What's the gradients dtype during mixed precision training?

188 Views Asked by At

I want to figure out how the torch.cuda.amp.autocast works. Therefore, I conducted an experiment. The code is as following:

class CustomModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(CustomModel, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc3 = nn.Linear(hidden_size, num_classes)
        
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.fc3(out)
        return out

input_size = X_train_tensor.shape[1]
hidden_size = 16000
num_classes = 2000
model = CustomModel(input_size, hidden_size, num_classes).to('cuda')
scaler = torch.cuda.amp.GradScaler()

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 1
for epoch in range(num_epochs):
    optimizer.zero_grad()
    with torch.cuda.amp.autocast(dtype=torch.float16, enabled=True):
        outputs = model(X_train_tensor)
        loss = criterion(outputs, y_train_tensor)
        
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    print(outputs.dtype)
    print(model.fc1.weight.grad.dtype)
    print(model.fc2.weight.grad.dtype)

print("Done!")

and I got the following outputs:

torch.float16
torch.float32
torch.float32
Done!

I am very confused so that I can not figure out which dtype should be for gradients. Otherwise, If all the gradients are float32, is it necessary to use grad scaler? Can anyone help me with this?

I tried use torch.cuda.amp.autocast in a demo. And the results show that all gradients are float32.

2

There are 2 best solutions below

0
ZainSharief On

It is typically recommended to keep your gradients as float32 and it isn't necessary to use a grad scaler, you could instead convert the gradients to float32 as such:

for param in model.parameters():
    if param.grad is not None:
        param.grad = param.grad.to(torch.float32)

However it's typically recommended to use a grad scalar for numerical stability reasons + it's just easier.

0
Saul Aryeh Kohn On

When you pass your tensor to CUDA, it will promote the operations to match the input dtype. Are your inputs float32s? You are explicitly casting your output float16... but I expect in the computation of the loss function it is casting back up to 32.

This is also probably what you want -- since your gradients are less likely to vanish if you scale a float32 loss up rather than a float16 one.