I've been researching all over for a tutorial/guide to load a civitAI model (https://civitai.com/models/4823/deliberate) into pytorch and then use it for Inference.
Most research leads to the following:
- Create your base model (which should have the same model Architecture as the model from where the checkpoint was save).
- Then loading the checkpoint using torch.load()
- torch.load_state_dict(loaded_checkpoint)
However, the models on civitai only have the ckpt file and nothing more. So cannot do step 1. I do know it's possible, because the GUI version AUTOMATIC1111 is able to do it.
PS. I do know that the same deliberate model is available on huggingface.co and can be downloaded like standard stable diffusion models, but i'm interested in working with the ckpt file alone and do it the way AUTO1111 does it.
model_id = "stabilityai/stable-diffusion-2-1"
model = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
model.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
# Load the Checkpoint File
ckpt_path = '/Users/XXXX/XXXX/model.ckpt'
checkpoint = torch.load(ckpt_path, map_location="cpu")
model.load_state_dict(checkpoint['state_dict'])
model.eval()
image = model(prompt='xxxxxx')
The requirement is that you install
diffusers
andpeft
(pip install peft diffusers
).Loading LoRA
.safetensors
You can use Hugging Face stable diffusion
load_lora_weight
method.For this demo, I am downloading this LoRA weight: Styles for Pony Diffusion V6 XL.
Load the pipeline (be careful to load the same version as the one the LoRA was performed on! Here XL-1.0), then load the LoRA weighs:
Loading checkpoint
.safetensors
If you are looking to load checkpoint instead but have access to a
.safetensors
then you can't rely onfrom_pretrained
. According to the documentation "Load safetensors", you need to rely onfrom_single_file
to initialize and load the pipeline:Here I am downloading the checkpoint from: Vapor - A Futuristic Retro Experience.
The following code was tested on Google collab:
Then you can load the checkpoint with a single line: