Accuracy and gradient update not within the same training loop

77 Views Asked by At

I am following the example mlp_mnist.jl from Flux model zoo and I am a little confused with the accuracy computation and train functions.

The accuracy computation function loops over the dataloader and computes the accuracy which is fine.

function loss_and_accuracy(dataloader, model)
    loss::Float32 = 0.0f0
    acc::Int = 0
    num_samples::Int = 0
    
    for (X,y) in dataloader #(60000 - samples)
        ŷ = model(X) #(10,256:batch) -- last batch (10,96)
        # Use agg = sum when using batched inputs, note also += that aggregates
        loss += loss_func(ŷ, y, agg = sum)  #i.e 132230.64 
        # onecold(ŷ) .== onecold(y) -> (256,) .== (256,)
        acc += sum(onecold(ŷ) .== onecold(y)) #i.e 25 <- sums up 1's (look line above)
        num_samples += size(X)[end] #256++ .... ends at 60000 | note last batch doesnt have 256
    end
    # get average accuracy and average loss 
    avg_loss = loss / num_samples #num_samples = 60000 if train
    avg_acc = acc / num_samples
    return avg_loss, avg_acc
end

But the way its called within the train function is a little confusing !. By that I mean in the first for loop we update the gradients for the entire training set and then we loop again over the trianing set (within function loss_and_accuracy) where we calculate the error for each batch. Shouldnt they be within one loop where for each batch the gradients are updated and the accuracy is computed?. For example in pytorch this is generally how I would do (next code snippet after this)

function train(;kws...)
    args = Args(;kws...)
    model = build_model() 
    optimizer = setup(Adam(args.η), model)

    for epoch in 1:args.epochs
        for (X,y) in trainloader
            grad = gradient(m -> loss_func(m(X), y),model)
            Optimise.update!(optimizer,model,grad[1])
        end

        # Why are we looping again after the trainloader after updating the gradients 
        # in th for loop above
        train_loss, train_acc = loss_and_accuracy(trainloader,model)
        test_loss, test_acc = loss_and_accuracy(testloader,model)
        println("Epoch : $epoch")
        println("training loss : $train_loss , training acc : $train_acc")
        println("testing loss : $test_loss, testing acc : $test_acc")
    end
end

Pytorch implementation would be something like this

for e in epoch:
    for X, y in trainloader:
       yhat = model(X)
       loss = loss_func(yhat, y)
       optim.zero_grad()
       loss.backward()
       optim.step()

       running_loss += loss.item()
       running_acc += compute_acc(yhat, y) #compute_acc is some function that computes accuracy)
    print(f"epoch : {e}, loss : {running_loss/len(trainloader)}, acc : {running_acc/len(trainloader)}")


0

There are 0 best solutions below