TF 2 Keras model benchmarking with JIT

1.4k Views Asked by At

I'm trying to benchmark some TF2 keras code - specifically, comparing JIT compiled performance to non-JITed. tf.test.Benchmark gives reasonable looking results without JIT - roughly consistent memory usage compared to nvidia-smi output, and time very close to model.fit - but the JITed version reports tiny memory usage (<1Mb, vs 2.2Gb without JIT), and times which are consistently ~30% less than the time taken during model.fit.

Code provided below. I have 3 main questions:

  1. How do I get an accurate idea of memory usage of JIT models?
  2. What is the source of the speed discrepancy between benchmarked call and model.fit with JIT models?
  3. What's the TF 2 way of doing this? I'm using sessions and tf.compat.v1.data.make_one_shot_iterator, but surely there's a way using @tf.function or something? Are there non-TF tools that can do this better?
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds

ALLOW_GROWTH = False  # switch to this to use nvidia-smi
JIT = True

TFDS_NAME = 'mnist'
SHAPE = (28, 28, 1)
BATCH_SIZE = 64
NUM_CLASSES = 10
NUM_LAYERS = 20
UNITS = 4096
TRY_GCS = False  # switch this if running on colab
TRAIN_STEPS = 200
BURN_ITERS = 200
MIN_ITERS = 200


def model_fn(inp):
    layers = tf.keras.layers
    x = layers.Flatten()(inp)
    for _ in range(NUM_LAYERS):
        x = layers.Dense(UNITS)(x)
        x = layers.BatchNormalization()(x)
        x = layers.Activation('relu')(x)
    logits = layers.Dense(NUM_CLASSES)(x)
    model = tf.keras.Model(inp, logits)
    model.compile(
        optimizer=tf.keras.optimizers.SGD(),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
    return model


def get_dataset():
    return tfds.load(
        TFDS_NAME,
        split='train',
        as_supervised=True,
        in_memory=True,
        try_gcs=TRY_GCS).repeat().shuffle(1024).map(
            lambda image, label: (tf.cast(image, tf.float32) / 255, label),
            tf.data.experimental.AUTOTUNE).batch(BATCH_SIZE).prefetch(
                tf.data.experimental.AUTOTUNE)


def fit(epochs=2, steps_per_epoch=TRAIN_STEPS):
    dataset = get_dataset()
    model = model_fn(tf.keras.Input(shape=SHAPE, dtype=tf.float32))
    model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=epochs)


def benchmark(burn_iters=BURN_ITERS, min_iters=MIN_ITERS):
    with tf.Graph().as_default():
        dataset = get_dataset()
        image, labels = tf.compat.v1.data.make_one_shot_iterator(
            dataset).get_next()
        model = model_fn(tf.keras.Input(tensor=image))
        logits, = model.outputs
        optimizer = model.optimizer
        weights = model.weights
        loss = model.loss(labels, logits)
        grads = optimizer.get_gradients(loss, weights)
        grads_and_vars = tuple(
            (g, v) for g, v in zip(grads, weights) if g is not None)
        op = optimizer.apply_gradients(grads_and_vars)
        op = tf.group((op,) + tuple(model.updates))  # <---

        bm = tf.test.Benchmark()
        with tf.compat.v1.Session() as sess:
            logging.info('Initializing variables...')

            variables = model.weights + optimizer.weights
            for name in ('learning_rate', 'momentum'):
                a = getattr(optimizer, name, None)
                if isinstance(a, tf.Variable):
                    variables.append(a)
            sess.run([v.initializer for v in variables])

            logging.info('Starting benchmarking...')
            result = bm.run_op_benchmark(sess,
                                         op,
                                         burn_iters=burn_iters,
                                         min_iters=min_iters)
            logging.info('Wall time (ms): {}'.format(result['wall_time'] *
                                                     1000))
            gpu_mem = result['extras'].get(
                'allocator_maximum_num_bytes_GPU_0_bfc', 0)
            logging.info('Memory (Mb):    {}'.format(gpu_mem / 1024**2))


logging.set_verbosity(logging.INFO)
tf.config.optimizer.set_jit(JIT)
for device in tf.config.experimental.get_visible_devices('GPU'):
    tf.config.experimental.set_memory_growth(device, ALLOW_GROWTH)
benchmark()
fit()

1

There are 1 best solutions below

0
On

With regards to question 2, it seems that while keras models construct optimized graphs, these don't seem to take advantage of JIT compilation if the model itself is not built in graph mode. I managed to get roughly the same timings between benchmarking and fiting by constructing the model in graph mode, i.e.

def fit(epochs=2, steps_per_epoch=TRAIN_STEPS):
    with tf.Graph().as_default():
        dataset = get_dataset()
        model = model_fn(tf.keras.Input(shape=SHAPE, dtype=tf.float32))
        model.fit(dataset, steps_per_epoch=steps_per_epoch, epochs=epochs)

Having said that, this resulted in slower performance times with other models, though I cannot reduce those models to minimal examples that demonstrate this behaviour.

Part 1 and 3 of original post remain open...