Get frames as observation for CartPole environment

73 Views Asked by At

In Python, I am using stablebaselines3 and gymnasium to implement a custom DQN. Using atari games I tested the agent and works, now I need to test it also on environments like CartPole The problem is that this kind of environment does not return frames as observation but instead returns just a vector. So I need a way to make return CartPole frames as observation and apply the same preprocessing stuff that I do on Atari games (like stack 4 frames of the game together)

I searched on the internet how to do it and I came up with this code after some tries, but I have some problems.

This is the code:

from stable_baselines3.common.env_util import make_atari_env, make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage
import gymnasium as gym
from gymnasium import spaces
from gymnasium.envs.classic_control import CartPoleEnv
import numpy as np
import cv2


class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self, env):
        super(CartPoleImageWrapper, self).__init__(env)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def _get_image_observation(self):
        # Render the CartPole environment
        cartpole_image = self.render()

        # Resize the image to 84x84 pixels
        resized_image = cv2.resize(cartpole_image, (84, 84))
        # make it grayscale
        resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
        resized_image = np.expand_dims(resized_image, axis=-1)
        return resized_image

    def reset(self):
        self.env.reset()
        return self._get_image_observation()

    def step(self, action):
        observation, reward, terminated, info = self.env.step(action)
        return self._get_image_observation(), reward, terminated, info


env = CartPoleImageWrapper(CartPoleEnv(render_mode='rgb_array'))
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")
#exit()
    
vec_env.close()

And the error is this when I call env.reset():

Traceback (most recent call last):
    File "/data/g.carfi/rl/tmp.py", line 41, in <module>
        obs = vec_env.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_frame_stack.py", line 41, in reset
        observation = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/vec_transpose.py", line 113, in reset
        observations = self.venv.reset()
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/vec_env/dummy_vec_env.py", line 77, in reset
        obs, self.reset_infos[env_idx] = self.envs[env_idx].reset(seed=self._seeds[env_idx], **maybe_options)
    File "/data/g/.virtualenvs/rl_new/lib/python3.8/site-packages/stable_baselines3/common/monitor.py", line 83, in reset
        return self.env.reset(**kwargs)
    TypeError: reset() got an unexpected keyword argument 'seed'

how can I solve the problem?

2

There are 2 best solutions below

0
On

The error is because the reset() method of the CartPoleEnv class does not accept the seed argument but is passed by VecEnv. To fix this issue you will need to modify your CartPoleImageWrapper to handle this correctly.

import cv2
import gym
import numpy as np
from gym import spaces
from gym.envs.classic_control import CartPoleEnv
from stable_baselines3.common.env_util import make_vec_env
from stable_baselines3.common.vec_env import VecFrameStack, VecTransposeImage


class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

    def __init__(self, env):
        super(CartPoleImageWrapper, self).__init__(env)
        self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def _get_image_observation(self):
        # Render the CartPole environment
        cartpole_image = self.env.render(mode='rgb_array')

        # Resize the image to 84x84 pixels
        resized_image = cv2.resize(cartpole_image, (84, 84))
        # make it grayscale
        resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
        resized_image = np.expand_dims(resized_image, axis=-1)
        return resized_image

    def reset(self, **kwargs):
        self.env.reset(**kwargs)
        return self._get_image_observation()

    def step(self, action):
        observation, reward, terminated, info = self.env.step(action)
        return self._get_image_observation(), reward, terminated, info


env = CartPoleImageWrapper(CartPoleEnv())
vec_env = make_vec_env(lambda: env, n_envs=1)
vec_env = VecTransposeImage(vec_env)
vec_env = VecFrameStack(vec_env, n_stack=4)
obs = vec_env.reset()
print(f"Observation space: {obs.shape}")

vec_env.close()

These are the changes I made:

  1. Changed _get_image_observation() to use render() method with mode='rgb_array'.

  2. The reset() method was changed to accept **kwargs and pass them to the internal environment's reset() method (CartPoleEnv).

  3. Unnecessary gym import removed.

With these changes, your container should work properly with the vectorized environment.

0
On

The issue you're encountering is due to the fact that the reset() method of the CartPoleEnv class does not accept the seed argument, but it seems like it's being passed internally by VecEnv.

To solve this problem, you can modify the reset() method in your CartPoleImageWrapper class to handle this discrepancy. You can simply ignore the seed argument when calling the reset() method of the wrapped environment. Here's how you can do it:

class CartPoleImageWrapper(gym.Wrapper):
    metadata = {'render.modes': ['rgb_array']}

def __init__(self, env):
    super(CartPoleImageWrapper, self).__init__(env)
    self.observation_space = spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

def _get_image_observation(self):
    # Render the CartPole environment
    cartpole_image = self.render()

    # Resize the image to 84x84 pixels
    resized_image = cv2.resize(cartpole_image, (84, 84))
    # make it grayscale
    resized_image = cv2.cvtColor(resized_image, cv2.COLOR_RGB2GRAY)
    resized_image = np.expand_dims(resized_image, axis=-1)
    return resized_image

def reset(self, **kwargs):
    self.env.reset(**kwargs)  # Ignore the 'seed' argument
    return self._get_image_observation()

def step(self, action):
    observation, reward, terminated, info = self.env.step(action)
    return self._get_image_observation(), reward, terminated, info

With this modification, you should be able to use your CartPoleImageWrapper with VecFrameStack without encountering the TypeError related to the unexpected seed argument.