I'm working on a project using the diffusers library by Hugging Face, specifically with the Stable Diffusion XL model ("stabilityai/stable-diffusion-xl-base-1.0"). My goal is to manipulate the latents directly within a custom callback function during the diffusion process to visualize both the de-noised prediction and the slightly less noisy input for the next step.

Here's a simplified version of my code:

from diffusers import DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
import torch
from transformers import CLIPTextModel, CLIPTokenizer

pipeline = DiffusionPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16,
).to("cuda")

pipeline.scheduler = LMSDiscreteScheduler(
    beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)

pipeline.unet = UNet2DConditionModel.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", subfolder="unet", torch_dtype=torch.float16).to("cuda")

prompt = "A capybara holding a sword whilst wearing a knights costume,"

tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_input = tokenizer([prompt], padding="max_length",
                       max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")

text_encoder = CLIPTextModel.from_pretrained(
    "openai/clip-vit-large-patch14").to("cuda")

with torch.no_grad():
    text_embeddings = text_encoder(text_input.input_ids.to("cuda"))[0]

def decode_tensors(pipe, step, timestep, callback_kwargs):
    latents = callback_kwargs["latents"]
    latent_model_input = torch.cat([latents] * 2)
    sigma = pipeline.scheduler.sigmas[step]
    latent_model_input = pipeline.scheduler.scale_model_input(
        latent_model_input, timestep)

    with torch.no_grad():
        noise_pred = pipeline.unet(latent_model_input, timestep,
                                   encoder_hidden_states=text_embeddings)["sample"]

    return callback_kwargs

image = pipeline(
    height=1024,
    width=1024,
    prompt=prompt,
    negative_prompt="",
    guidance_scale=7.5,
    num_inference_steps=20,
    callback_on_step_end=decode_tensors,
    callback_on_step_end_tensor_inputs=["latents"],
).images[0]

image.save("./imgs/final.png")

However, when I run this script, I encounter the following error:

Traceback (most recent call last):
  File "C:\Users\myalt\Desktop\svd\sdxl.py", line 43, in <module>
    image = pipeline(
  File "C:\Users\myalt\Desktop\svd\venv\lib\site-packages\torch\utils\_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "C:\Users\myalt\Desktop\svd\venv\lib\site-packages\diffusers\pipelines\stable_diffusion_xl\pipeline_stable_diffusion_xl.py", line 1286, in __call__
    callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
  File "C:\Users\myalt\Desktop\svd\sdxl.py", line 37, in decode_tensors
    noise_pred = pipeline.unet(latent_model_input, timestep,
  File "C:\Users\myalt\Desktop\svd\venv\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "C:\Users\myalt\Desktop\svd\venv\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "C:\Users\myalt\Desktop\svd\venv\lib\site-packages\diffusers\models\unets\unet_2d_condition.py", line 1013, in forward
    if "text_embeds" not in added_cond_kwargs:
TypeError: argument of type 'NoneType' is not iterable

This error is thrown by the line in my decode_tensors callback function when calling the UNet model. It seems to relate to how I'm using or not using certain arguments or configurations with the UNet model or the scheduler. I've ensured that my environment is set up correctly, and I'm using the latest versions of the necessary libraries.

Could anyone help me understand what's causing this error and how to fix it? I'm particularly puzzled by how the encoder_hidden_states parameter is being handled, or if there's something else I'm missing in how I'm supposed to set up or call the UNet model within this context.

0

There are 0 best solutions below