RuntimeError while running get_weights() in strategy.run in tensorflow

91 Views Asked by At

I am new to tf.distribute and I do not know how to directly get weights of a model in memory. I put my sample code below, and it gives a RuntimeError.

import os
import json

# Dump the cluster information to `'TF_CONFIG'`.
tf_config = {
    'cluster': {
        'worker': ["localhost:12345"],
    },
    'task': {'type': 'worker', 'index': 0}
}

os.environ.pop('TF_CONFIG', None)
os.environ['TF_CONFIG'] = json.dumps(tf_config)
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import tensorflow as tf

strategy = tf.distribute.MultiWorkerMirroredStrategy()

with strategy.scope():
    model = tf.keras.models.Sequential([tf.keras.layers.Dense(10)])

# RuntimeError
weights = strategy.run(model.get_weights)

And this is my error log.

Traceback (most recent call last):
  File "C:\Users\USER\Documents\example.py", line 24, in <module>
    weights = strategy.run(model.get_weights)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 1312, in run
    return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 2888, in call_for_each_replica
    return self._call_for_each_replica(fn, args, kwargs)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_strategy.py", line 676, in _call_for_each_replica
    return mirrored_run.call_for_each_replica(
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 100, in call_for_each_replica
    return _call_for_each_replica(strategy, fn, args, kwargs)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 242, in _call_for_each_replica
    coord.join(threads)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\training\coordinator.py", line 385, in join
    six.reraise(*self._exc_info_to_raise)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\six.py", line 719, in reraise
    raise value
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\training\coordinator.py", line 293, in stop_on_exception
    yield
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\mirrored_run.py", line 342, in run
    self.main_result = self.main_fn(*self.main_args, **self.main_kwargs)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\autograph\impl\api.py", line 595, in wrapper
    return func(*args, **kwargs)
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\keras\engine\training.py", line 2329, in get_weights
    with self.distribute_strategy.scope():
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 389, in __enter__
    _require_cross_replica_or_default_context_extended(
  File "C:\Users\USER\AppData\Roaming\Python\Python39\site-packages\tensorflow\python\distribute\distribute_lib.py", line 312, in _require_cross_replica_or_default_context_extended
    raise RuntimeError(error_message)
RuntimeError: Method requires being in cross-replica context, use get_replica_context().merge_call()

This issue happened on single worker as well as multiple workers. It seems to be inevitable to declare this model inside strategy.scope since I am going to train a model using multiple workers/GPUs. Is it possible to fetch these weights without saving them into hard disks?

0

There are 0 best solutions below