In Python, I am using stablebaselines3 and gymnasium to implement a custom DQN. Using atari games I tested the agent and works, now I need to test it also on environments like CartPole
The problem is that this kind of environment does not return frames as observation but instead returns just a vector.
So I need a way to make return CartPole frames as observation and apply the same preprocessing stuff that I do on Atari games (like stack 4 frames of the game together)
I searched on the internet how to do it and I came up with this code after some tries, but I have some problems.
This is the code:
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self):
self.env.reset()
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
vec_env.close()
And the error is this when I call env.reset():
Traceback (most recent call last):
File "/data/g.carfi/rl/tmp.py", line 41, in <module>
obs = vec_env.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
observation = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
observations = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
return self.env.reset(**kwargs)
TypeError: reset() got an unexpected keyword argument 'seed'
how can I solve the problem?
The issue you're encountering is due to the fact that the reset() method of the CartPoleEnv class does not accept the seed argument, but it seems like it's being passed internally by VecEnv.
To solve this problem, you can modify the reset() method in your CartPoleImageWrapper class to handle this discrepancy. You can simply ignore the seed argument when calling the reset() method of the wrapped environment. Here's how you can do it:
With this modification, you should be able to use your CartPoleImageWrapper with VecFrameStack without encountering the TypeError related to the unexpected seed argument.