How to write a proper dataset_fn in tff.simulation.FilePerUserClientData?

213 Views Asked by At

I'm currently implementing federated learning using tff.

Because the dataset is very large, we split it into many npy files, and I'm currently putting the dataset together using tff.simulation.FilePerUserClientData.

This is what I'm trying to do

client_ids_to_files = dict()
for i in range(len(train_filepaths)):
  client_ids_to_files[str(i)] = train_filepaths[i]

def dataset_fn(filepath):
  print(filepath)
  dataSample = np.load(filepath)
  label = filepath[:-4].strip().split('_')[-1]
  return tf.data.Dataset.from_tensor_slices((dataSample, label))
train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)

However, it doesn't seem to work well, the filepath in the callback function has is a tensor with dtype of string. The value of filepath is: Tensor("hash_table_Lookup/LookupTableFindV2:0", shape=(), dtype=string)

Instead of containing a path in client_ids_to_files, the tensor seems to contains error messages? Am I doing something wrong? How can I write a proper dataset_fn for tff.simulation.FilePerUserClientData using npy files?

EDIT: Here is the error log. The error itself is not really related to the question I'm asking, but you can find the called functions:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-46-e61ddbe06cdb> in <module>
     22     return tf.data.Dataset.from_tensor_slices(filepath)
     23 
---> 24 train_filePerClient = tff.simulation.FilePerUserClientData(client_ids_to_files,dataset_fn)
     25 

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in __init__(self, client_ids_to_files, dataset_fn)
     52       return dataset_fn(client_ids_to_files[client_id])
     53 
---> 54     @computations.tf_computation(tf.string)
     55     def dataset_computation(client_id):
     56       client_ids_to_path = tf.lookup.StaticHashTable(

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py in __call__(self, tff_internal_types, *args)
    405                                             parameter_type)
    406       args, kwargs = unpack_arguments_fn(next(wrapped_fn_generator))
--> 407       result = fn_to_wrap(*args, **kwargs)
    408       if result is None:
    409         raise ComputationReturnedNoneError(fn_to_wrap)

~/fasttext-venv/lib/python3.6/site-packages/tensorflow_federated/python/simulation/file_per_user_client_data.py in dataset_computation(client_id)
     59               list(client_ids_to_files.values())), '')
     60       client_path = client_ids_to_path.lookup(client_id)
---> 61       return dataset_fn(client_path)
     62 
     63     self._create_tf_dataset_fn = create_dataset_for_filename_fn

<ipython-input-46-e61ddbe06cdb> in dataset_fn(filepath)
     17         filepath = tf.print(filepath)
     18     print(filepath)
---> 19     dataSample = np.load(filepath)
     20     print(dataSample)
     21     label = filepath[:-4].strip().split('_')[-1]

~/fasttext-venv/lib/python3.6/site-packages/numpy/lib/npyio.py in load(file, mmap_mode, allow_pickle, fix_imports, encoding)
    426         own_fid = False
    427     else:
--> 428         fid = open(os_fspath(file), "rb")
    429         own_fid = True
    430 

TypeError: expected str, bytes or os.PathLike object, not Operation
1

There are 1 best solutions below

0
On

The problem is the dataset_fn must be serializable as a tf.Graph. This is required because TFF uses TensorFlow graphs to execute logic on remote machines.

In this case, np.load is not serializable to a graph operation. It looks like numpy is used to load from disk in to memory, and then tf.data.Dataset.from_tensor_slices is used to create a dataset from an in-memory object? I may be possible to save the file in a different format and use a native tf.data.Dataset operation to load from disk, rather than using Python. Some options could be tf.data.TFRecordDataset, tf.data.TextLineDataset, or tf.data.experimental.SqlDataset.