How to calculate language model's perplexity for text that exceeds memory?

88 Views Asked by At

I have a long list of texts, each with 512 tokens max:

texts = ["some fairly long text", "some fairly long text2",  ... ]

And a model

from transformers import GPT2LMHeadModel, GPT2TokenizerFast

device = "cuda"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

How can I calculate the average perplexity of the model over all texts?

The first approach I tried (from this SO) was using evaluate:

perplexity = evaluate.load("perplexity", module_type="metric")
input_texts = ["lorem ipsum", "Happy Birthday!", "Bienvenue"]

results = perplexity.compute(model_id='gpt2',
                             add_start_token=False,
                             predictions=input_texts)
print(list(results.keys()))
>>>['perplexities', 'mean_perplexity']
print(round(results["mean_perplexity"], 2))

But this runs out of memory for larger models (e.g., Llama 7B) on my machine. I even tried using batch_size=1, but that didn't seem to make a difference for some reason.

The other approach was using this huggingface tutorial, but they concatenate all the texts into 1 long text and use a sliding window. While this somehow fits my memory (I don't understand why this works but batch of size 1 didn't), the result is not as optimal as doing it per text, which is more realistic.

Is there a better method that is less computationally heavy than the evaluate method and that will calculate the perplexity per sentence rather than with a sliding window?

Update:

I tried the following, but got different results than the evaluate.compute method, so I'm not sure what I'm doing wrong:

import torch
from transformers import GPT2LMHeadModel, GPT2TokenizerFast
# Load the model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "gpt2-large"
model = GPT2LMHeadModel.from_pretrained(model_id).to(device)
tokenizer = GPT2TokenizerFast.from_pretrained(model_id)

# List of texts
texts = ["some fairly long text", "some fairly long text2", ...]

# Function to calculate perplexity
def calculate_perplexity(text):
    tokenized_text = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
    input_ids = tokenized_text.input_ids
    attention_mask = tokenized_text.attention_mask

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        logits = outputs.logits

    # Flatten the logits and input_ids
    logits = logits.view(-1, logits.shape[-1])
    input_ids = input_ids.view(-1)

    # Compute the cross-entropy loss
    loss = torch.nn.functional.cross_entropy(logits, input_ids, reduction='none')
    perplexity = torch.exp(torch.mean(loss))

    return perplexity.item()

# Calculate perplexity for each text
perplexities = [calculate_perplexity(text) for text in texts]

# Calculate average perplexity
average_perplexity = sum(perplexities) / len(perplexities)

print("Average Perplexity:", average_perplexity)
0

There are 0 best solutions below