I'm trying to understand how to use Actor class in tf_agents. I am using DDPG (actor-critic, although this doesn't really matter per say). I also am learning off of gym package, although again this isn't fully important to the question.
I went into the class definition for train.Actor and under the hood the run method calls py_driver.PyDriver. It is my understanding that after it reaches a terminal state, the gym environment needs to be reset. However following the Actor and PyDriver classes, I don't see anywhere (outside the init method) where env.reset() is called. And then looking at the tutorial for sac_agent.SacAgent, I don't see them calling env.reset() either.
Can someone help me understand what is missing? Do I not need to call env.reset()? Or is there some code that is being called that I am missing?
Here is the method for PyDriver.run():
def run(
self,
time_step: ts.TimeStep,
policy_state: types.NestedArray = ()
) -> Tuple[ts.TimeStep, types.NestedArray]:
num_steps = 0
num_episodes = 0
while num_steps < self._max_steps and num_episodes < self._max_episodes:
# For now we reset the policy_state for non batched envs.
if not self.env.batched and time_step.is_first() and num_episodes > 0:
policy_state = self._policy.get_initial_state(self.env.batch_size or 1)
action_step = self.policy.action(time_step, policy_state)
next_time_step = self.env.step(action_step.action)
# When using observer (for the purpose of training), only the previous
# policy_state is useful. Therefore substitube it in the PolicyStep and
# consume it w/ the observer.
action_step_with_previous_state = action_step._replace(state=policy_state)
traj = trajectory.from_transition(time_step, action_step_with_previous_state, next_time_step)
for observer in self._transition_observers:
observer((time_step, action_step_with_previous_state, next_time_step))
for observer in self.observers:
observer(traj)
for observer in self.info_observers:
observer(self.env.get_info())
if self._end_episode_on_boundary:
num_episodes += np.sum(traj.is_boundary())
else:
num_episodes += np.sum(traj.is_last())
num_steps += np.sum(~traj.is_boundary())
time_step = next_time_step
policy_state = action_step.state
return time_step, policy_state
As you can see, it increases the number of steps if it hits a boundary, and increases the number of episodes if it hits the terminal state. But then there is no call to env.reset().