How can I use transfer learning in federated learning?

255 Views Asked by At

I tried to implement federated learning. (Using TensorFlow federated core)

def create_keras_model():
    model = Sequential()
    model.add(Conv2D(16, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu', input_shape=(226,232,1)))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    
    model.add(Conv2D(64, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    
    model.add(Conv2D(128, kernel_size=(3,3), strides=(1,1), padding='same', activation='relu'))
    model.add(MaxPooling2D((2,2), strides=(2,2), padding='same'))
    
    model.add(Flatten())

    model.add(Dense(128, activation='relu'))
    model.add(Dense(64, activation='relu'))
    model.add(Dense(10, activation='softmax'))
    return model



def model_fn():
      keras_model = create_keras_model()
      return tff.learning.from_keras_model(
          keras_model,
          input_spec=federated_train_data[0].element_spec,
          loss=tf.keras.losses.SparseCategoricalCrossentropy(),
          metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

def initialize_fn():
  model = model_fn()
  return model.trainable_variables

def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = client_update(federated_dataset, server_weights_at_client)

  # The server averages these updates.
  mean_client_weights = mean(client_weights)

  # The server updates its model.
  server_weights = server_update(mean_client_weights)

  return server_weights


@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def get_average_temperature(client_temperatures):
  return tff.federated_mean(client_temperatures)

@tff.tf_computation(tf.float32)
def add_half(x):
  return tf.add(x, 0.5)


@tff.federated_computation(tff.FederatedType(tf.float32, tff.CLIENTS))
def add_half_on_clients(x):
  return tff.federated_map(add_half, x)


@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables


@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)


whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)


model_weights_type = server_init.type_signature.result


@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)


@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)


federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights = tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  mean_client_weights = tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)

  return server_weights,client_weights


federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)


server_state = federated_algorithm.initialize()

and save server_state (weights) after each round:

for round in range(3,15):
  server_state,client_weights = federated_algorithm.next(server_state, federated_train_data)
  FileCheckpointManager(root_dir= '/content/drive/MyDrive',prefix='fed_per_',step= 1,keep_total= 1,keep_first= True).save_checkpoint(state=server_state,round_num=round)

now I want to use this pre_trained model for a new federated learning case where the weights of the CNN layer are fixed and only the weights of the 3 last layers are changed.

could someone help me with how I can do this?

1

There are 1 best solutions below

2
On

Using a for loop, you can freeze the layers via the keras.layers API

layer.trainable = False