I'm posting a question that was posted by another user and then deleted. I had the same question, and I found an answer. The original question:
I am currently trying to implement a categorical DQN following this tutorial: https://www.tensorflow.org/agents/tutorials/9_c51_tutorial
The following part is giving me a bit of a headache though:
random_policy = random_tf_policy.RandomTFPolicy(env.time_step_spec(),
env.action_spec())
replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=agent.collect_data_spec,
batch_size=1,
max_length=replay_buffer_capacity) # this is 100
# ...
def collect_step(environment, policy):
time_step = environment.current_time_step()
action_step = policy.action(time_step)
next_time_step = environment.step(action_step.action)
traj = trajectory.from_transition(time_step, action_step, next_time_step)
print(traj)
# Add trajectory to the replay buffer
replay_buffer.add_batch(traj)
for _ in range(initial_collect_steps):
collect_step(env, random_policy)
For context: agent.collect_data_spec
is of the following shape:
Trajectory(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), action=BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6)), policy_info=(), next_step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
And here is what a sample traj looks like:
Trajectory(step_type=<tf.Tensor: shape=(), dtype=int32, numpy=0>, observation=<tf.Tensor: shape=(4, 84, 84), dtype=float32, numpy=array([tensor contents omitted], dtype=float32)>, action=<tf.Tensor: shape=(), dtype=int32, numpy=1>, policy_info=(), next_step_type=<tf.Tensor: shape=(), dtype=int32, numpy=1>, reward=<tf.Tensor: shape=(), dtype=float32, numpy=0.0>, discount=<tf.Tensor: shape=(), dtype=float32, numpy=1.0>)
So, everything should check out, right? The environment outputs a tensor of shape [4, 84, 84], same as the replay buffer expects. Except I'm getting the following error:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [4,84,84], indices.shape [1], params.shape [100,4,84,84] [Op:ResourceScatterUpdate]
Which suggests that it is actually expecting a tensor of shape [1, 4, 84, 84]
. The thing is though, if I have my environment output a tensor of that shape, I then get another error message telling me that the output shape doesn't match the spec shape (duh). And if I then adjust the spec shape to be [1, 4, 84, 84]
, suddenly the replay buffer expects a shape of [1, 1, 4, 84, 84]
, and so on...
Finally, for completion, here you have the time_step_spec
and action_spec
of my environment respectively:
TimeStep(step_type=TensorSpec(shape=(), dtype=tf.int32, name='step_type'), reward=TensorSpec(shape=(), dtype=tf.float32, name='reward'), discount=BoundedTensorSpec(shape=(), dtype=tf.float32, name='discount', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)), observation=BoundedTensorSpec(shape=(4, 84, 84), dtype=tf.float32, name='screen', minimum=array(0., dtype=float32), maximum=array(1., dtype=float32)))
---
BoundedTensorSpec(shape=(), dtype=tf.int32, name='play', minimum=array(0), maximum=array(6))
I've tried pretty much the better half of today trying to get the tensor to fit properly, but you cannot reshape it since it's an attribute so in a last ditch effort I'm hoping maybe some kind stranger out there can tell me what the heck is going on here.
Thank you in advance!
It seems that in the
collect_step
function,traj
is a a single trajectory, not a batch. Therefore you need to expand the dimensions into a batch and then use it. Note that you can't just dotf.expand_dims(traj, 0)
. There's a helper function for doing it for nested structures.