How to augment 2 specific classes on face mask dataset for yolov8

54 Views Asked by At

Dataset

I used this dataset to make a face mask detection using Yolov8 to train the model. It is working nicely with 82% accuracy however the dataset is very imbalanced. These are the total for each classes: Total 'without_mask' labels: 717 Total 'mask_weared_incorrect' labels: 123 Total 'with_mask' labels: 3232 I want to know if there is a way to specifically augment the without_mask and mask_weared_incorrect classes.

I tried this code but it doesn't work.

import os
import cv2
import xml.etree.ElementTree as ET
import pandas as pd
import numpy as np
import albumentations as A
from typing import List, Sequence, Tuple

# Define a function to normalize bounding boxes
def normalize_bbox(bbox, rows, cols):
    x_min, y_min, x_max, y_max = bbox

    # Ensure x and y coordinates are within the valid range [0.0, 1.0]
    x_min = max(0, min(x_min / cols, 1.0))
    y_min = max(0, min(y_min / rows, 1.0))
    x_max = max(0, min(x_max / cols, 1.0))
    y_max = max(0, min(y_max / rows, 1.0))

    # Ensure x_max is greater than x_min and y_max is greater than y_min
    x_max = max(x_min, min(x_max, 1.0))
    y_max = max(y_min, min(y_max, 1.0))

    # Swap values if x_max is less than x_min or y_max is less than y_min
    if x_max < x_min:
        x_min, x_max = x_max, x_min
    if y_max < y_min:
        y_min, y_max = y_max, y_min

    return [x_min, y_min, x_max, y_max]

# Define augmentation transformations for without_mask and mask_weared_incorrect classes
augmentation_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.5),  # Increase the probability for more aggressive augmentation
    A.Rotate(limit=30, p=0.5),  # Introduce rotation with a probability of 0.5
    # Add more augmentations as needed

    # Specific augmentations for 'without_mask' and 'mask_weared_incorrect' classes
    A.OneOf([
        A.Blur(blur_limit=3, p=0.5),  # Apply blur with a probability of 0.5
        A.RandomSnow(p=0.5),  # Simulate snow on the image with a probability of 0.5
    ], p=0.5),

    A.OneOf([
        A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=True),  # Sharpen the image
        A.RandomFog(p=1),  # Simulate fog
    ], p=0.5),
], bbox_params=A.BboxParams(format='pascal_voc', min_area=0, min_visibility=0, label_fields=['class_labels']))

# Initialize an empty list to store augmented data
augmented_data = []

# Loop through each annotation file in the 'annotations' list
for annotation_path in annotations:
    # Parse the XML file using ElementTree
    tree = ET.parse(annotation_path)
    root = tree.getroot()
    
    # Extract the filename from the XML file
    filename = root.find('filename').text
    
    # Loop through each object in the annotation
    for obj in root.findall("object"):
        # Extract the label of the object
        label = obj.find("name").text
        
        # Extract bounding box coordinates (xmin, ymin, xmax, ymax)
        bbox = []
        bndbox_tree = obj.find('bndbox')
        bbox.append(int(bndbox_tree.find('xmin').text))
        bbox.append(int(bndbox_tree.find('ymin').text))
        bbox.append(int(bndbox_tree.find('xmax').text))
        bbox.append(int(bndbox_tree.find('ymax').text))
        
        # Apply augmentation only for 'without_mask' and 'mask_weared_incorrect' classes
        if label in ['without_mask', 'mask_weared_incorrect']:
            # Define the augmentation transform with increased probability
            augmentation_transform = A.Compose([
                A.HorizontalFlip(p=0.5),
                A.RandomBrightnessContrast(p=0.5),  # Increase the probability for more aggressive augmentation
                A.Rotate(limit=30, p=0.5),  # Introduce rotation with a probability of 0.5
                # Add more augmentations as needed
                A.OneOf([
                    A.Blur(blur_limit=3, p=0.5),  # Apply blur with a probability of 0.5
                    A.RandomSnow(p=0.5),  # Simulate snow on the image with a probability of 0.5
                ], p=0.5),
                A.OneOf([
                    A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=True),  # Sharpen the image
                    A.RandomFog(p=1),  # Simulate fog
                ], p=0.5),
            ], bbox_params=A.BboxParams(format='pascal_voc', min_area=0, min_visibility=0, label_fields=['class_labels']))
            
            # Apply the defined augmentation transform
            augmented = augmentation_transform(image=cv2.imread(os.path.join(image_directory, filename)),
                                               bboxes=[bbox],
                                               class_labels=[class_id[label]])
            
            # Normalize the augmented bounding box
            normalized_bbox = normalize_bbox(augmented['bboxes'][0], int(size.find('height').text), int(size.find('width').text))
            
            # Update augmented data list
            augmented_data.append({
                'filename': filename,
                'width': int(size.find('width').text),
                'height': int(size.find('height').text),
                'label': label,
                'class_id': class_id[label],
                'bboxes': normalized_bbox
            })
        
        # Append original data to the data dictionary
        data_dict['filename'].append(filename)
        data_dict['width'].append(int(size.find('width').text))
        data_dict['height'].append(int(size.find('height').text))
        data_dict['label'].append(label)
        data_dict['class_id'].append(class_id[label])
        data_dict['bboxes'].append(bbox)

# Print augmented data list
print(f"Augmented Data List: {augmented_data}")

# Create a DataFrame from the data dictionary
df_data = pd.DataFrame(data_dict)

# Create a DataFrame from the augmented data list
df_augmented_data = pd.DataFrame(augmented_data)

# Concatenate original and augmented data
df_data = pd.concat([df_data, df_augmented_data], ignore_index=True)

# Check the new distribution
print(f"Total 'without_mask' labels: {sum(df_data.label == 'without_mask')}")
print(f"Total 'mask_weared_incorrect' labels: {sum(df_data.label == 'mask_weared_incorrect')}")
print(f"Total 'with_mask' labels: {sum(df_data.label == 'with_mask')}")

0

There are 0 best solutions below