How do I calculate the accuracy of my Vision Transformer?

107 Views Asked by At

I'm new to PyTorch, and want to find the accuracy of each epoch. I know that accuracy is # of correct predictions / the total samples, but I don't know how to integrate this into my code.:

for epoch in range(1):
epoch_losses = []
model.train()
for step, (inputs, labels) in enumerate(train_dataloader):
    optimizer.zero_grad()
    inputs, labels = inputs.to(device), labels.to(device)
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    epoch_losses.append(loss.item())
if epoch % 1 == 0:  # For every epoch
    print(f">>> Epoch {epoch+1} train loss: ", np.mean(epoch_losses))
    epoch_losses = []
    model.eval()
    epoch_losses = []

    for step, (inputs, labels) in enumerate(test_dataloader):
        # Unpack the tuple
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        epoch_losses.append(loss.item())
    print(f">>> Epoch {epoch+1} test loss: ", np.mean(epoch_losses))
2

There are 2 best solutions below

0
On BEST ANSWER

To calculate the accuracy of your Vision Transformer model, you need to keep track of the number of correct predictions during both training and testing epochs. Then, you can divide the total number of correct predictions by the total number of samples to get the accuracy.

Here's how you can modify your code to calculate accuracy during each epoch:

import torch

for epoch in range(num_epochs):
    epoch_losses = []
    correct_predictions = 0
    total_samples = 0
    
    # Training phase
    model.train()
    for step, (inputs, labels) in enumerate(train_dataloader):
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        
        # Calculate accuracy during training
        _, predicted = torch.max(outputs, 1)
        correct_predictions += torch.sum(predicted == labels).item()
        total_samples += labels.size(0)

    train_accuracy = correct_predictions / total_samples
    
    # Validation phase
    model.eval()
    epoch_losses = []
    correct_predictions = 0
    total_samples = 0
    
    with torch.no_grad():
        for step, (inputs, labels) in enumerate(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            epoch_losses.append(loss.item())
            
            # Calculate accuracy during validation
            _, predicted = torch.max(outputs, 1)
            correct_predictions += torch.sum(predicted == labels).item()
            total_samples += labels.size(0)

    test_accuracy = correct_predictions / total_samples
    
    print(f">>> Epoch {epoch+1} train loss: {np.mean(epoch_losses)} train accuracy: {train_accuracy}")
    print(f">>> Epoch {epoch+1} test loss: {np.mean(epoch_losses)} test accuracy: {test_accuracy}")

In this code:

  • correct_predictions keeps track of the number of correct predictions.
  • total_samples keeps track of the total number of samples.
  • torch.max(outputs, 1) finds the index of the maximum value along the predicted axis.
  • torch.sum(predicted == labels).item() calculates the number of correct predictions in a batch.

Finally, you compute accuracy by dividing the total number of correct predictions by the total number of samples.

0
On

Since you said its a 5 class problem, your output tensor should have the shape [1,5]. Do a torch.max on the output tensor.

#This is assuming that the num of classes is in dim 1
_, pred = torch.max(scores, dim=1) 

The first output will be the value and the second - which is what we need - will be the index of the most probable class. Next, do a direct comparison with your Ground Truth label.

if pred == labels:
   correct_predictions+=1

Then, average it out for your entire val/train set.

This is just a skeleton; you may have to change it to suit your particular setting.