I am following the example mlp_mnist.jl from Flux model zoo and I am a little confused with the accuracy computation and train functions.
The accuracy computation function loops over the dataloader and computes the accuracy which is fine.
function loss_and_accuracy(dataloader, model)
loss::Float32 = 0.0f0
acc::Int = 0
num_samples::Int = 0
for (X,y) in dataloader #(60000 - samples)
ŷ = model(X) #(10,256:batch) -- last batch (10,96)
# Use agg = sum when using batched inputs, note also += that aggregates
loss += loss_func(ŷ, y, agg = sum) #i.e 132230.64
# onecold(ŷ) .== onecold(y) -> (256,) .== (256,)
acc += sum(onecold(ŷ) .== onecold(y)) #i.e 25 <- sums up 1's (look line above)
num_samples += size(X)[end] #256++ .... ends at 60000 | note last batch doesnt have 256
end
# get average accuracy and average loss
avg_loss = loss / num_samples #num_samples = 60000 if train
avg_acc = acc / num_samples
return avg_loss, avg_acc
end
But the way its called within the train function is a little confusing !. By that I mean
in the first for loop we update the gradients for the entire training set and then we loop again over the trianing set (within function loss_and_accuracy) where we calculate the error for each batch. Shouldnt they be within one loop where for each batch the gradients are updated and the accuracy is computed?. For example in pytorch this is generally how I would do (next code snippet after this)
function train(;kws...)
args = Args(;kws...)
model = build_model()
optimizer = setup(Adam(args.η), model)
for epoch in 1:args.epochs
for (X,y) in trainloader
grad = gradient(m -> loss_func(m(X), y),model)
Optimise.update!(optimizer,model,grad[1])
end
# Why are we looping again after the trainloader after updating the gradients
# in th for loop above
train_loss, train_acc = loss_and_accuracy(trainloader,model)
test_loss, test_acc = loss_and_accuracy(testloader,model)
println("Epoch : $epoch")
println("training loss : $train_loss , training acc : $train_acc")
println("testing loss : $test_loss, testing acc : $test_acc")
end
end
Pytorch implementation would be something like this
for e in epoch:
for X, y in trainloader:
yhat = model(X)
loss = loss_func(yhat, y)
optim.zero_grad()
loss.backward()
optim.step()
running_loss += loss.item()
running_acc += compute_acc(yhat, y) #compute_acc is some function that computes accuracy)
print(f"epoch : {e}, loss : {running_loss/len(trainloader)}, acc : {running_acc/len(trainloader)}")