Why is GPytorch kernel partition size not reducing GPU memory usage?

122 Views Asked by At

I am attempting to use an Exact Gaussian Process Regression to perform analysis on a dataset of roughly 1 million datapoints. As this would result in phenomenal memory usage without some optimization, I followed the example set by the example notebook here, reusing large segments of code, only feeding in my personal dataset instead of using the example dataset. I am using 10% of my original dataset to ensure my code works as written before moving on to a more in-depth analysis.

Most of the code is identical to the link, but I've transcribed the relevant sections below:

Data preparation

import math
import torch
import gpytorch
from datetime import datetime
import sys
from matplotlib import pyplot as plt
sys.path.append('../')
from LBFGS import FullBatchLBFGS # LGFGS.py file from GitHub
import numpy as np
import pandas as pd
import os

data = np.genfromtxt(filename, delimiter=',', skip_header = 1, dtype=None) # File of 13 parameters, 4 of which are relevant to my algorithm
intensity = data[:, 3]
thickness = data[:, 8]
focal_distance = data[:, 2]
max_energy = data[:, 4][
data = np.dstack((intensity, thickness, focal_distance, max_energy)).reshape(num_points, 4)
data = torch.tensor(data)

Train/test split

N = data.shape[0]
# make train/val/test
n_train = int(0.8 * N)
train_x, train_y = data[:n_train, :-1], data[:n_train, -1]
test_x, test_y = data[n_train:, :-1], data[n_train:, -1]

# normalize features
mean = train_x.mean(dim=-2, keepdim=True)
std = train_x.std(dim=-2, keepdim=True) + 1e-6 # prevent dividing by 0
train_x = (train_x - mean) / std
test_x = (test_x - mean) / std

# normalize labels
mean, std = train_y.mean(),train_y.std()
train_y = (train_y - mean) / std
test_y = (test_y - mean) / std

# make continguous
train_x, train_y = train_x.contiguous(), train_y.contiguous()
test_x, test_y = test_x.contiguous(), test_y.contiguous()

output_device = torch.device('cuda:0')

train_x, train_y = train_x.to(output_device), train_y.to(output_device)
test_x, test_y = test_x.to(output_device), test_y.to(output_device)
n_devices = torch.cuda.device_count()
print('Planning to run on {} GPUs.'.format(n_devices))

Model definition

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood, n_devices):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        base_covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())
        
        self.covar_module = gpytorch.kernels.MultiDeviceKernel(
            base_covar_module, device_ids=range(n_devices),
            output_device=output_device
        )
    
    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def train(train_x,
          train_y,
          n_devices,
          output_device,
          checkpoint_size,
          preconditioner_size,
          n_training_iter,
):
    likelihood = gpytorch.likelihoods.GaussianLikelihood().to(output_device)
    model = ExactGPModel(train_x, train_y, likelihood, n_devices).to(output_device)
    model = model.double() # Necessary to match float64 type of dataset
    model.train()
    likelihood.train()
    
    optimizer = FullBatchLBFGS(model.parameters(), lr=0.1)
    # "Loss" for GPs - the marginal log likelihood
    mll = gpytorch.mlls.ExactMarginalLogLikelihood(likelihood, model)

    
    with gpytorch.beta_features.checkpoint_kernel(checkpoint_size), \
         gpytorch.settings.max_preconditioner_size(preconditioner_size):

        def closure():
            optimizer.zero_grad()
            output = model(train_x)
            loss = -mll(output, train_y)
            return loss

        loss = closure()
        loss.backward()

        for i in range(n_training_iter):
            options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
            loss, _, _, _, _, _, _, fail = optimizer.step(options)
            
            print('Iter %d/%d - Loss: %.3f   lengthscale: %.3f   noise: %.3f' % (
                i + 1, n_training_iter, loss.item(),
                model.covar_module.module.base_kernel.lengthscale.item(),
                model.likelihood.noise.item()
            ))
            
            if fail:
                print('Convergence reached!')
                break
    
    print(f"Finished training on {train_x.size(0)} data points using {n_devices} GPUs.")
    return model, likelihood

Final code block run to define and run GPU parameter search:

import gc

def find_best_gpu_setting(train_x,
                          train_y,
                          n_devices,
                          output_device,
                          preconditioner_size
):
    N = train_x.size(0)

    # Find the optimum partition/checkpoint size by decreasing in powers of 2
    # Start with no partitioning (size = 0)
    settings = [0] + [int(n) for n in np.ceil(N / 2**np.arange(1, np.floor(np.log2(N))))]

    for checkpoint_size in settings:
        print('Number of devices: {} -- Kernel partition size: {}'.format(n_devices, checkpoint_size))
        try:
            # Try a full forward and backward pass with this setting to check memory usage
            _, _ = train(train_x, train_y,
                         n_devices=n_devices, output_device=output_device,
                         checkpoint_size=checkpoint_size,
                         preconditioner_size=preconditioner_size, n_training_iter=1)

            # when successful, break out of for-loop and jump to finally block
            break
        except RuntimeError as e:
            print('RuntimeError: {}'.format(e))
        except AttributeError as e:
            print('AttributeError: {}'.format(e))
        finally:
            # handle CUDA OOM error
            gc.collect()
            torch.cuda.empty_cache()
    return checkpoint_size

# Set a large enough preconditioner size to reduce the number of CG iterations run
preconditioner_size = 100
checkpoint_size = find_best_gpu_setting(train_x, train_y,
                                        n_devices=n_devices,
                                        output_device=output_device,
                                        preconditioner_size=preconditioner_size)

Which produces the following output:

Number of devices: 1 -- Kernel partition size: 0 RuntimeError: CUDA out of memory. Tried to allocate 63.64 GiB (GPU 0; 31.74 GiB total capacity; 93.80 MiB already allocated; 30.03 GiB free; 104.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Number of devices: 1 -- Kernel partition size: 46209 RuntimeError: CUDA out of memory. Tried to allocate 190.91 GiB (GPU 0; 31.74 GiB total capacity; 153.03 MiB already allocated; 29.96 GiB free; 178.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF Number of devices: 1 -- Kernel partition size: 23105 RuntimeError: CUDA out of memory. Tried to allocate 190.91 GiB (GPU 0; 31.74 GiB total capacity; 153.03 MiB already allocated; 29.96 GiB free; 178.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

By my understanding, the purpose of the kernel partition size increasing is to gradually reduce memory usage by a factor of 2 until the required memory is less than the available memory. However, introducing a kernel partition increased memory allocation size from 63.64 Gib to 190.91 GiB regardless of the Kernel size-- I cut off the last few lines for readability's sake, but the kernel partition size continued to decay to 3 without ever reducing the attempted allocation from 190.91 GiB. Is there some aspect of the process that I'm missing?

0

There are 0 best solutions below