Stablebaselines3 and Pettingzoo

372 Views Asked by At

I am trying to understand how to train agents in a pettingzoo environment using the single agent algorithm PPO implemented in stablebaselines3.

I'm following this tutorial where the agents act in a cooperative environment and they are all trained with (parameter sharing) PPO. However when I pass the pettingzoo environment into the PPO constructor of stablebaselines3, I get the following error message:

The algorithm only supports (<class 'gym.spaces.box.Box'>, <class 'gym.spaces.discrete.Discrete'>, <class 'gym.spaces.multi_discrete.MultiDiscrete'>, <class 'gym.spaces.multi_binary.MultiBinary'>) as action spaces but Box(-1.0, 1.0, (1,), float32) was provided

Here is my full code:

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.utils.conversions import aec_to_parallel
import supersuit as ss

from stable_baselines3.ppo import CnnPolicy
from stable_baselines3 import PPO




def main():
    # Initialize environment
    env = pistonball_v6.env(n_pistons=20, 
                        time_penalty=-0.1, 
                        continuous=True, 
                        random_drop=True, 
                        random_rotate=True, 
                        ball_mass=0.75, 
                        ball_friction=0.3, 
                        ball_elasticity=1.5,
                        max_cycles=125)
    # Reduce the complexity of the observation by considering only the blue channel
    env = ss.color_reduction_v0(env, mode='B')
    # Resize the observation to reduce dimension
    env = ss.resize_v1(env, x_size=84, y_size=84)
    # In order to let the policy learn based on the ball's velocity and acceleration, 
    # we inlcude the last 3 frames consecutive frames in the observation
    env = ss.frame_stack_v1(env,3)
    # This is for using stable baselines
    env = aec_to_parallel(env)
    env = ss.pettingzoo_env_to_vec_env_v1(env)
    
    # prepare the einvironment to use stablebaselines
    env = ss.concat_vec_envs_v1(env, 2, num_cpus=1, base_class='stable_baselines3')
    # PPO
    model = PPO(CnnPolicy, 
                env, 
                verbose=3,
                gamma=0.95,
                n_steps=256,
                ent_coef=0.0905168,
                learning_rate=0.00062211,
                vf_coef=0.042202,
                max_grad_norm=0.9,
                gae_lambda=0.99,
                n_epochs=5,
                clip_range=0.3,
                batch_size=256)

    model.learn(total_timesteps=100000)

    pass



if __name__ == "__main__":
    main()
0

There are 0 best solutions below