I was trying to train a convolution network. But it is not improving, i.e. loss is not decreasing. And the train function is also terminating much more quickly than usual. Below is the minimal code to show the problem.
using Flux
data=rand(200, 100, 1, 50)
label=rand([0.0,1.0], 1, 50)
model=Chain(
Conv((3,3), 1=>5, pad=(1,1)),
MaxPool((2,2)),
Conv((3,3), 5=>5, pad=(1,1)),
MaxPool((2,2)),
Conv((3,3), 5=>5, pad=(1,1)),
MaxPool((2,2)),
x->reshape(x, :, size(x, 4)),
x->σ.(x),
Dense(1500,100),
Dense(100,1)
)
model(data)
loss=Flux.mse
opt=Descent(0.1)
param=params(model)
loss(model(data), label) #=>0.3492440767136241
Flux.train!(loss, param, zip(data, label), opt)
loss(model(data), label) #=>0.3492440767136241
The first argument to
Flux.train!
needs to be function which accepts the data, runs the model, and returns the loss. Its loop looks something like this:But the function
loss
you provide doesn't call the model at all, it just compares the data point to the label directly.There is more to fix here though. What's being iterated over is tuples of numbers, starting with
zip(data, label) |> first
, which I don't think is what you want. Maybe you wantedFlux.DataLoader
to iterate batches of images?