How to initialize the model with certain weights?

681 Views Asked by At

I am using the example "stateful_clients" in tensorflow-federated examples. I want to use my pretrained model weights to initialize the model. I use the function model.load_weights(init_weight). But it seems that it doesn't work. The validation accuracy in the first round is still low. How can I solve the problem?

def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = get_five_layers_cnn([28, 28, 1])
    keras_model.load_weights(init_weight)
    loss = tf.keras.losses.SparseCategoricalCrossentropy()
    return stateful_fedavg_tf.KerasModelWrapper(keras_model,
                                                test_data.element_spec, loss)
1

There are 1 best solutions below

0
On

A quick primer on state and model weights in TFF

TFF takes a distinct perspective on state in machine learning, generally a consequence of its desire to be purely functional.

Usually in machine learning, a model is conceptually a function which takes data and produces a prediction. However, this notion is a little overloaded at times; does 'model' refer to a trained model (fitting the specification above), or an architecture which is parameterized by its parameters, and therefore needs to accept these parameters as an argument to be considered truly a 'function'? A conception somewhat in the middle is that of a 'stateful function', which I think tends to be what people intend to refer to when they use the term 'model'.

TFF standardizes on the latter understanding. For TFF, a 'model' is a function which accepts parameters along with data as an argument, producing a prediction. This is generally to avoid the notion of a stateful function, which is disallowed by a purely functional perspective (f(x) == f(x) should always be true, so f cannot have any state which affects its output).

On the code in question

I'm not super familiar with this portion of the TFF codebase; in particular I'm a little surprised at the behavior of the keras model wrapper, as usually TFF wants to serialize all logic into TFF-defined data structures as soon as possible (at least, this is how I think about it). Glancing at the code, it looks to me like it could work--but there have been exciting interactions between TFF and Keras in the past.

Briefly, here is how this path should be working:

  1. The model function you define above is invoked while building the initialize computation, in a graph context; the logic to load weights (or assignment of the weights themselves, baked into the graph as a constant) would hopefully be serialized into the graph that TFF generates to represent initialize.
  2. Upon calling iterative_process.initialize, you would find your desired weights populated in the appropriate attributes of the returned data structure. This would serve as your initial starting point for your iterative process, and you would be off to the races.

What I am suspicious of in the above is 1. TFF will silently invoke your model_fn in a TensorFlow graph context, resulting in non program-order semantics; if there is no control dependency between the assignment and the return value of your function (which there isn't in the code above, and in fact it is not obvious how to force this), the assignment may be skipped at initialize time. Therefore the state returned from initialize won't have your specified weights.

If this suspicion is true, the appropriate solution is to run this to run the weight loading logic directly in Python. TFF provides some utilities to help with this kind of thing, like tff.learning.state_with_new_model_weights. This would be used like:

state = iterative_process.initialize()
weights = tf.keras.load_weights(...)  # No idea if this call is correct, probably not.
state_with_loaded_weights = tff.learning.state_with_new_model_weights(state, weights)
...
# continue on using state in the iterative process