How can I extract object segmentations from the coco dataset?

467 Views Asked by At

From the MSCOCO dataset segmentation annotations, how can I extract just the segmented objects themselves? For example, given an image of a person standing with a house in the background, how can I extract just the person themselves?

1

There are 1 best solutions below

0
On

If your data is already in FiftyOne, then you can write a simple function using OpenCV and Numpy to crop the segmentations in your FiftyOne labels. It could look something like this:

import os

import cv2
import numpy as np

import fiftyone as fo
import fiftyone.zoo as foz
from fiftyone import ViewField as F

def extract_classwise_instances(samples, output_dir, label_field, ext=".png"):
    print("Extracted object instances...")
    for sample in samples.iter_samples(progress=True):
        img = cv2.imread(sample.filepath)
        img_h,img_w,c = img.shape
        for det in sample[label_field].detections:
            mask = det.mask
            [x,y,w,h] = det.bounding_box
            x = int(x * img_w)
            y = int(y * img_h)
            h, w = mask.shape
            mask_img = img[y:y+h, x:x+w, :] 
            alpha = mask.astype(np.uint8)*255
            alpha = np.expand_dims(alpha, 2)
            mask_img = np.concatenate((mask_img, alpha), axis=2)
    
            label = det.label
            label_dir = os.path.join(output_dir, label)
            if not os.path.exists(label_dir):
                os.mkdir(label_dir)
            output_filepath = os.path.join(label_dir, det.id+ext)
            cv2.imwrite(output_filepath, mask_img)

label_field = "ground_truth"
classes = ["person"]

dataset = foz.load_zoo_dataset(
    "coco-2017",
    split="validation",
    label_types=["segmentations"],
    classes=classes,
    max_samples=20,
    label_field=label_field,
    dataset_name=fo.get_default_dataset_name(),
)

view = dataset.filter_labels(label_field, F("label").is_in(classes))

output_dir = "/tmp/coco-segmentations"
os.makedirs(output_dir, exist_ok=True)

extract_classwise_instances(view, output_dir, label_field)