Why does JAX + STAX model take more GPU memory than needed?

1k Views Asked by At

I'm trying to run a JAX + STAX model from Kaggle kernels on GPU but it fails due to Out Of Memory Error. I've set the XLA_PYTHON_CLIENT_PREALLOCATE to false to avoid preallocation of GPU memory and also tried setting XLA_PYTHON_CLIENT_ALLOCATOR to platform, nothing helped. The default device is set to CPU from the beginning as I do not want all the data stored on GPU. Model and batch data are sent to GPU manually. The size of the variables (model parameters, data...) souldn't be a problem as the same code runs smoothly on CPU, without OOM errors. I've also made memory profiling of the model. In order to get only GPU memory it was necessary to make another version of the code where GPU is the default device and all the data is stored there. If I ran the profiling on the original code where CPU is default I only get the profiling for CPU data. Batch size reduction to 10 was also necessary for the model to complete training. The profiling shows only the memory needed for storing the data and parameters (≈ 5.5GB), but when I check the GPU usage with other Python functions it is much larger (≈ 14.6GB, Note: when run with batch_size = 100 the memory also hits 14.6GB during the first mini batch but cannot go further).

Here is the simplified version of the code I used:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' # Tried this, didn't help

import jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu') # If not set default device = CPU then all the device arrays will be saved to GPU by default

# Set the processor to GPU if available
try: print('Available GPU Devices: ', jax.devices("gpu")); device = jax.devices("gpu")[0]; gpu_available = 1
except: device = jax.devices("cpu")[0]; gpu_available = 0

# Load data into jax device arrays of dimensions (2000, 200, 200, 3)...

InitializationFunction, ApplyFunction = stax.serial(
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Flatten, Dense(128), Relu, Dense(2),)

key = random.PRNGKey(2793)
output_shape, parameters = jax.device_put(InitializationFunction(rng = key, input_shape = (100, image_width, image_height, number_of_channels)), device)
optimizer = optax.adam(0.001)
optimizer_state = jax.device_put(optimizer.init(parameters), device)

def Loss(parameters, inputs, targets):
    predictions = ApplyFunction(parameters, inputs)
    loss = jnp.mean(optax.softmax_cross_entropy(predictions, targets))
    return loss

@jit
def Step(parameters, optimizer_state, inputs, targets):
    loss, gradients = value_and_grad(Loss)(parameters, inputs, targets)
    updates, optimizer_state = optimizer.update(gradients, optimizer_state, parameters)
    parameters = optax.apply_updates(parameters, updates)
    return parameters, optimizer_state, loss

epochs, batch_size = 2, 100
key, subkey = random.split(key)
keys_epochs = random.split(subkey, epochs)
    
for epoch in range(epochs):
    random_indices_order = random.permutation(keys_epochs[epoch], jnp.arange(len(train_set['images'])))

    for batch_number in range(len(train_set['images']) // batch_size):
        start = batch_number * batch_size
        end = (batch_number + 1) * batch_size
        batch_inputs = jax.device_put(jnp.take(train_set['images'], random_indices_order[start:end], 0), device)
        batch_targets = jax.device_put(OneHot(jnp.take(train_set['class_numbers'], random_indices_order[start:end], 0), jnp.max(train_set['class_numbers']) + 1), device)
        parameters, optimizer_state, loss = Step(parameters, optimizer_state, inputs = batch_inputs, targets = batch_targets)          

My questions are:

  1. Why is more GPU memory used than needed for the size of the variables and more than captured with jax device memory profiling? What is the excess of the memory used for, how to track it and how to prevent it?
  2. How to capture both CPU and GPU memory when doing jax device memory profiling? It only captures CPU when CPU is default device, although GPU is available and in use too.

Here is the result of device memory profiling for GPU when GPU is set to default device and stores the entire dataset (2x(2000, 200, 200, 3) ≈ 1.79GB). Batch size is reduced to 10. GPU Jax Device Memory profiling for batch size 10

0

There are 0 best solutions below