NVIDIA DALI - Load specific frames of a video with nvidia.dali.readers.video

124 Views Asked by At

I am trying to use NVIDIA DALI to read frames sequences from a video dataset (as Pytorch Tensors). Not all parts of videos are useful, and I have a database storing info on what frames to get from a video - and I don't see how to tell DALI to load these frames specifically.

I have the following minimal example so far:

import random
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.pipeline import Pipeline


class DataIterator:
    """Iterator that will load sequences info from the Database and output filenames and frames to read."""
    def __init__(self):
        self.database = MyDatabase(...)
        self.dataset_length = 1000

    def __iter__(self):
        return self

    def __next__(self):
        # Get random index from my dataset
        index = random.randint(0, self.dataset_length - 1)

        # Load data associated to that index from the DB
        filenames, frames = self.database.get_sequence(index)
        # Do other stuff, like getting labels, omitted from here
        return [filenames], [frames]

    def __len__(self):
        return self.dataset_length

    next = __next__


def create_pipeline():
    data_iterator = DataIterator()
    pipeline = Pipeline()

    with pipeline:
        filenames, frames = nvidia.dali.fn.external_source(source=data_iterator, num_outputs=2)
        # TODO: Here I want to load 'frames' from the video 'filenames'.
        # Let's say frames 21 to 31 from video "video1.mp4", 18 to 28 from video "video2.mp4"...
        sequences = nvidia.dali.fn.readers.video(...)
        sequences = nvidia.dali.fn.do_other_stuff(sequences)
        pipeline.set_outputs(sequences, filenames, frames)
    return pipeline


if __name__ == "__main__":
    pipeline = create_pipeline()

    iterator = DALIGenericIterator(pipeline, ["sequences", "filenames", "frames"])

    for epoch in range(3):
        for i, batch in enumerate(iterator):
            # Do stuff with my sequences, filenames, and frames
            pass
        iterator.reset()

Question is: How do I use my nvidia.dali.readers.video to fetch exactly the frames I need?

I saw a lot of examples (like these) fetching random sequences from a set of videos, but it's really unclear to me how I can use it and control which sequences are loaded.

Thanks in advance for any help! Best,

0

There are 0 best solutions below