Customized aggregation algorithm for gradient updates in tensorflow federated

402 Views Asked by At

I have been trying to implement this paper . Basically what I want to do is sum the per client loss and compare the same with previous epoch. Then for each constituent layer of the model compare the KL divergence between the weights of the server and the client model to get the layer specific parameter updates and then doing a softmax and to decide whether an adaptive update or a normal FedAvg approach is needed.

The algorithm is as follows- FedMed

I tried to make use of the code here to build a custom federated avg process. I got the basic understanding that there are some tf.computations and some tff.computations which are involved. I get that I need to make changes in the orchestration logic in the run_one_round function and basically manipulate the client outputs to do adaptive averaging instead of the vanilla federated averaging. The client_update tf.computation function basically returns all the values that I need i.e the weights_delta (can be used for client based model weights), model_output(which can be used to calculate the loss).

But I am not sure where exactly I should make the changes.

@tff.federated_computation(federated_server_state_type,
                             federated_dataset_type)
  def run_one_round(server_state, federated_dataset):

    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)
  
  client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight

# todo
# instead of using tff.federated_mean I wish to do a adaptive aggregation based on the client_outputs.weights_delta and server_state model
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)


#client_outputs.weights_delta   has all the client model weights.
#client_outputs.client_weight has the number of examples per client.
#client_outputs.model_output has the output of the model per client example.

I want to make use of the server model weights using server_state object. I want to calculate the KL divergence between the weights of server model and each client's model per layer. Then use a relative weight to aggregate the client weights instead of vanilla federated averaging. Instead of using tff.federated_mean I wish to use a different strategy basically an adaptive one based on the algorithm above.
So I needed some suggestions on how to go about implementing this. Basically what I want to do is :
1)Sum all the values of client losses.
2)Calculate the KL divergence per layerbasis of all the clients with server and then determine whether to use adaptive optimization or FedAvg.

Also is there a way to manipulate this value as a python value which will be helpful for debugging purposes( I tried to use tf.print but that was not helpful either). Thanks!

1

There are 1 best solutions below

0
On BEST ANSWER

Simplest option: compute weights for mean on clients

If I read the algorithm above correctly, we need only compute some weights for a mean on-the-fly. tff.federated_mean accepts an optional CLIENTS-placed weight argument, so probably the simplest option here is to compute the desired weights on the clients and pass them in to the mean.

This would look something like (assuming the appropriate definitions of the variables used below, which we will comment on):


@tff.federated_computation(...)
def round_function(...):
  ...
  # We assume there is a tff.Computation training_fn that performs training,
  # and we're calling it here on the correct arguments
  trained_clients = tff.federated_map(training_fn, clients_placed_arguments)
  # Next we assume there is a variable in-scope server_model,
  # representing the 'current global model'.
  global_model_at_clients = tff.federated_broadcast(server_model)
  # Here we assume a function compute_kl_divergence, which takes
  # two structures of tensors and computes the KL divergence
  # (as a scalar) between them. The two arguments here are clients-placed,
  # so the result will be as well.
  kl_div_at_clients = tff.federated_map(compute_kl_divergence,
      (global_model_at_clients, trained_clients))
  # Perhaps we wish to not use raw KL divergence as the weight, but rather
  # some function thereof; if so, we map a postprocessing function to
  # the computed divergences. The result will still be clients-placed.
  mean_weight = tff.federated_map(postprocess_divergence, kl_div_at_clients)
  # Now we simply use the computed weights in the mean.
  return tff.federated_mean(trained_clients, weight=mean_weight)

More flexible tool: tff.federated_reduce

TFF generally encourages algorithm developers to implement whatever they can 'in the aggregation', and as such exposes some highly customizable primitives like tff.federated_reduce, which allow you to run arbitrary TensorFlow "in the stream" between clients and server. If the above reading of the desired algorithm is incorrect and something more involved is needed, or you wish to flexibly experiment with totally different notions of aggregation (something TFF encourages and is designed to support), this may be the option for you.

In TFF's heuristic typing language, tff.federated_reduce has signature:

<{T}@CLIENTS, U, (<U, T> -> U)> -> U@SERVER

Meaning, federated_reduce take a value of type T placed at the clients, a 'zero' in a reduction algebra of type U, and a function accepting a U and a T and producing a U, and applies this function 'in the stream' on the way between clients and server, producing a U placed at the server. The function (<U, T> -> U) will be applied to the partially accumulated value U, and the 'next' element in the stream T (note however that TFF does not guarantee ordering of these values), returning another partially accumulated value U. The 'zero' should represent whatever 'partially accumulated' means over the empty set in your application; this will be the starting point of the reduction.

Application to this problem

The components

Your reduction function needs access to two pieces of data: the global model state and the result of training on a given client. This maps quite nicely to the type T. In this application, we will have something like:

T = <server_model=server_model_type, trained_model=trained_model_type>

These two types are likely to be the same, but may not necessarily be so.

Your reduction function will accept the partial aggregate, your server model and your client-trained model, returning a new partial aggregate. Here we will start assuming the same reading of the algorithm as above, that of a weighted mean with particular weights. Generally, the easiest way to compute a mean is to keep two accumulators, one for numerator and one for denominator. This will affect the choice of zero and reduction function below.

Your zero should contain a structure of tensors with value 0 mapping to the weights of your model--this will be the numerator. This would be generated for you if you had an aggregation like tff.federated_sum (as TFF knows what the zero should be), but for this case you'll have to get your hands on such a tensor yourself. This shouldn't be too hard with tf.nest.map_structure and tf.zeros_like.

For the denominator, we will assume we just need a scalar. TFF and TF are much more flexible than this--you could keep a per-layer or per-parameter denominator if desired--but for simplicity we will assume that we just want to divide by a single float in the end.

Therefore our type U will be something like:

U = <numerator=server_model_type, denominator=tf.float32>

Finally we come to our reduction function. It will be more or less a different composition of the same pieces above; we will make slightly tighter assumptions about them here (in particular, that all the local functions are tff.tf_computations--a technical assumption, arguably a bug on TFF). Our reduction function will be along the lines (assuming appropriate type aliases):

@tff.tf_computation(U, T)
def reduction(partial_accumulate, next_element):
  kl_div = compute_kl_divergence(
      next_element.server_model, next_element.trained_model)
  weight = postprocess_divergence(kl_div)
  new_numerator = partial_accumulate.numerator + weight * next_element.trained_model
  new_denominator = partial_accumulate.denominator + weight
  return collections.OrderedDict(
      numerator=new_numerator, denominator=new_denominator)

Putting them together

The basic outline of a round will be similar to the above; but we have put more computation 'in the stream', and consequently there wil be less on the clients. We assume here the same variable definitions.


@tff.federated_computation(...)
def round_function(...):
  ...
  trained_clients = tff.federated_map(training_fn, clients_placed_arguments)
  global_model_at_clients = tff.federated_broadcast(server_model)
  # This zip I believe is not necessary, but it helps my mental model.
  reduction_arg = tff.federated_zip(
      collections.OrderedDict(server_model=global_model_at_clients,
                              trained_model=trained_clients))
  # We assume a zero as specified above
  return tff.federated_reduce(reduction_arg,
                              zero,
                              reduction)