TFAGENTS: clarification on the usage of observation_and_action_constraint_splitter for DqnAgent agents

262 Views Asked by At

im trying to create a DqnAgent agent with a mask for valid/invalid actions, according to this post , i should specify a splitter_fn for the observation_and_action_constraint_splitter arg. According to the tf_agents doc

, the splitter_fn would be like:

def observation_and_action_constraint_splitter(observation):
  return observation['network_input'], observation['constraint'] 

On my mind, i thought the variable observation should be an array returned by env.step(action).observation which is an array with shape (56,) in my case (it is a flattened array with the original shape (14,4), each row are 4 feature values for each choice, there are 5-14 choices, if the choices are invalid the corresponding features will be all 0), so i wrote my splitter_fn like this:

def observation_and_action_constrain_splitter(observation):
     print(observation)
     temp = observation.reshape(14,-1)
     action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
     return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

agent = DqnAgent(
    tf_time_step_spec,
    tf_action_spec,
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=tf_common.element_wise_squared_loss,
    train_step_counter=train_step_counter,
    observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
)

However, it returned the following error when running the above code cell:

BoundedTensorSpec(shape=(56,), dtype=tf.float32, name='observation', minimum=array(-3.4028235e+38, dtype=float32), maximum=array(3.4028235e+38, dtype=float32))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-213-07450ea5ba21> in <module>()
     13     td_errors_loss_fn=tf_common.element_wise_squared_loss,
     14     train_step_counter=train_step_counter,
---> 15     observation_and_action_constraint_splitter=observation_and_action_constrain_splitter
     16     )
     17 

4 frames
<ipython-input-212-dbfee6076511> in observation_and_action_constrain_splitter(observation)
      1 def observation_and_action_constrain_splitter(observation):
      2      print(observation)
----> 3      temp = observation.reshape(14,-1)
      4      action_mask = (~(temp==0).all(axis=1)).astype(np.int32).ravel()
      5      return observation, tf.convert_to_tensor(action_mask, dtype=tf.int32)

AttributeError: 'BoundedTensorSpec' object has no attribute 'reshape'
  In call to configurable 'DqnAgent' (<class 'tf_agents.agents.dqn.dqn_agent.DqnAgent'>)

It turns out that print(observation) returns a BoundedTensorSpec object, not an array nor a tf.Tensor object. How can i create my action mask from BoundedTensorSpec, which doesnt even contain the array for the observation?

Thanks in advance!

PS: tf_agents version is 0.12.0

1

There are 1 best solutions below

0
On

I was faced to the same problem. I solved it by passing the function observation_and_action_constrain_splitter to the policy instead of DqnAgent

agent = categorical_dqn_agent.CategoricalDqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    categorical_q_network=categorical_q_net,
    optimizer=optimizer,
    min_q_value=min_q_value,
    max_q_value=max_q_value,
    n_step_update=n_step_update,
    td_errors_loss_fn=common.element_wise_squared_loss,
    gamma=gamma,
    train_step_counter=train_step_counter)
agent.initialize()

random_policy = random_tf_policy.RandomTFPolicy(train_env.time_step_spec(),
                                                train_env.action_spec(),
                                                observation_and_action_constraint_splitter=observation_and_action_constraint_splitter)

I hope this helped you.