I accessed a Llama-based model on Huggingface named: "LeoLM/leo-hessianai-7b-chat". I downloaded the model on my Mac with the device set as 'MPS'. The download worked, however when I want to test the model I get following error:
TypeError: BFloat16 is not supported on MPS
Above I see the hint:
FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.
Here is my code:
from torch import cuda, bfloat16
import transformers
device = torch.device("mps")
model_id = 'LeoLM/leo-hessianai-7b-chat'
#device = f'cuda:{cuda.current_device()}' if cuda.is_available() else 'cpu'
# set quantization configuration to load large model with less GPU memory
# this requires the `bitsandbytes` library
bnb_config = transformers.BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type='nf4',
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=bfloat16
)
# begin initializing HF items, need auth token for these
hf_auth = 'HF_KEY'
model_config = transformers.AutoConfig.from_pretrained(
model_id,
use_auth_token=hf_auth
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
trust_remote_code=False, # True for flash attention
config=model_config,
quantization_config=bnb_config,
device_map='auto',
use_auth_token=hf_auth
)
model.eval()
print(f"Model loaded on {device}")
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_id,
use_auth_token=hf_auth
)
generate_text = transformers.pipeline(
model=model, tokenizer=tokenizer,
return_full_text=True, # langchain expects the full text
task='text-generation',
# we pass model parameters here too
temperature=0.0, # 'randomness' of outputs, 0.0 is the min and 1.0 the max
max_new_tokens=512, # mex number of tokens to generate in the output
repetition_penalty=1.1 # without this output begins repeating
)
res = generate_text("Explain the difference between a country and a continent.")
print(res[0]["generated_text"])
What do i need to change to make it run?
Use the nightly build of pytorch.
Official source: https://pytorch.org/get-started/locally/