PettingZoo Agent Training with Supersuit and Stable Baselines 3

339 Views Asked by At

I am trying to train agents in a PettingZoo environment using stable baselines and supersuit.

The new version from stable baselines 3, which is beta, is supporting Gymnasium now, and not just Gym. Therefore it should be compatible. Still it looks incompatible.

Can anyone shed a light if it is possible to train agents from pettingzoo environment using stable baselines 3? Also, if the most recent beta is incompatible, does anyone know which versions are compatible and work well together of the packages below?

Python version - 3.10.10
stable baselines3 - 2.0.0a13
SuperSuit - 3.8.0
pettingzoo - 1.23.1
gymnasium - 0.28.1

from stable_baselines3.ppo import CnnPolicy
from stable_baselines3 import PPO
from pettingzoo.butterfly import pistonball_v6
import supersuit as ss
from pettingzoo.utils.conversions import aec_to_parallel


env = pistonball_v6.env(n_pistons=20, 
                        time_penalty=-0.1, 
                        continuous=True, 
                        random_drop=True, 
                        random_rotate=True, 
                        ball_mass=0.75, 
                        ball_friction=0.3, 
                        ball_elasticity=1.5,
                        max_cycles=125)
env = ss.color_reduction_v0(env, mode="B")
env = ss.resize_v1(env, x_size=84, y_size=84)
env = ss.frame_stack_v1(env, 3)
env = aec_to_parallel(env)
env = ss.pettingzoo_env_to_vec_env_v1(env)


env = ss.concat_vec_envs_v1(env, 1, base_class='stable_baselines3')

model = PPO(CnnPolicy, 
            env, 
            verbose=3, 
            gamma=0.95, 
            n_steps=256, 
            ent_coef=0.0905168, 
            learning_rate=0.00062211, 
            vf_coef=0.042202, 
            max_grad_norm=0.9, 
            gae_lambda=0.99, 
            n_epochs=5, 
            clip_range=0.3, 
            batch_size=256)


model.learn(total_timesteps=2000000)
model.save('policy')

It errors: by line 26 - model = PPO(CnnPolicy...

TypeError: VectorEnv.get_attr() takes 2 positional arguments but 3 were given

1

There are 1 best solutions below

0
On

Been testing very simular code (slightly modified pistonball_v5) and I have made two changes in your code

Removed env = aec_to_parallel(env)

And changed env = ss.concat_vec_envs_v1(env,1,base_class='stable_baselines) to env = ss.concat_vec_envs_v1(env, num_vec_envs=4, num_cpus=1, base_class='stable_baselines3')

num_cpus=1, must bee 1 as I have one GPU