Pytorch Linear layer outputting a nan tensor

55 Views Asked by At

I am trying to implement a SAC agent to train in a custom gym environment, but I get nan values after the state passes through the first layer (the assertion error to debug the forward is raised, and the layer params are all nans)

class ActorNet(nn.Module):
LOG_STD_MAX = 2
LOG_STD_MIN = -20
def __init__(self, state_dims, action_dims, checkpoint_dir="checkpoint", name="ActorNet"):
    super().__init__()
    # self.net = nn.Sequential(
    #         nn.Linear(state_dims, 512),
    #         nn.Sigmoid(),
    #         nn.Linear(512, 256),
    #         nn.Sigmoid()
    # )
    self.net1 = nn.Linear(state_dims, 512)
    self.net2 = nn.Linear(512, 256)
    self.mu_layer = nn.Linear(256, action_dims)
    self.log_sigma_layer = nn.Linear(256, action_dims)

    self.name = name
    self.checkpoint_dir =checkpoint_dir


def forward(self, state, explore=True, with_logprob=True):
    # raw_output = self.net(state)
    raw_output = self.net1(state)
    assert not np.isnan(raw_output.detach().numpy()).any(), f"tsr has nans after first net, {list(self.net1.parameters())}"
    raw_output = self.net2(raw_output)
    assert not np.isnan(raw_output.detach().numpy()).any(), "tsr has nans after second net"
    mu = self.mu_layer(raw_output)
    log_std = self.log_sigma_layer(raw_output)
    #clamping for bounds as per paper
    log_std = torch.clamp(log_std, min = ActorNet.LOG_STD_MIN, max=ActorNet.LOG_STD_MAX)
    std = torch.exp(log_std)
    pi_distribution = Normal(loc=mu, scale=std)
    pi_action = pi_distribution.rsample() if explore else mu
    #this is copied from the spinning up implementation, need to revisit
    if with_logprob:
        logp_pi = pi_distribution.log_prob(pi_action).sum(axis=-1)
        logp_pi -= (2*(np.log(2) - pi_action - F.softplus(-2*pi_action))).sum(axis=1)
    else:
        logp_pi = None

    #i need action values in (0,1)
    pi_action = (torch.tanh(pi_action)+1)/2  

    return pi_action, logp_pi

the state does not contain any nan values nor any infs (the max ~= 300 and min = 0). this is how the loss function is computed (taken from the spinning up implementation of the paper)

def compute_actor_loss(self, batch: Experience):
    states = torch.as_tensor(batch.state)
    pi, logp_pi = self.pi(states)
    q1_pi = self.q1(states, pi)
    q2_pi = self.q2(states, pi)
    q_pi = torch.min(q1_pi, q2_pi)

    loss_pi = (self.alpha*logp_pi - q_pi).mean()  #min Q phi (for mitigating overestimation issue) + alpha times entropy
    return loss_pi

the lr is 3e-4 and the optimizer is an Adam. Normalizing the inputs and Xavier initializing the weights didn't help, and I'm lost on ideas on how to debug. is it possible that any of the logs or expressions are causing nans or leading to the explosion/vanishing of gradients?

0

There are 0 best solutions below