How do I get value function/critic values from Rllib's PPO algorithm for a range of observations?

279 Views Asked by At

Goal: I want to train a PPO agent on a problem and determine its optimal value function for a range of observations. Later I plan to work with this value function (economic inequality research). The problem is sufficiently complex so that dynamic programming techniques no longer work.

Approach: In order to check, whether I get correct outputs for the value function, I have trained PPO on a simple problem, whose analytical solution is known. However, the results for the value function are rubbish, which is why I suspect that I have done sth wrong.

The code:

from keras import backend as k_util
...

parser = argparse.ArgumentParser()

# Define framework to use
parser.add_argument(
    "--framework",
    choices=["tf", "tf2", "tfe", "torch"],
    default="tf",
    help="The DL framework specifier.",
)
...

def get_rllib_config(seeds, debug=False, framework="tf") -> Dict:
...

def get_value_function(agent, min_state, max_state):
    policy = agent.get_policy()
    value_function = []
    for i in np.arange(min_state, max_state, 1):
        model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)})
        value = k_util.eval(policy.model.value_function())[0]
        value_function.append(value)
        print(i, value)
    return value_function


def train_schedule(config, reporter):
    rllib_config = config["config"]
    iterations = rllib_config.pop("training_iteration", 10)

    agent = PPOTrainer(env=rllib_config["env"], config=rllib_config)
    for _ in range(iterations):
        result = agent.train()
        reporter(**result)
    values = get_value_function(agent, 0, 100)
    print(values)
    agent.stop()

...

resources = PPO.default_resource_request(exp_config)
tune_analysis = tune.Tuner(tune.with_resources(train_schedule, resources=resources), param_space=exp_config).fit()
ray.shutdown()

So first I get the policy (policy = agent.get_policy()) and run a forward pass with each of the 100 values (model_out, _ = policy.model({"obs": np.array([[i]], dtype=np.float32)})). Then, after each forward pass I use the value_function() method to get the output of the critic network and evaluate the tensor via keras backend.

The results: True VF (analytical solution) VF output of Rllib

Unfortunately you can see that the results are not that promising. Maybe I have missed a pre- or postprocessing step? Does the value_function() method even return the last layer of the critic network?

I am very grateful for any help!

1

There are 1 best solutions below

0
On

It's not part of your script, but I assume that you have trained the policy before you attempt to get useful values out of it.

You are correct in assuming that the value_function() returns the output of the last layer of the critic network in RLlib's implementations. Have a look at the value function metrics to see if it's actually learning anything (RLlib logs .../learner_stats/vf_loss and .../learner_stats/vf_explained_var)! After training the model, I'd also try to query the model directly. If that looks better, something is likely off with the code you posted here.