Creating a custom TensorFlow dataset by subclassing tf.data.Dataset

607 Views Asked by At

Question: Is there a clean and straightforward way to create a custom dataset in TensorFlow by subclassing tf.data.Dataset, similar to the functionality available in PyTorch?

Details: I'm currently working on a project that involves training deep learning models using TensorFlow. In PyTorch, I found it convenient to create custom datasets by subclassing torch.utils.data.Dataset, which allowed me to encapsulate the data loading and preprocessing logic.

class FaceLandmarksDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, csv_file, root_dir, transform=None):
        ...

    def __len__(self):
        ...

    def __getitem__(self, idx):
        ...

However, in TensorFlow, I'm struggling to find a similar mechanism for creating custom datasets. I've been using the tf.data.Dataset API for handling data pipelines, but I haven't come across a way to subclass it and define my own custom dataset.

Is there a recommended approach in TensorFlow for achieving this? Ideally, I would like to have the flexibility to implement custom data loading, preprocessing, and augmentation logic within the dataset subclass, as it provides a clean and modular structure.

Any guidance or examples on how to create a custom dataset by subclassing tf.data.Dataset would be greatly appreciated. Thank you!

P.S. similar question has already been asked in Is there a proper way to subclass Tensorflow's Dataset? with no good answer.

1

There are 1 best solutions below

0
On

I don't know if this is exactly what you want but if you want to create custom dataset (simillary to datamodule in pytorch lightning) the simplest way is to subclass the tfds.dataset_builders.TfDataBuilder by overriding the __init__ method with your custom logic of data loading, preprocessing, and augmentation. See https://www.tensorflow.org/datasets/format_specific_dataset_builders#defining_a_new_dataset_builder_class. You can also add some metadata about your dataset, pass in a batch_size etc, and use build-in data loading functions such as tf.data.Dataset.from_tensor_slices.

If you want to have generator similar to torch.utils.data.Dataset you can subclass tf.keras.utils.Sequence (https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence) which has similar structure (the only major difference I can think of is that __getitem__ should return whole batch not just one element as in pytorch). You can then pass that object to model.fit().

So you can combine those 2 approaches to define loading data, preprocessing, split etc in __init__() of TfDataBuilder and then use Sequence class for defining the logic of loading a single batch.