How to load ServerState.optimizer_state to continue training in Tensorflow Federated

202 Views Asked by At

Does TFF have any way to save and load optimizer state similar to model weights. For model weights there are ModelWeights.assign_weights_to() and tff.learning.state_with_new_model_weights() functions, Is there a way to save and load optimizer state especially when using server side optimizer other than SGD.

I could not find anything to save and load state of optimizer.

2

There are 2 best solutions below

0
On BEST ANSWER

This should be achievable with TFF's tff.simulation.FileCheckpointManager. In particular, the usage in Google Research's federated repo was originally written to support restarting from checkpoints when using learning rate scheduling and adaptive optimization on the server , an application for which correctly restoring the optimizer state is crucial.

As long as your tff.templates.IterativeProcess returns the appropriate optimizer state, simply using the FileCheckpointManager out of the box should just work.

0
On

Piggybacking off of Keith's answer: The checkpoint manager is no longer available in Google's federated research repository. It has been upstreamed to TFF (see it on GitHub here).

You can access it either through the tensorflow-federated-nightly pip package, or else by cloning the repository.

The code essentially delegates to tf.saved_model.save, so you could alternatively simply use this.