How to add a custom activation function in stable-baselines3 with custom feature extractor?

48 Views Asked by At

I have the following custom feature extractor:

class NatureCNN(BaseFeaturesExtractor):

    def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
        super(NatureCNN, self).__init__(observation_space, features_dim)

        self.frame_cnt = 0
        self.dump_dir_path = self.create_dump_dir()
        
        n_input_channels = observation_space.shape[0]
        self.cnn = nn.Sequential(
            nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=2, padding=0),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten()
        )

        # Compute shape by doing one forward pass
        with th.no_grad():
            n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]

        self.linear = nn.Sequential(
            nn.Linear(n_flatten, features_dim), 
            nn.BatchNorm1d(features_dim),
            nn.ReLU()
        )

    def forward(self, observations: th.Tensor) -> th.Tensor:
        
        cnn_outputs = self.cnn(observations)
        linear_output = self.linear(cnn_outputs)

        return linear_output

I am passing the custom feature extractor (written above) as below:

policy_kwargs = dict(
    features_extractor_class=NatureCNN
)

model = PPO(
    'CnnPolicy', 
    env, 
    # learning_rate=0.0001
    batch_size=128,
    clip_range=0.10,
    max_grad_norm=0.5,
    verbose=1, 
    seed=1,
    device="cuda",
    tensorboard_log="./tb_logs/",
    policy_kwargs=policy_kwargs,
)

I want just change the activation function of the MlpExtractor Class with my custom activation mentioned below:

class CustomActivation(nn.Module):
    def __init__(self, param):
        super(CustomActivation, self).__init__()
        # My parameters for custom activation function

    def forward(self, x):
        # Implement of custom activation logic here

        return result

and want the rest of the MlpExtractor structure to be exactly the same, like below:

  (mlp_extractor): MlpExtractor(
    (shared_net): Sequential()
    (policy_net): Sequential(
      (0): Linear(in_features=512, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
    (value_net): Sequential(
      (0): Linear(in_features=512, out_features=64, bias=True)
      (1): Tanh()
      (2): Linear(in_features=64, out_features=64, bias=True)
      (3): Tanh()
    )
  )
  (action_net): Linear(in_features=64, out_features=2, bias=True)
  (value_net): Linear(in_features=64, out_features=1, bias=True)
)

I am new to stable-baselines3 so can anyone help me with this? I tried the method mentioned here. But it is giving me errors. Any help/suggestions are much appreciated.

0

There are 0 best solutions below