I have the following custom feature extractor:
class NatureCNN(BaseFeaturesExtractor):
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 512):
super(NatureCNN, self).__init__(observation_space, features_dim)
self.frame_cnt = 0
self.dump_dir_path = self.create_dump_dir()
n_input_channels = observation_space.shape[0]
self.cnn = nn.Sequential(
nn.Conv2d(n_input_channels, 32, kernel_size=8, stride=2, padding=0),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=0),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=0),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Flatten()
)
# Compute shape by doing one forward pass
with th.no_grad():
n_flatten = self.cnn(th.as_tensor(observation_space.sample()[None]).float()).shape[1]
self.linear = nn.Sequential(
nn.Linear(n_flatten, features_dim),
nn.BatchNorm1d(features_dim),
nn.ReLU()
)
def forward(self, observations: th.Tensor) -> th.Tensor:
cnn_outputs = self.cnn(observations)
linear_output = self.linear(cnn_outputs)
return linear_output
I am passing the custom feature extractor (written above) as below:
policy_kwargs = dict(
features_extractor_class=NatureCNN
)
model = PPO(
'CnnPolicy',
env,
# learning_rate=0.0001
batch_size=128,
clip_range=0.10,
max_grad_norm=0.5,
verbose=1,
seed=1,
device="cuda",
tensorboard_log="./tb_logs/",
policy_kwargs=policy_kwargs,
)
I want just change the activation function of the MlpExtractor Class with my custom activation mentioned below:
class CustomActivation(nn.Module):
def __init__(self, param):
super(CustomActivation, self).__init__()
# My parameters for custom activation function
def forward(self, x):
# Implement of custom activation logic here
return result
and want the rest of the MlpExtractor structure to be exactly the same, like below:
(mlp_extractor): MlpExtractor(
(shared_net): Sequential()
(policy_net): Sequential(
(0): Linear(in_features=512, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
(value_net): Sequential(
(0): Linear(in_features=512, out_features=64, bias=True)
(1): Tanh()
(2): Linear(in_features=64, out_features=64, bias=True)
(3): Tanh()
)
)
(action_net): Linear(in_features=64, out_features=2, bias=True)
(value_net): Linear(in_features=64, out_features=1, bias=True)
)
I am new to stable-baselines3 so can anyone help me with this? I tried the method mentioned here. But it is giving me errors. Any help/suggestions are much appreciated.