In Python, I am using stablebaselines3
and gymnasium
to implement a custom DQN. Using atari games I tested the agent and works, now I need to test it also on environments like CartPole
The problem is that this kind of environment does not return frames as observation but instead returns just a vector.
So I need a way to make return CartPole frames as observation and apply the same preprocessing stuff that I do on Atari games (like stack 4 frames of the game together)
I searched on the internet how to do it and I came up with this code after some tries, but I have some problems.
This is the code:
from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2
class CartPoleImageWrapper(gym.Wrapper):
metadata = {'render.modes': ['rgb_array']}
def __init__(self, env):
super(CartPoleImageWrapper, self).__init__(env)
self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def _get_image_observation(self):
# Render the CartPole environment
cartpole_image = self.render()
# Resize the image to 84x84 pixels
resized_image = cv2.resize(cartpole_image, (84, 84))
# make it grayscale
resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
resized_image = np.expand_dims(resized_image, axis=-1)
return resized_image
def reset(self):
self.env.reset()
return self._get_image_observation()
def step(self, action):
observation, reward, terminated, info = self.env.step(action)
return self._get_image_observation(), reward, terminated, info
env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
vec_env.close()
And the error is this when I call env.reset()
:
Traceback (most recent call last):
File "/data/g.carfi/rl/tmp.py", line 41, in <module>
obs = vec_env.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
observation = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
observations = self.venv.reset()
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
return self.env.reset(**kwargs)
TypeError: reset() got an unexpected keyword argument 'seed'
how can I solve the problem?
The error is because the reset() method of the CartPoleEnv class does not accept the seed argument but is passed by VecEnv. To fix this issue you will need to modify your CartPoleImageWrapper to handle this correctly.
These are the changes I made:
Changed _get_image_observation() to use render() method with mode='rgb_array'.
The reset() method was changed to accept **kwargs and pass them to the internal environment's reset() method (CartPoleEnv).
Unnecessary gym import removed.
With these changes, your container should work properly with the vectorized environment.