I am using a custom PPO model with ray.tune(), and I want to add some self-supervised learning that is dependent on batch[‘obs’], batch[‘done’], batch[‘action’] and batch[‘next_obs’]
I have defined some layers in my model that are called only during training.
I have defined a loss function which I am passing to the trainer within the loss function, i passed various inputs through layers that were never called in the forward model. Specifically, these inputs are train_batch[‘actions’] (things from the observation), and layers that I have stored as attributes of the model (e.g. model.loss_context)
The layers that are not in the forward model (i.e. the ones only called during the loss function) do not seem to be added to the gradient - i am recording their magnitude and they are not changing ,even when I put a clearly simple example which is just a huge weight decay on a layer called outside the forward model.
I have also tried adding these layers to a overridden @custom_loss function, as per the example https://github.com/ray-project/ray/blob/50e1fda022a81e5015978cf723f7b5fd9cc06b2c/rllib/examples/models/custom_loss_model.py: but in this case the weights for those layers do not even initialise.
Has anyone solved this? I see a number of stack overflow questions asking about this but no answers!
see above. i was epxecting the weights to change. here is the loss function
LoggedPPO = PPOTFPolicy.with_updates(
name="SHPPOPolicy",
loss_fn=ppo_surrogate_loss,
grad_stats_fn=grad_stats,
stats_fn=stats,
)
context, action_mask, net_mask = tf.split(
logits,
[
model.context_dim * model.max_num_nets,
model.max_num_nets * (9 + model.svg_feature_dict["max_layers"]),
model.max_num_nets,
],
axis=1,
)
x = model.test_dense(context)
wd_loss = sum(
[tf.reduce_sum(v ** 2) for v in model.test_dense.variables]
) + 1e-4
batch_loss = [ ..... wd_loss]
in this example, test_dense, which is not called during the forward pass, never gets updated, even though the case is trivial and the model should try to reduce the absolute scalar value of its weights.
You'll want to make sure that you are actually doing a couple of things:
I'd recommend using zero workers, not using tune for now and setting breakpoints in the sections of code you modified. It's hard to tell from here which of the above steps is not taken. Since you are mentioning
with_updates()
: That API has been deprecated and using it makes debugging an issue like this one a little harder. Consider upgrading! The current PPO policies can simply be subclassed. Posting a full reproduction script on GH makes it more obvious how your modifications look.Cheers