I am struggling with batch_size in TorchRL environment. I created an environment that passes the check_env test. Here is an example of the environment, with the _step method simplified but with appropriate shape operations.
class FlyEnv(EnvBase):
batch_locked = False
def __init__(self, nb_joints=nb_joints, td_params=None, seed=None, device="cpu"):
self.nb_joints = nb_joints
self.dt = time_step
if td_params is None:
td_params = self.gen_params()
super().__init__(device=device, batch_size=[])
self._make_spec(td_params)
if seed is None:
seed = torch.empty((), dtype=torch.int64).random_().item()
self.set_seed(seed)
def gen_params(self, batch_size=None) -> TensorDictBase:
if batch_size is None:
batch_size = []
obs_min = torch.concat(
(-torch.pi*torch.ones(self.nb_joints),
-10*torch.ones(self.nb_joints)),
dim=-1)
td = TensorDict(
{
"params": TensorDict(
{
"obs_min":obs_min,
"obs_max":-obs_min,
"act_min":obs_min,
"act_max":-obs_in
},
[],
)
},
[],
)
if batch_size:
td = td.expand(batch_size).contiguous()
return td
def _make_spec(self, td_params):
self.observation_spec = CompositeSpec(
observation=BoundedTensorSpec(
low=td_params["params", "obs_min"],
high=td_params["params", "obs_max"],
shape=(2*self.nb_joints,),
dtype=torch.float32,
),
params=make_composite_from_td(td_params["params"]),
shape=td_params.shape,
)
# since the environment is stateless, we expect the previous output as input.
# For this, ``EnvBase`` expects some state_spec to be available
self.state_spec = self.observation_spec.clone()
# action-spec will be automatically wrapped in input_spec when
# `self.action_spec = spec` will be called supported
self.action_spec = BoundedTensorSpec(
low=td_params["params", "act_min"],
high=td_params["params", "act_max"],
shape=(*td_params.shape, 2*self.nb_joints),
dtype=torch.float32,
)
self.reward_spec = UnboundedContinuousTensorSpec(shape=(*td_params.shape, 1))
def _step(self, tensordict):
action = tensordict["action"]
new_obs = action
reward = torch.sum(action, dim=-1)
reward = reward.view(*tensordict.shape, 1)
done = torch.zeros_like(reward, dtype=torch.bool)
out = TensorDict(
{
"observation": new_obs,
"params": tensordict["params"],
"reward": reward,
"done": done,
},
tensordict.shape,
)
return out
def _reset(self, tensordict):
if tensordict is None or tensordict.is_empty():
tensordict = self.gen_params(batch_size=self.batch_size)
obs_max = tensordict["params", "obs_max"]
obs_min = tensordict["params", "obs_min"]
size = (*tensordict.shape, 2*self.nb_joints)
# for non batch-locked environments, the input ``tensordict`` shape dictates the number
# of simulators run simultaneously. In other contexts, the initial
# random state's shape will depend upon the environment batch-size instead.
obs = (
torch.rand(size, generator=self.rng, device=self.device)
* (obs_max - obs_min)
+ obs_min
)
out = TensorDict(
{
"observation": obs,
"params": tensordict["params"],
},
batch_size=tensordict.shape,
device="cpu",
)
return out
def _set_seed(self, seed: Optional[int]):
rng = torch.manual_seed(seed)
self.rng = rng
The in the torchRL pipeline, I use this environment with an actor that takes inputs [*B, F] and outputs [*B, F] where B is the batch shape.
However, with collector=SyncDataCollector, the line for i, tensordict_data in enumerate(collector): produces the following error
File "c:\Users\samje\Documents\EPFL\Cours\Semester project 2\Code\RL_copy.py", line 160, in <module>
for i, tensordict_data in enumerate(collector):
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\collectors\collectors.py", line 952, in iterator
tensordict_out = self.rollout()
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\_utils.py", line 469, in unpack_rref_and_invoke_function
return func(self, *args, **kwargs)
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\collectors\collectors.py", line 1069, in rollout
env_output, env_next_output = self.env.step_and_maybe_reset(env_input)
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\common.py", line 2576, in step_and_maybe_reset
tensordict = self.step(tensordict)
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\common.py", line 1409, in step
next_tensordict = self._step(tensordict)
File "C:\Users\samje\anaconda3\envs\semester_proj\lib\site-packages\torchrl\envs\transforms\transforms.py", line 738, in _step
next_tensordict = self.base_env._step(tensordict_in)
File "c:\Users\samje\Documents\EPFL\Cours\Semester project 2\Code\environment_copy_copy.py", line 168, in _step
reward = reward.view(*tensordict.shape, 1)
RuntimeError: shape '[1]' is invalid for input of size 4
This is due to the fact that the action (output of the policy network) has shape [B, F] whereas the tensordict.shape = torch.Size([])... And reward has shape [B].
I tried to manually set the batch_size in the different tensordicts/variables of the environment, which solves this issue but fails later in the code in GAE (from torchrl.objectives.value). The latter concatenates tensordicts in shape [B, T, F] and the new environment does not handle the batch size [B, T]...
My question: is there a simple way to handle these batch_sizes in a torchRL environment ?
Please let me now if there is anything missing, I tried to put the minimum as it is already long enough.
A couple of things while looking at the code:
If you want a dynamic batch-size, it is supported (I see you set batch-locked=False which is what is needed there).
For the collector, we need to be able to tell the environment what the batch-size is going to be. It's not an easy task and we should find a way to streamline that... To say it gently, the compatibility with envs that are not batch-locked and collectors is limited. env.rollout with a
tensordict=smthargument should work though!I made a PR to solve this issue https://github.com/pytorch/rl/pull/2030