I'm trying to get a hang of reinforcement learning, so I'm following a guide at: pytorch.org/tutorials/
They've implemented DQN that solves CartPole with computer vision. Basically, I've copied their code and modified it to solve the LunarLander environment without computer vision. But I'm getting weird results. The model seems to be learning as it improves its score (with a lot of hiccups) until it fails spectacularly and gets stuck, doing weird movements and not learning.
Another learning progress graph of the different model
You can see both models failing in the same way at the end of the learning.
I cannot figure out why this solution is not working. Could you have a look at my code and perhaps find and point out errors?
Global variables:
BATCH_SIZE = 1000
GAMMA = 0.999
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TARGET_UPDATE = 10
LEARNING_RATE = 0.01
MOMENTUM = 0.9
MEMORY_SIZE = 10000
env = gym.make('LunarLander-v2')
n_actions = env.action_space.n
n_observation_space = env.observation_space.shape[0]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = DQN(n_observation_space, n_actions).to(device)
target_net = DQN(n_observation_space, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
memory = ReplayMemory(MEMORY_SIZE)
Learning loop:
def learn(num_episodes=50, render=False):
for i_episode in range(num_episodes):
# Initialize the environment and state
state = torch.tensor([env.reset()], device=device, dtype=torch.float32)
episode_reward = 0
for t in count():
# Select and perform an action
action = select_action(state)
next_state, reward, done, _ = env.step(action.item())
episode_reward += reward
reward = torch.tensor([reward], device=device, dtype=torch.float32)
next_state = torch.tensor([next_state], device=device, dtype=torch.float32)
# Store the transition in memory
memory.push(state, action, next_state, reward)
# Move to the next state
state = next_state
# Perform one step of the optimization (on the target network)
optimize_model()
if render:
env.render()
if done:
break
all_rewards.append(episode_reward)
# Update the target network, copying all weights and biases in DQN
if i_episode % TARGET_UPDATE == 0:
target_net.load_state_dict(policy_net.state_dict())
Optimization methods:
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None,
batch.next_state)), device=device, dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state
if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE, device=device)
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0].detach()
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
loss = nn.MSELoss(state_action_values, expected_state_action_values.unsqueeze(1))
# Optimize the model
optimizer.zero_grad()
loss.backward()
for param in policy_net.parameters():
param.grad.data.clamp_(-1, 1)
optimizer.step()
Model:
class DQN(nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.l1 = nn.Linear(input_size, 512)
self.l2 = nn.Linear(512, 512)
self.l3 = nn.Linear(512, 256)
self.l4 = nn.Linear(256, output_size)
def forward(self, x):
x = F.leaky_relu(self.l1(x))
x = F.leaky_relu(self.l2(x))
x = F.leaky_relu(self.l3(x))
return self.l4(x)
If anyone's willing to run my code locally, please let me know. I'll clean up the code and share it via Github.
Looking through your code, I can't seem to find any standing-out bugs (but you didn't post everything). There are a few weird things though:
BATCH_SIZE
of 1000 is quite massive. Of course you should try with what works best for you but next time try with 32/64/128 and around.EPS
decay, I assume you're decaying yourEPS
at every time step with a 1/1000 decay rate. Given that you're using a very big network, try and make your epsilon decay slower.