How to debug ValueError: `FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16?

51 Views Asked by At

I'm trying to do Pytorch Lightning Fabric distributed FSDP training with Huggingface PEFT LORA fine tuning on LLAMA 2 but my code ends up failing with:

`FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16
  File ".......", line 100, in <module>
    model, optimizer = fabric.setup(model, optimizer)
ValueError: `FlatParameter` requires uniform dtype but got torch.float32 and torch.bfloat16

How do I find out which tensors in pytorch fabric are of float32 type?

1

There are 1 best solutions below

0
JobHunter69 On