How can micro batching be implemented in tensorflow 2.x? That is I would like to accumulate gradients for several batches and then update the weights with these accumulated gradients (this would virtually increase my batch size to accumulation steps * batch size).
I tried with the following code:
import numpy as np
import tensorflow as tf
class Model(tf.keras.Model):
def __init__(self, ):
super().__init__()
self.dense = tf.keras.layers.Dense(1)
def call(self, inputs):
return self.dense(inputs)
class Trainer:
def __init__(self, model, num_accumulate):
self.model = model
self.num_accumulate = num_accumulate
self.optimizer = tf.keras.optimizers.Adam()
self.accumulated_gradients = None
def _init_accumulated_gradients_maybe(self):
if self.accumulated_gradients is None:
self.accumulated_gradients = [tf.Variable(var, dtype=var.dtype, trainable=False) for var in self.model.trainable_weights]
self._reset_gradients()
def _reset_gradients(self):
for grad in self.accumulated_gradients:
grad.assign(tf.zeros_like(grad))
def _accumulate_gradients(self, gradients):
for acc_grad, grad in zip(self.accumulated_gradients, gradients):
acc_grad.assign_add( grad / self.num_accumulate )
def get_mae(self, targets, mean_pred):
return tf.reduce_mean(tf.abs(targets - mean_pred))
@tf.function
def train_on_batch(self, dataset_iter):
for _ in range(self.num_accumulate): # problematic
inputs, target = next(dataset_iter)
with tf.GradientTape() as tape:
prediction = self.model(inputs, training=True)
loss = self.get_mae(target, prediction)
gradients = tape.gradient(loss, self.model.trainable_weights)
self._init_accumulated_gradients_maybe()
self._accumulate_gradients(gradients)
gradients = self.accumulated_gradients
self.optimizer.apply_gradients(zip(gradients, self.model.trainable_weights))
self._reset_gradients()
return loss
class DataProvider:
def __init__(self,
batch_size: int = 1,
):
self.batch_size = batch_size
self.in_data = np.random.rand(100,10)
self.out_data = np.random.rand(100,1)
def get_dataset(self):
def generator():
while True:
yield (tf.constant(self.in_data, dtype=tf.float32), tf.constant(self.out_data, dtype=tf.float32))
return tf.data.Dataset.from_generator(
generator,
output_types=(tf.float32, tf.float32),
output_shapes=([None,10], [None,1])
)
num_accumulate = 4
batch_size = 25
nSteps = 10
model = Model()
trainer = Trainer(model, num_accumulate)
dataset_iter = iter(DataProvider(batch_size).get_dataset())
for step in range(1, nSteps):
trainer.train_on_batch(dataset_iter)
However, I ran into two different problems depending on if I use tf.range or range inside the tf.function decorated function.
- Using range: It works with the provided mini model but in my use case the model is significantly bigger (2.6 Mio params) and when I accumulate gradients like this the following error is raised:
2021-04-24 18:19:28.349940: W tensorflow/core/common_runtime/process_function_library_runtime.cc:733] Ignoring multi-device function optimization failure: Deadline exceeded: meta_optimizer exceeded deadline.
My guess is that using range (as far as I understood how tf.function works) every gradient accumulation step is added to the graph instead of repeating this part and adding it only once.
- Replacing range with tf.range raises the following error:
Traceback (most recent call last):
File "/mydirectory/model/test_train copy.py", line 89, in <module>
trainer.train_on_batch(dataset_iter)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 505, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 3299, in bound_method_wrapper
return wrapped_fn(*args, **kwargs)
File "/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/mydirectory/model/test_train copy.py:40 train_on_batch *
for _ in tf.range(self.num_accumulate):
/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:343 for_stmt
_tf_range_for_stmt(
/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:526 _tf_range_for_stmt
_tf_while_stmt(
/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:862 _tf_while_stmt
_verify_loop_init_vars(init_vars, symbol_names)
/mydirectory/anaconda3/envs/tf/lib/python3.8/site-packages/tensorflow/python/autograph/operators/control_flow.py:119 _verify_loop_init_vars
raise ValueError('"{}" must be defined before the loop.'.format(name))
ValueError: "loss" must be defined before the loop.
Therefore, I initialized all occuring variables such as gradients, loss and prediction and then it works but it is painfully slow (in my use case) why is that?
What am I missing ? Any help is highly appreciated.