I'm trying to train a simple CNN with Flux and running into a weird issue...during training the loss appears to go down (indicating that it's working) but despite what the loss curve suggested the "trained" model output was very bad, and when I calculated the loss by hand I noticed that it differed from what the training indicated it should be (it was acting like it hadn't been trained at all).
I then started calculating the loss returned inside the gradient vs. outside, and after a lot of digging I think the problem is related to the BatchNorm layer. Consider the following minimum example:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image, some relationship to the input values (doesn't matter for this)
m = Chain(BatchNorm(1),Conv((1,1),1=>1)) #very simple model (doesn't really do anything but illustrates the problem)
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m) #loss calculated by gradient
l_final = Flux.mse(m(x),y) #loss calculated again using the model (no parameters have been updated)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
All of the losses above will be different, sometimes pretty drastically (when running just now I got 22.6, 30.7, and 23.0), when I think they should all be the same?
Interestingly if I remove the BatchNorm layer, the outputs are all the same, i.e. running:
using Flux
x = rand(100,100,1,1) #say a greyscale image 100x100 with 1 channel (greyscale) and 1 batch
y = @. 5*x + 3 #output image
m = Chain(Conv((1,1),1=>1))
l_init = Flux.mse(m(x),y) #initial loss after model creation
l_grad, grad = Flux.withgradient(m -> Flux.mse(m(x),y), m)
l_final = Flux.mse(m(x),y)
println("initial loss: $l_init")
println("loss calculated in withgradient: $l_grad")
println("final loss: $l_final")
Produces the same number for each loss calculation.
Why does including the BatchNorm layer change the value of the loss like this?
My (limited) understanding was that this was just supposed to normalize the input values, which I understand could affect the loss between the unormalized and normalized case, but I don't understand why it would produce different values of the loss for the same input values on the same model without any of the parameters of said model being updated?
Look at the documentation of
BatchNormThe key bit here is that per default
track_stats=true. This leads to the changing inputs. If you don't want to have this behaviour, initialise your model withand you'll get identical outputs as in your second example.
The
BatchNormis initialised with zero mean and unit std, and your input data isn't, that's why you'll get the changing output even with repeated identical input in the case thattrack_state=true, as far as I can see it (quickly).