I am new to tf.distribute and I do not know how to directly get weights of a model in memory. I put my sample code below, and it gives a RuntimeError
.
import os
import json
# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
'cluster': {
'worker': ["localhost:12345"],
},
'task': {'type': 'worker', 'index': 0}
}
os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
import tensorflow as tf
strategy = tf.distribute.MultiWorkerMirroredStrategy()
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])
# RuntimeError
weights = strategy.run(model.get_weights)
And this is my error log.
Traceback (most recent call last):
File "C:\Users\USER\Documents\example.py", line 24, in <module>
weights = strategy.run(model.get_weights)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 1312, in run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 2888, in call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_strategy.py", line 676, in _call_for_each_replica
return mirrored_run.call_for_each_replica(
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 100, in call_for_each_replica
return _call_for_each_replica(strategy, fn, args, kwargs)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 242, in _call_for_each_replica
coord.join(threads)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\training\coordinator.py", line 385, in join
six.reraise(*self._exc_info_to_raise)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\six.py", line 719, in reraise
raise value
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\training\coordinator.py", line 293, in stop_on_exception
yield
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 342, in run
self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\autograph\impl\api.py", line 595, in wrapper
return func(*args, **kwargs)
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\keras\engine\training.py", line 2329, in get_weights
with self.distribute_strategy.scope():
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 389, in __enter__
_require_cross_replica_or_default_context_extended(
File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 312, in _require_cross_replica_or_default_context_extended
raise RuntimeError(error_message)
RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()
This issue happened on single worker as well as multiple workers. It seems to be inevitable to declare this model inside strategy.scope
since I am going to train a model using multiple workers/GPUs. Is it possible to fetch these weights without saving them into hard disks?