I am trying to train a machine learning model where the loss function is binary cross entropy, because of gpu limitations i can only do batch size of 4 and i'm having lot of spikes in the loss graph. So I'm thinking to back-propagate after some predefined batch size(>4). So it's like i'll do 10 iterations of batch size 4 store the losses, after 10th iteration add the losses and back-propagate. will it be similar to batch size of 40.
TL;DR
f(a+b) = f(a)+f(b) is it true for binary cross entropy?
f(a+b) = f(a) + f(b) doesn't seem to be what you're after. This would imply that BCELoss is additive which it clearly isn't. I think what you really care about is if for some index
i
is true?
The short answer is no, because you're missing some scale factors. What you probably want is the following identity
where
b
is the total batch size.This identity is used as motivation for the gradient accumulation method (see below). Also, this identity applies to any objective function which returns an average loss across each batch element, not just BCE.
Caveat/Pitfall: Keep in mind that batch norm will not behave exactly the same when using this approach since it updates its internal statistics based on batch size during the forward pass.
We can actually do a little better memory-wise than just computing the loss as a sum followed by backpropagation. Instead we can compute the gradient of each component in the equivalent sum individually and allow the gradients to accumulate. To better explain I'll give some examples of equivalent operations
Consider the following model
and lets say our data and targets for a single batch are
Full batch
The body of our training loop for an entire batch may look something like this
Weighted sum of loss on sub-batches
We could have computed this using the sum of multiple loss functions using
Gradient accumulation
The problem with the previous approach is that in order to apply back-propagation, pytorch needs to store intermediate results of layers in memory for every sub-batch. This ends up requiring a relatively large amount of memory and you may still run into memory consumption issues.
To alleviate this problem, instead of computing a single loss and performing back-propagation once, we could perform gradient accumulation. This gives equivalent results of the previous version. The difference here is that we instead perform a backward pass on each component of the loss, only stepping the optimizer once all of them have been backpropagated. This way the computation graph is cleared after each sub-batch which will help with memory usage. Note that this works because
.backward()
actually accumulates (adds) the newly computed gradients to the existing.grad
member of each model parameter. This is whyoptimizer.zero_grad()
must be called only once, before the loop, and not during or after.