Looking at this PR, I see that one can define on_start
and on_gradient
callbacks for caffe.Solver
object.
import caffe
solver = caffe.AdamSolver('solver.prototxt')
solver.add_callback(on_start, on_gradient) # <- ??
What type of objects are on_start
and on_gradient
?
What are these callbacks for?
How can one use them (an example would be nice...)?
1. Where and how are the callbacks defined?
The callbacks are part of the Solver, and are thus defined in the
solver.hpp
file. To be exact, there is aCallback
class, which looks like this:and a protected vector of such callbacks, which is a member of the
Solver
class.So, this basically provides an
add_callback
function to theSolver
class, which allows one to add an object of typeCallback
to a vector. This is to make sure, that each callback has two methods:on_start()
andon_gradients_ready()
.2. Where are the callbacks called?
This happens in the
solver.cpp
file, in thestep()
function, which contains the main worker loop. Here's that main loop part (with lots of things stripped out for simplicity):3. Where is this used?
This callback feature was implemented when multi-GPU support was added. The only place (that I know of), where callbacks are used, is to synchronize the solver between multiple GPUs:
The
P2PSync
class inparallel.hpp
inherits from theSolver::Callback
class, and implements anon_start()
andon_gradients_ready()
method, which synchronize the GPUs and finally accumulate the all gradient updates.4. How to use callbacks from Python?
As the pull request #3020 explains,
so it should be straight-forward to use. A full, runnable example is shown in this Github Gist I created.
5. How is this useful?
As the two callback functions do not take any arguments, you can't simply use them to keep track of the loss or similar things. To do that, you have to create a wrapper function around the
Solver
class, and calladd_callback
with two methods as callback functions. This allows you to access the net from within the callback, by usingself.solver.net
. In the following example, I use theon_start
callback to load data into the net, and theon_gradients_ready
callback to print the loss function.