Fine Tuning StableDiffusionSafetyChecker model

21 Views Asked by At

I have been trying to fine-tune the basic model for safety checker hosted at https://huggingface.co/CompVis/stable-diffusion-safety-checker which is mostly used by all the AI image gen models.

I have tried to make small changes to the main code in Diffusers library and try to get it to train. The final class looks like this right now

def cosine_distance(image_embeds, text_embeds):
    normalized_image_embeds = nn.functional.normalize(image_embeds)
    normalized_text_embeds = nn.functional.normalize(text_embeds)
    return torch.mm(normalized_image_embeds, normalized_text_embeds.t())

class StableDiffusionSafetyChecker(PreTrainedModel):
    config_class = CLIPConfig

    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPConfig):
        super().__init__(config)

        self.vision_model = CLIPVisionModel(config.vision_config)
        self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)

        self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
        self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)

        self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
        self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)

    @torch.no_grad()
    def forward(self, clip_input):
        pooled_output = self.vision_model(clip_input)[1]  # pooled_output
        image_embeds = self.visual_projection(pooled_output)

        # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
        special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) #.cpu().float().numpy()
        cos_dist = cosine_distance(image_embeds, self.concept_embeds) #.cpu().float().numpy()

        result = []
        batch_size = image_embeds.shape[0]
        for i in range(batch_size):
            result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

            # increase this value to create a stronger `nfsw` filter
            # at the cost of increasing the possibility of filtering benign images
            adjustment = 0.0

            for concept_idx in range(len(special_cos_dist[0])):
                concept_cos = special_cos_dist[i][concept_idx]
                concept_threshold = self.special_care_embeds_weights[concept_idx].item()
                result_img["special_scores"][concept_idx] = torch.round(concept_cos - concept_threshold + adjustment, decimals=3)
                if result_img["special_scores"][concept_idx] > 0:
                    result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
                    adjustment = 0.01

            for concept_idx in range(len(cos_dist[0])):
                concept_cos = cos_dist[i][concept_idx]
                concept_threshold = self.concept_embeds_weights[concept_idx].item()
                result_img["concept_scores"][concept_idx] = torch.round(concept_cos - concept_threshold + adjustment, decimals=3)
                if result_img["concept_scores"][concept_idx] > 0:
                    result_img["bad_concepts"].append(concept_idx)

            result.append(result_img)

        has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]

        has_nsfw_concepts = [('True' if flag else 'False') for flag in has_nsfw_concepts]
        return 0, has_nsfw_concepts

and the code for loading data and training looks like this:

ds = load_dataset('training_data')

label = ds['train'].features['label']

feature_extractor_config = json.load(open('./safety_checker/preprocessor_config.json','r'))
feature_extractor = CLIPImageProcessor(**feature_extractor_config, padding=True)
# pipe_sc = StableDiffusionSafetyChecker.from_pretrained('./safety_checker')

def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([x for x in example_batch['image']], padding=True, return_tensors='pt') #.pixel_values

    inputs['label'] = example_batch['label']
    return inputs

prepared_ds = ds.with_transform(transform)
print(prepared_ds)

def collate_fn(batch):
    return {
        'clip_input': torch.stack([x['pixel_values'] for x in batch]),
        # 'label': torch.tensor([x['label'] for x in batch])
    }

metric = load_metric("accuracy")
def compute_metrics(p):
    return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)

labels = ds['train'].features['label'].names
print(labels)

model = StableDiffusionSafetyChecker.from_pretrained(
    './safety_checker',
    num_labels=len(labels),
    id2label={str(i): c for i, c in enumerate(labels)},
    label2id={c: str(i) for i, c in enumerate(labels)}
)

from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="./safety_checker-update",
  per_device_train_batch_size=16,
  evaluation_strategy="steps",
  num_train_epochs=4,
  save_steps=100,
  eval_steps=100,
  logging_steps=10,
  learning_rate=2e-4,
  save_total_limit=2,
  remove_unused_columns=False,
  push_to_hub=False,
  report_to='tensorboard',
  load_best_model_at_end=True,
)

from transformers import Trainer
os.environ["CUDA_VISIBLE_DEVICES"] = ""

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
    train_dataset=prepared_ds["train"],
    eval_dataset=prepared_ds["validation"],
    tokenizer=transform,
)

train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

But after Running the code I am facing this error

Traceback (most recent call last):
  File "*****/stable_diff_sc_trainer.py", line 80, in <module>
    train_results = trainer.train()
                    ^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 2911, in training_step
    self.accelerator.backward(loss)
  File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/accelerate/accelerator.py", line 2001, in backward
    loss.backward(**kwargs)
    ^^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'backward'
  0%|          | 0/4 [00:07<?, ?it/s] 

When I searched around on google, gpt and gemini, most of the sources said that I need to add the part of loss calculation for it to work. Can someone guide me on what I will need to do to be able to train this model?

0

There are 0 best solutions below