I'm trying to fine-tune OpenLlama 3Bv2 for SequenceClassification but I've got very little experience working with TPUs. Nothing doesn't seem to be performing correctly. The first 2 batches run perfectly fine but I then receive this error: BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.
I'm also generally confused about how I should go about training on the TPU and how I'd use mixed precision. In a perfect world, I'd be using something like Jax or even TensorFlow however I'm not too sure how to use those with a HF Transformers model.
Here's my current code:
import torch
import os
import pickle
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
# Paths and model selection
path = "trained_paragraph_1/"
mini = False
pref = "tokenized_llama/mini/" if mini else "tokenized_llama/"
model_path = 'tokenized_llama/model'
# Training parameters
batch_size = 32
val_batch_size = 64
epochs = 1
steps_per_eval = 10000
# Optimization settings
learning_rate = 1e-6
weight_decay = 0.05
warmup_ratio = 0.1
class PreTokenizedTextDataset(Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
item = {key: torch.tensor(val[idx], dtype=torch.long) for key, val in self.encodings.items()}
item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
return item
# Define the map function for multiprocessing
def _mp_fn(index, flags):
# Load the pre-tokenized training and validation datasets
with open(f'{pref}train_tokenized.pkl', 'rb') as f:
train_encodings, train_labels = pickle.load(f)
train_dataset = PreTokenizedTextDataset(train_encodings, train_labels)
with open(f'{pref}val_tokenized.pkl', 'rb') as f:
val_encodings, val_labels = pickle.load(f)
val_dataset = PreTokenizedTextDataset(val_encodings, val_labels)
# Initialize the TPU device
device = xm.xla_device()
# Define the DataLoader for training and validation
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True)
val_sampler = torch.utils.data.distributed.DistributedSampler(
val_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False
)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=True)
# Initialize the tokenizer and model
model = LlamaForSequenceClassification.from_pretrained(model_path, num_labels=2)
# Move the model to the TPU
model.to(device)
# Define the optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
# Define the learning rate scheduler with warmup
num_training_steps = len(train_loader) * epochs
num_warmup_steps = int(warmup_ratio * num_training_steps)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)
# Training loop with checkpointing and evaluation
model.train()
global_step = 0
best_accuracy = 0.0
for epoch in range(epochs):
para_loader = pl.ParallelLoader(train_loader, [device])
train_iterator = para_loader.per_device_loader(device)
if xm.is_master_ordinal():
train_iterator = tqdm(train_iterator, desc=f"Epoch {epoch+1}", unit="batch")
for batch in train_iterator:
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()
xm.optimizer_step(optimizer)
scheduler.step()
global_step += 1
if xm.is_master_ordinal():
train_iterator.set_postfix(loss=loss.item())
if global_step % steps_per_eval == 0 and xm.is_master_ordinal():
# Checkpointing
xm.save(model.state_dict(), os.path.join(path, f'checkpoint-{global_step}.pt'))
# Evaluation
model.eval()
total_eval_accuracy = 0
total_eval_loss = 0
para_loader = pl.ParallelLoader(val_loader, [device])
eval_iterator = para_loader.per_device_loader(device)
for batch in eval_iterator:
with torch.no_grad():
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
logits = outputs.logits
loss = outputs.loss
total_eval_loss += loss.item()
predictions = torch.argmax(logits, dim=-1)
total_eval_accuracy += accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())
avg_val_accuracy = total_eval_accuracy / len(val_loader)
avg_val_loss = total_eval_loss / len(val_loader)
xm.master_print(f"Step {global_step}, Validation Loss: {avg_val_loss}, Validation Accuracy: {avg_val_accuracy}")
# Save the best model
if avg_val_accuracy > best_accuracy:
best_accuracy = avg_val_accuracy
xm.save(model.state_dict(), os.path.join(path, 'best_model.pt'))
model.train()
# Load the best model
if xm.is_master_ordinal():
model.load_state_dict(torch.load(os.path.join(path, 'best_model.pt')))
# Save the fine-tuned model
if xm.is_master_ordinal():
model.save_pretrained(f'{path}fine_tuned_model')
# Start training using xmp.spawn
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), start_method='fork')
If anyone has any idea on how I can fix the problem I'm encountering and/or how to improve performance that would be great. Genuienly any tips or tricks would be fantastic. Thanks!
I've tried to implement BF16 but I couldn't get it to work.
I also tried to switch to using TensorFlow but I couldn't get the OpenLlama model to work there.