How does `enforce_stop_tokens` work in LangChain with Huggingface models?

3k Views Asked by At

When we look at HuggingFaceHub model usage in langchain there's this part that the author doesn't know how to stop the generation, https://github.com/hwchase17/langchain/blob/master/langchain/llms/huggingface_pipeline.py#L182:

class HuggingFacePipeline(LLM):
        ...
    def _call(
        ...
        if stop is not None:
            # This is a bit hacky, but I can't figure out a better way to enforce
            # stop tokens when making calls to huggingface_hub.
            text = enforce_stop_tokens(text, stop)
        return text

What should I use to add the stop token to the end of the template?


If we look at https://github.com/hwchase17/langchain/blob/master/langchain/llms/utils.py, it's simply a regex split that split an input string up based on a list of stopwords, then take the first partition of the re.split

re.split("|".join(stop), text)[0]

Lets try to get a generation output from a Huggingface model, e.g.

from transformers import pipeline
from transformers import GPT2LMHeadModel, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
output = generator("Hey Pizza! ")
output

[out]:

[{'generated_text': 'Hey Pizza! 」\n\n「Hurry up, leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and then, Yuigahama came in contact with Ruriko in the middle of the'}]

If we apply the re.split:

import re
def enforce_stop_tokens(text, stop):
    """Cut off the text as soon as any stop words occur."""
    return re.split("|".join(stop), text)[0]

stop = ["up", "then"]
text = output[0]['generated_text']

re.split("|".join(stop), text)

[out]:

['Hey Pizza! 」\n\n「Hurry ',
 ', leave the place! 」\n\n「Oi! 」\n\nWhile eating pizza and ',
 ', Yuigahama came in contact with Ruriko in the middle of the']

But that isn't useful, I want to split at the point the generation ends. What tokens do I use to "enforce_stop_tokens"?

1

There are 1 best solutions below

2
On

You could do this by setting the eos_token_id as your stop term(s)-- in my testing it seemed to work with a list. See below: regex cuts off the stopword, eos_token_id cuts off just after the stopword ("once upon a time" vs. "once upon a")


from transformers import GPT2LMHeadModel, GPT2Tokenizer
import regex as re

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Define your custom stop terms
stop_terms = [ "right", "time"]

# Ensure the stop terms are in the tokenizer's vocabulary
for term in stop_terms:
    if term not in tokenizer.get_vocab():
        tokenizer.add_tokens([term])
        model.resize_token_embeddings(len(tokenizer))

def enforce_stop_tokens(text, stop):
    """Cut off the text as soon as any stop words occur."""
    return re.split("|".join(stop), text)[0]

# Get the token IDs for your custom stop terms
eos_token_ids_custom = [tokenizer.encode(term, add_prefix_space=True)[0] for term in stop_terms]

# Generate text
input_text = "Once upon "
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output_ids = model.generate(input_ids, eos_token_id=eos_token_ids_custom, max_length=50)

# Decode the output IDs to text
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

print(generated_text) # Once upon a time

print("ENFORCE STOP TOKENS")

truncated_text = enforce_stop_tokens(generated_text, stop_terms)

print(truncated_text) # Once upon a