TFF: Remote Executor

524 Views Asked by At

We are setting up a federated scenario with Server and Client on different physical machines.

Setup Diagram On the server, we have used the docker container to kickstart:

enter image description here

The above has been borrowed from Kubernetes tutorial. We believe this creates a 'local executor' [Ref 1] which helps create a gRPC server [Ref 2].

Ref 1:

enter image description here

Ref 2:

enter image description here

Next on the client 1, we are calling tff.framework.RemoteExecutor that connects to the gRPC server.

enter image description here

Our understanding based on the above is that the Remote Executor runs on the client which connects to the gRPC server.

Assuming the above is correct, how can we send a

tff.tf_computation

from the server to the client and print the output on the client side to ensure the whole setup works well.

1

There are 1 best solutions below

0
On

Your understanding is definitely correct.

If you construct an ExecutorFactory directly, as seems to be the case in the code above, passing it to tff.framework.set_default_context will install your remote stack as the default mechanism for executing computations in the TFF runtime. You should additionally be able to pass the appropriate channels to tff.backends.native.set_remote_execution_context to handle the remote executor construction and context installation if desired, but the way you are doing it certainly works, and allows for greater customization.

Once you have set this up, running an example end-to-end should be fairly simple. We will set up a computation which takes a set of federated integers, prints on the clients, and sums the integers up. Let:

@tff.tf_computation(tf.int32)
def print_and_return(x):
  # We must use tf.print here, as this logic will be
  # serialized and run on the clients as TensorFlow.
  tf.print('hello world')
  return x

@tff.federated_computation(tff.FederatedType(tf.int32, tff.CLIENTS))
def print_and_sum(federated_arg):
  same_ints = tff.federated_map(print_and_return, federated_arg)
  return tff.federated_sum(same_ints)

Suppose we have N clients; we simply instantiate the set of federated integers, and invoke our computation.

federated_ints = [1] * N
total = print_and_sum(federated_ints)
assert total == N

This should cause the tf.prints defined above to run on the remote machine; as long as tf.print is directed to an output stream which you can monitor, you should be able to see it.

PS: you may note that the federated sum above is unnecessary; it certainly is. The same effect can be had by simply mapping the identity function with the serialized print.