I have a long list of texts, each with 512 tokens max:
texts = ["some fairly long text", "some fairly long text2", ... ]
And a model
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
device = "cuda"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
How can I calculate the average perplexity of the model over all texts?
The first approach I tried (from this SO) was using evaluate:
perplexity = evaluate.load("perplexity", module_type="metric")
input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]
results = perplexity.compute(model_id='gpt2',
add_start_token=False,
predictions=input_texts)
print(list(results.keys()))
>>>['perplexities', 'mean_perplexity']
print(round(results["mean_perplexity"], 2))
But this runs out of memory for larger models (e.g., Llama 7B) on my machine. I even tried using batch_size=1, but that didn't seem to make a difference for some reason.
The other approach was using this huggingface tutorial, but they concatenate all the texts into 1 long text and use a sliding window. While this somehow fits my memory (I don't understand why this works but batch of size 1 didn't), the result is not as optimal as doing it per text, which is more realistic.
Is there a better method that is less computationally heavy than the evaluate method and that will calculate the perplexity per sentence rather than with a sliding window?
Update:
I tried the following, but got different results than the evaluate.compute method, so I'm not sure what I'm doing wrong:
import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
# Load the model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)
# List of texts
texts = ["some fairly long text", "some fairly long text2", ...]
# Function to calculate perplexity
def calculate_perplexity(text):
tokenized_text = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
input_ids = tokenized_text.input_ids
attention_mask = tokenized_text.attention_mask
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
logits = outputs.logits
# Flatten the logits and input_ids
logits = logits.view(-1, logits.shape[-1])
input_ids = input_ids.view(-1)
# Compute the cross-entropy loss
loss = torch.nn.functional.cross_entropy(logits, input_ids, reduction='none')
perplexity = torch.exp(torch.mean(loss))
return perplexity.item()
# Calculate perplexity for each text
perplexities = [calculate_perplexity(text) for text in texts]
# Calculate average perplexity
average_perplexity = sum(perplexities) / len(perplexities)
print("Average Perplexity:", average_perplexity)