Using Captum with Stable Baselines

24 Views Asked by At

I've been testing Captum feature attribution with a Stable Baselines PPO agent, and I wanted to make sure I was taking the right approach. I couldn't find any tutorials for using Captum in an RL setting...

  1. During an evaluation episode of the agent I'm recording observations from the environment.
  2. Afterwards, I'm constructing a Captum integrated gradients object like this:
ig = IntegratedGradients(agent.policy.mlp_extractor.policy_net)
  1. Then I'm choosing a target action, and, after converting the recorded observations to a tensor, providing those as inputs to the attribute method along with the initial (reset) observation as the baseline:
attr = ig.attribute(recorded_observations_tensor, initial_observation_tensor, action)

For the simple Gymnasium Cart Pole environment, that produces something like this:

Importances of features for action: Left
Cart position: 2.047e-05
Cart velocity: 4.940e-04
Pole angle: 4.476e-06
Pole angular velocity: 2.138e-04

Am I on the right track in terms of recording observations and using the policy net as the forward function?

0

There are 0 best solutions below