Pytorch dataset - len(train_dataset) returns zero

62 Views Asked by At

I am trying to create a custom dataset and dataloader in pytorch, to finetune a DONUT model. For context, my dataset is organised as follows:

dataset/
├── train/
│   ├── image1.jpg
│   ├── image2.jpg
│   ├── metadata.jsonl
│   └── ...
├── validation/
│   ├── image1.jpg
│   ├── image2.jpg
│   ├── metadata.jsonl
│   └── ...
└── ...

I have already written my custom dataset class:

class DonutOCRDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.transform = transform
        self.data = self.load_data()
    
    def load_data(self):
        data = []
        for folder in os.listdir(self.root_dir):  # Use self.root_dir here
            folder_path = os.path.join(self.root_dir, folder)
            if os.path.isdir(folder_path):
                metadata_path = os.path.join(folder_path, 'metadata.jsonl')
                with open(metadata_path, 'r') as f:
                    metadata = [json.loads(line) for line in f]
                data.extend([(item["file_name"], item["ground_truth"]) for item in metadata])
        return data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        img_path_full = os.path.join(self.root_dir, img_path)
        print(f"Loading image: {img_path_full}")
        img = Image.open(img_path_full).convert('RGB')

        if self.transform:
            img = self.transform(img)

        return img, label

and below I tried to define my transformations, instantiate dataset and dataloader:

# Define root_dir
root_dir = r'C:\Users\Company\Documents\.....\240111_donut_1\dataset'


# Define your transformation
transform = transforms.Compose([
    transforms.Resize((640, 460)),
    transforms.ToTensor(),
])

# Instantiate the dataset
train_dataset = DonutOCRDataset(os.path.join(root_dir, 'train'), transform=transform)
val_dataset = DonutOCRDataset(os.path.join(root_dir, 'validation'), transform=transform)

batch_size = 8
train_dataloader = DataLoader(train_dataset, batch_size, shuffle=False)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)

However, i later found out when i printed len(train_dataset) and len(val_dataset), they both return 0.

Anyone knows whats wrong with my code?

1

There are 1 best solutions below

0
On

In The DonutOCRDataset object for training, the root_dir that you pass is dataset/train. Then, in load_data(), you are looking for subdirectories within this directory (at if os.path.isdir(folder_path)), which do not seem to exist in your directory structure. So the if condition is probably never satisfied, and self.data in your dataset object remains an empty list, giving a length of zero. Removing the for loop within load_data() should resolve the issue.