Microbatching (accumulating gradients) in Tensorflow 2.x with tf.function

287 Views Asked by At

How can micro batching be implemented in tensorflow 2.x? That is I would like to accumulate gradients for several batches and then update the weights with these accumulated gradients (this would virtually increase my batch size to accumulation steps * batch size).

I tried with the following code:

import numpy as np
import tensorflow as tf

class Model(tf.keras.Model):
    def __init__(self,  ):
        super().__init__()
    
        self.dense = tf.keras.layers.Dense(1)

    def call(self, inputs):
        return self.dense(inputs)


class Trainer:
    def __init__(self, model, num_accumulate):
        self.model = model
        self.num_accumulate = num_accumulate
        self.optimizer = tf.keras.optimizers.Adam()
        self.accumulated_gradients = None

    def _init_accumulated_gradients_maybe(self):
        if self.accumulated_gradients is None:
            self.accumulated_gradients = [tf.Variable(var, dtype=var.dtype, trainable=False) for var in self.model.trainable_weights]
            self._reset_gradients()

    def _reset_gradients(self):
        for grad in self.accumulated_gradients:
            grad.assign(tf.zeros_like(grad))

    def _accumulate_gradients(self, gradients):
        for acc_grad, grad in zip(self.accumulated_gradients, gradients):
            acc_grad.assign_add( grad / self.num_accumulate )

    def get_mae(self, targets, mean_pred):
        return tf.reduce_mean(tf.abs(targets - mean_pred))

    @tf.function
    def train_on_batch(self, dataset_iter):
        
        for _ in range(self.num_accumulate): # problematic
            inputs, target = next(dataset_iter)

            with tf.GradientTape() as tape:
                prediction = self.model(inputs, training=True)
                loss = self.get_mae(target, prediction)

            gradients = tape.gradient(loss, self.model.trainable_weights)

            self._init_accumulated_gradients_maybe()
            self._accumulate_gradients(gradients)
            gradients = self.accumulated_gradients

        self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
        self._reset_gradients()

        return loss

class DataProvider:
    def __init__(self,  
                        batch_size: int = 1, 
                ):
        self.batch_size = batch_size
        self.in_data = np.random.rand(100,10)
        self.out_data = np.random.rand(100,1)

    def get_dataset(self):
        def generator():
            while True:
                yield (tf.constant(self.in_data, dtype=tf.float32), tf.constant(self.out_data, dtype=tf.float32))

        return tf.data.Dataset.from_generator(
                generator,
                output_types=(tf.float32, tf.float32),
                output_shapes=([None,10], [None,1])
                )


num_accumulate = 4
batch_size = 25
nSteps = 10

model = Model()
trainer = Trainer(model, num_accumulate)
dataset_iter = iter(DataProvider(batch_size).get_dataset())

for step in range(1, nSteps):
    trainer.train_on_batch(dataset_iter)

However, I ran into two different problems depending on if I use tf.range or range inside the tf.function decorated function.

  1. Using range: It works with the provided mini model but in my use case the model is significantly bigger (2.6 Mio params) and when I accumulate gradients like this the following error is raised:

2021-04-24 18:19:28.349940: W tensorflow/core/common_runtime/process_function_library_runtime.cc:733] Ignoring multi-device function optimization failure: Deadline exceeded: meta_optimizer exceeded deadline.

My guess is that using range (as far as I understood how tf.function works) every gradient accumulation step is added to the graph instead of repeating this part and adding it only once.

  1. Replacing range with tf.range raises the following error:
Traceback (most recent call last):
  File "/mydirectory/model/test_train copy.py", line 89, in <module>
    trainer.train_on_batch(dataset_iter)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 505, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3299, in bound_method_wrapper
    return wrapped_fn(*args, **kwargs)
  File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /mydirectory/model/test_train copy.py:40 train_on_batch  *
        for _ in tf.range(self.num_accumulate):
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:343 for_stmt
        _tf_range_for_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:526 _tf_range_for_stmt
        _tf_while_stmt(
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:862 _tf_while_stmt
        _verify_loop_init_vars(init_vars, symbol_names)
    /mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:119 _verify_loop_init_vars
        raise ValueError('"{}" must be defined before the loop.'.format(name))

    ValueError: "loss" must be defined before the loop.

Therefore, I initialized all occuring variables such as gradients, loss and prediction and then it works but it is painfully slow (in my use case) why is that?

What am I missing ? Any help is highly appreciated.

0

There are 0 best solutions below