I'm using LLama-2 13B with the following stopping criteria:
stop_words = ["Human:", "Chatbot:", "###"]
stop_words_ids = [tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])
generation_config = GenerationConfig( ... stopping_criteria=stopping_criteria )
prompt = tokenizer(text, return_tensors='pt', truncation="only_first", max_length=4096)
prompt = {key: value.to("cuda") for key, value in prompt.items()}
out = model.generate(**prompt, generation_config=generation_config)
res = tokenizer.decode(out[0])
The model does not stop at the provided stop words. For example, if I have a response of the model I'm feeling good, how about you?### Human: I'm also feeling good.### Chatbot: That's good.
the model should stop generating at the first ###
.
Why does this not work and how can this be fixed?
I have fine-tuned the model (with Axolotl) on a dataset so that the model produces responses as shown above.
implementing working stopping criteria is unfortunately quite a bit more complicated, I'll explain the technical details at the bottom. You have to make a child class of StoppingCriteria and reimplement the logic of it's
__call__()
function, this is not done for you and it can be implemented in many different ways.Luckily, there's some code I was able to piece together that seems to work well.
I wrap this up into a simple interfacing function, because it will look kind of gross if you keep piling stuff sequentially:
Then, to use this function, you can pass in a list of words you wish the model to stop on:
Unfortunately, choosing your stop words passed also needs to be adjusted and tested. Tokenization is weird, and while you probably didn't fine-tune the tokenizer, some things may still be tokenized differently depending on how it is placed within the sentence. For example, "###", " ###", and "### " may all be different tokens depending on how they are placed in the sentence, and you may have to pass all of them into your
stop_words_list
. Spaces or newlines or even other characters before or after each of your stop words can make it into an entirely different token. Experiment with a few and see what works!