I am using DQN algorithm and I need to obtain the (optimal) action related to a specific state in the step function of environment. Is there any way to do it?
# %% Env class
class SimpleEnv(gymnasium.Env):
def __init__(self, env_config={'env_name': 'simple_cnn_env'}):
self.env_name = env_config['env_name']
self.action_space = gymnasium.spaces.Discrete(2)
observation_space = gymnasium.spaces.Box(0.0 * np.ones(2, dtype=float),
1.0 * np.ones(2, dtype=float),
shape=(2,),
dtype=float)
def reset(self, *,
seed=None, options=None):
state = (np.array([0, 1]), np.zeros((21, 21, 3), dtype=np.uint8))
self.timestep = 0
return state, {}
def _take_action(self, action):
next_state = self.initial_state
done = False if self.timestep <= 3 else True
reward = 1 if done else 0
return next_state, reward, done
def step(self, action):
self.timestep += 1
state, reward, done = self._take_action(action)
## Here I need to obtain the action based on DQN (or every policy in general) for the current state ##
....
return state, reward, done, False, {}
# %% Network
class SimpleConv(TorchModelV2, nn.Module, ABC):
### Here I define my network ######
.....