Adding LoRA to Checkpoint Model

711 Views Asked by At

I'm still new to the world of Stable Diffusion. I'm trying to LoRA weights to an original model. The approach I came up with is this to get my job done:

for i in range(0, 12): 
 model_model[f"cond_stage_model.transformer.text_model.encoder.layers.{i}.mlp.fc1.weight"] +=
 model_lora[f"lora_te_text_model_encoder_layers_{i}_mlp_fc1.alpha"] *
torch.matmul( 
model_lora[f"lora_te_text_model_encoder_layers_{i}_mlp_fc1.lora_up.weight"].float(), 
model_lora[f"lora_te_text_model_encoder_layers_{i}_mlp_fc1.lora_down.weight"].float()
)

However, I'm unable to map these LoRA keys to the checkpoint weights as I'm not able to find a similar key in the original model:

lora_unet_down_blocks_0_attentions_0_proj_in.alpha
lora_unet_down_blocks_0_attentions_0_proj_in.lora_down.weight
lora_unet_down_blocks_0_attentions_0_proj_in.lora_up.weight

The closest I came to finding a key similar to the above keys are these:

model.diffusion_model.input_blocks.0.0.weight
model.diffusion_model.input_blocks.1.1.proj_in.weight

I have made a Stable Diffusion program using the internet and I want to be able to add LoRAs to it in order to create more realistic images for a project purpose. I code be wrong with the way I add LoRAs.

1

There are 1 best solutions below

2
On

First of all, it's quite tiring to add the lora into sd checkpoint with the original source code, which means that you need to spend too much time on it.

And it is BETTER to use Diffusers to get the work done, and it's easy to use.

If u really want to add lora in this way, u can take the content I record while im debugging my own code below as reference. Basically, the lora weight contans 2 part, Textencoder, Unet.

Every transformer block in the checkpoint u use should add the correspponding block in lora.

TextEncoder

before adding lora:

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

after

CLIPTextModel(
  (text_model): CLIPTextTransformer(
    (embeddings): CLIPTextEmbeddings(
      (token_embedding): Embedding(49408, 768)
      (position_embedding): Embedding(77, 768)
    )
    (encoder): CLIPEncoder(
      (layers): ModuleList(
        (0-11): 12 x CLIPEncoderLayer(
          (self_attn): CLIPAttention(
            (k_proj): PatchedLoraProjection(
              (regular_linear_layer): Linear(in_features=768, out_features=768, bias=True)
              (lora_linear_layer): LoRALinearLayer(
                (down): Linear(in_features=768, out_features=4, bias=False)
                (up): Linear(in_features=4, out_features=768, bias=False)
              )
            )
            (v_proj): PatchedLoraProjection(
              (regular_linear_layer): Linear(in_features=768, out_features=768, bias=True)
              (lora_linear_layer): LoRALinearLayer(
                (down): Linear(in_features=768, out_features=4, bias=False)
                (up): Linear(in_features=4, out_features=768, bias=False)
              )
            )
            (q_proj): PatchedLoraProjection(
              (regular_linear_layer): Linear(in_features=768, out_features=768, bias=True)
              (lora_linear_layer): LoRALinearLayer(
                (down): Linear(in_features=768, out_features=4, bias=False)
                (up): Linear(in_features=4, out_features=768, bias=False)
              )
            )
            (out_proj): PatchedLoraProjection(
              (regular_linear_layer): Linear(in_features=768, out_features=768, bias=True)
              (lora_linear_layer): LoRALinearLayer(
                (down): Linear(in_features=768, out_features=4, bias=False)
                (up): Linear(in_features=4, out_features=768, bias=False)
              )
            )
          )
          (layer_norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): CLIPMLP(
            (activation_fn): QuickGELUActivation()
            (fc1): Linear(in_features=768, out_features=3072, bias=True)
            (fc2): Linear(in_features=3072, out_features=768, bias=True)
          )
          (layer_norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
    (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
)

Unet

This aprt is quiet large, so I only give u some examples explain what exactly the lora has.

Lora weight's UNet part