Failing to optimize unet model with dtype=torch.float16

25 Views Asked by At

I'm trying to use null text inversion with stable diffusion 1.5(or sdxl 1.0) with data_type=torch.float16, and get "black image". It works well on float32. After debugging, I find the grad turn to zero in unet layers due to it's float16 dtype. This is the grad of unet first up_block in float32: enter image description here This is the grad of unet first up_block in float16: enter image description here The grad is to small so that some turns to zero, causing the following layers smaller and smaller, and finally to zero.

I tried to enlarge the loss by a factor of 10(or 1e3, 1e4), or lr to smaller, but all failed. How can i do to solve it? I want to use it with torch.float16 o(╥﹏╥)o

0

There are 0 best solutions below