I have been trying to fine-tune the basic model for safety checker hosted at https://huggingface.co/CompVis/stable-diffusion-safety-checker which is mostly used by all the AI image gen models.
I have tried to make small changes to the main code in Diffusers library and try to get it to train. The final class looks like this right now
def cosine_distance(image_embeds, text_embeds):
normalized_image_embeds = nn.functional.normalize(image_embeds)
normalized_text_embeds = nn.functional.normalize(text_embeds)
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
class StableDiffusionSafetyChecker(PreTrainedModel):
config_class = CLIPConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPConfig):
super().__init__(config)
self.vision_model = CLIPVisionModel(config.vision_config)
self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False)
self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False)
self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False)
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
@torch.no_grad()
def forward(self, clip_input):
pooled_output = self.vision_model(clip_input)[1] # pooled_output
image_embeds = self.visual_projection(pooled_output)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) #.cpu().float().numpy()
cos_dist = cosine_distance(image_embeds, self.concept_embeds) #.cpu().float().numpy()
result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign images
adjustment = 0.0
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = torch.round(concept_cos - concept_threshold + adjustment, decimals=3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = torch.round(concept_cos - concept_threshold + adjustment, decimals=3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)
result.append(result_img)
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
has_nsfw_concepts = [('True' if flag else 'False') for flag in has_nsfw_concepts]
return 0, has_nsfw_concepts
and the code for loading data and training looks like this:
ds = load_dataset('training_data')
label = ds['train'].features['label']
feature_extractor_config = json.load(open('./safety_checker/preprocessor_config.json','r'))
feature_extractor = CLIPImageProcessor(**feature_extractor_config, padding=True)
# pipe_sc = StableDiffusionSafetyChecker.from_pretrained('./safety_checker')
def transform(example_batch):
# Take a list of PIL images and turn them to pixel values
inputs = feature_extractor([x for x in example_batch['image']], padding=True, return_tensors='pt') #.pixel_values
inputs['label'] = example_batch['label']
return inputs
prepared_ds = ds.with_transform(transform)
print(prepared_ds)
def collate_fn(batch):
return {
'clip_input': torch.stack([x['pixel_values'] for x in batch]),
# 'label': torch.tensor([x['label'] for x in batch])
}
metric = load_metric("accuracy")
def compute_metrics(p):
return metric.compute(predictions=np.argmax(p.predictions, axis=1), references=p.label_ids)
labels = ds['train'].features['label'].names
print(labels)
model = StableDiffusionSafetyChecker.from_pretrained(
'./safety_checker',
num_labels=len(labels),
id2label={str(i): c for i, c in enumerate(labels)},
label2id={c: str(i) for i, c in enumerate(labels)}
)
from transformers import TrainingArguments
training_args = TrainingArguments(
output_dir="./safety_checker-update",
per_device_train_batch_size=16,
evaluation_strategy="steps",
num_train_epochs=4,
save_steps=100,
eval_steps=100,
logging_steps=10,
learning_rate=2e-4,
save_total_limit=2,
remove_unused_columns=False,
push_to_hub=False,
report_to='tensorboard',
load_best_model_at_end=True,
)
from transformers import Trainer
os.environ["CUDA_VISIBLE_DEVICES"] = ""
trainer = Trainer(
model=model,
args=training_args,
data_collator=collate_fn,
compute_metrics=compute_metrics,
train_dataset=prepared_ds["train"],
eval_dataset=prepared_ds["validation"],
tokenizer=transform,
)
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()
But after Running the code I am facing this error
Traceback (most recent call last):
File "*****/stable_diff_sc_trainer.py", line 80, in <module>
train_results = trainer.train()
^^^^^^^^^^^^^^^
File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 1624, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 1961, in _inner_training_loop
tr_loss_step = self.training_step(model, inputs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/transformers/trainer.py", line 2911, in training_step
self.accelerator.backward(loss)
File "/home/ubuntu/anaconda3/envs/image/lib/python3.11/site-packages/accelerate/accelerator.py", line 2001, in backward
loss.backward(**kwargs)
^^^^^^^^^^^^^
AttributeError: 'float' object has no attribute 'backward'
0%| | 0/4 [00:07<?, ?it/s]
When I searched around on google, gpt and gemini, most of the sources said that I need to add the part of loss calculation for it to work. Can someone guide me on what I will need to do to be able to train this model?