What's the gradients dtype during mixed precision training?

174 Views Asked by At

I want to figure out how the torch.cuda.amp.autocast works. Therefore, I conducted an experiment. The code is as following:

class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(CustomModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

input_size = X_train_tensor.shape[1]
hidden_size = 16000
num_classes = 2000
model = CustomModel(input_size, hidden_size, num_classes).to('cuda')
scaler = torch.cuda.amp.GradScaler()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 1
for epoch in range(num_epochs):
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
        outputs = model(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    print(outputs.dtype)
    print(model.fc1.weight.grad.dtype)
    print(model.fc2.weight.grad.dtype)

print("Done!")

and I got the following outputs:

torch.float16
torch.float32
torch.float32
Done!

I am very confused so that I can not figure out which dtype should be for gradients. Otherwise, If all the gradients are float32, is it necessary to use grad scaler? Can anyone help me with this?

I tried use torch.cuda.amp.autocast in a demo. And the results show that all gradients are float32.

2

There are 2 best solutions below

0
On

It is typically recommended to keep your gradients as float32 and it isn't necessary to use a grad scaler, you could instead convert the gradients to float32 as such:

for param in model.parameters():
    if param.grad is not None:
        param.grad = param.grad.to(torch.float32)

However it's typically recommended to use a grad scalar for numerical stability reasons + it's just easier.

0
On

When you pass your tensor to CUDA, it will promote the operations to match the input dtype. Are your inputs float32s? You are explicitly casting your output float16... but I expect in the computation of the loss function it is casting back up to 32.

This is also probably what you want -- since your gradients are less likely to vanish if you scale a float32 loss up rather than a float16 one.