I'm currently training a neural network to classify food groups of food images, resulting in 5 output classes. However, whenever I begin training the network, I get this error:
ValueError: Expected input batch_size (64) to match target batch_size (30).
Here's my neural network definition and training code. I'd really appriciate help, I'm relatively new to pytorch and can't figure out exactly what the problem is in my code. Thanks!
#Define the Network Architechture
model = nn.Sequential(nn.Linear(7500, 4950),
nn.ReLU(),
nn.Linear(4950, 1000),
nn.ReLU(),
nn.Linear(1000, 250),
nn.ReLU(),
nn.Linear(250, 5),
nn.LogSoftmax(dim = 1))
#Define loss
criterion = nn.NLLLoss()
#Initial forward pass
images, labels = next(iter(trainloader))
images = images.view(images.shape[0], -1)
print(images.shape)
logits = model(images)
print(logits.size)
loss = criterion(logits, labels)
print(loss)
#Define Optimizer
optimizer = optim.SGD(model.parameters(), lr = 0.01)
Training the Network:
epochs = 10
for e in range(epochs):
running_loss = 0
for image, labels in trainloader:
#Flatten Images
images = images.view(images.shape[0], -1)
#Set gradients to 0
optimizer.zero_grad()
#Output
output = model(images)
loss = criterion(output, labels) #Where the error occurs
loss.backward()
#Gradient Descent Step
optimizer.step()
running_loss += loss.item()
else:
print(f"Training loss: {running_loss/len(trainloader)}")
Not 100% sure but I think that the error is in this line:
Put 1 instead of 7500 unless your absolutely sure that your input is 7500. Remember that the first value will always be your input size. By putting 1, you'll ensure that your model can work with any size of images.
By the way, PyTorch has a flatten function. Use
nn.Flatten
instead of usingimages.view()
because you don't wanna make any shape errors and waste more time necessarily.Another small error that you made was that you keep on using
images and image
as variables and parameters in the for loop. This is really bad practice because you're going to confuse someone whenever they read your code. Make sure that you don't reuse the same variables over and over again.Also, could you give more info on your data? Like is it greyscale, image_size, etc.