Send all tensors to CUDA in Torch

93 Views Asked by At

I'm trying to load a fine-tuned llama-2 model for text generation. As it can be seen below, the tokenizer and model are loaded using the transformers library.

device = 'cuda:3'

tokenizer = transformers.AutoTokenizer.from_pretrained('llama-2-7b-chat-fine-tuned.bin')

stop_list = ['\nHuman:', '\n```\n']
stop_token_ids = [tokenizer(x)['input_ids'] for x in stop_list]
stop_token_ids

stop_token_ids = [torch.LongTensor(x).to(device) for x in stop_token_ids]
stop_token_ids

class StopOnTokens(StoppingCriteria):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
        for stop_ids in stop_token_ids:
            if torch.eq(input_ids[0][-len(stop_ids.to(device)):], stop_ids.to(device)).all():
                return True
        return False

stopping_criteria = StoppingCriteriaList([StopOnTokens()])

generate_text = transformers.pipeline(model='llama-2-7b-chat-fine-tuned.bin',
                                      tokenizer=tokenizer,
                                      return_full_text=True,
                                      task='text-generation',
                                      stopping_criteria=stopping_criteria,
                                      temperature=0.1,
                                      max_new_tokens=512,
                                      repetition_penalty=1.1)


res = generate_text("How are you doing?")
print(res[0]["generated_text"])

However, I get the following error:

Traceback (most recent call last):
  File "test_fine_tuned_llamas_2.py", line 75, in <module>
    res = generate_text("How are you doing?")
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/pipelines/text_generation.py", line 200, in __call__
    return super().__call__(text_inputs, **kwargs)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1122, in __call__
    return self.run_single(inputs, preprocess_params, forward_params, postprocess_params)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1129, in run_single
    model_outputs = self.forward(model_inputs, **forward_params)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/pipelines/base.py", line 1028, in forward
    model_outputs = self._forward(model_inputs, **forward_params)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/pipelines/text_generation.py", line 261, in _forward
    generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/generation/utils.py", line 1538, in generate
    return self.greedy_search(
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/generation/utils.py", line 2423, in greedy_search
    if stopping_criteria(input_ids, scores):
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py", line 127, in __call__
    return any(criteria(input_ids, scores) for criteria in self)
  File "/sharedvolume/felipe/project1/test_project1/lib/python3.8/site-packages/transformers/generation/stopping_criteria.py", line 127, in <genexpr>
    return any(criteria(input_ids, scores) for criteria in self)
  File "test_fine_tuned_llamas_2.py", line 59, in __call__
    if torch.eq(input_ids[0][-len(stop_ids.to(device)):], stop_ids.to(device)).all():
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cpu!

What tensors am I missing? I understand I'm already sending all tensors to CUDA but seems some are still using CPU.

0

There are 0 best solutions below