Optimize pytorch data loader for reading small patches in full HD images

1.2k Views Asked by At

I'm training my neural network using PyTorch framework. The data is full HD images (1920x1080). But in each iteration, I just need to crop out a random 256x256 patch from these images. My network is relatively small (5 conv layers), and hence the bottleneck is being caused by loading the data. I've provided my current code below. Is there any way to optimize loading the data and speed up the training?

Code:

from pathlib import Path

import numpy
import skimage.io
import torch.utils.data as data

import Imath
import OpenEXR


class Ours(data.Dataset):
    """
    Loads patches of resolution 256x256. Patches are selected such that they contain atleast 1 unknown pixel
    """

    def __init__(self, data_dirpath, split_name, patch_size):
        super(Ours, self).__init__()
        self.dataroot = Path(data_dirpath) / split_name
        self.video_names = []
        for video_path in sorted(self.dataroot.iterdir()):
            for i in range(4):
                for j in range(11):
                    view_num = i * 12 + j
                    self.video_names.append((video_path.stem, view_num))
        self.patch_size = patch_size
        return

    def __getitem__(self, index):
        video_name, view_num = self.video_names[index]

        patch_start_pt = (numpy.random.randint(1080), numpy.random.randint(1920))

        frame1_path = self.dataroot / video_name / f'render/rgb/{view_num + 1:04}.png'
        frame2_path = self.dataroot / video_name / f'render/rgb/{view_num + 2:04}.png'
        depth_path = self.dataroot / video_name / f'render/depth/{view_num + 1:04}.exr'
        mask_path = self.dataroot / video_name / f'render/masks/{view_num + 1:04}.png'
        frame1 = self.get_image(frame1_path, patch_start_pt)
        frame2 = self.get_image(frame2_path, patch_start_pt)
        mask = self.get_mask(mask_path, patch_start_pt)
        depth = self.get_depth(depth_path, patch_start_pt, mask)

        data_dict = {
            'frame1': frame1,
            'frame2': frame2,
            'mask': mask,
            'depth': depth,
        }
        return data_dict

    def __len__(self):
        return len(self.video_names)

    @staticmethod
    def get_mask(path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        mask = skimage.io.imread(path.as_posix())[h:h + self.patch_size, w:w + self.patch_size][None]
        return mask

    def get_image(self, path: Path, patch_start_point: tuple):
        h, w = patch_start_point
        image = skimage.io.imread(path.as_posix())
        image = image[h:h + self.patch_size, w:w + self.patch_size, :3]
        image = image.astype(numpy.float32) / 255 * 2 - 1
        image_cf = numpy.moveaxis(image, [0, 1, 2], [1, 2, 0])
        return image_cf

    def get_depth(self, path: Path, patch_start_point: tuple, mask: numpy.ndarray):
        h, w = patch_start_point

        exrfile = OpenEXR.InputFile(path.as_posix())
        raw_bytes = exrfile.channel('B', Imath.PixelType(Imath.PixelType.FLOAT))
        depth_vector = numpy.frombuffer(raw_bytes, dtype=numpy.float32)
        height = exrfile.header()['displayWindow'].max.y + 1 - exrfile.header()['displayWindow'].min.y
        width = exrfile.header()['displayWindow'].max.x + 1 - exrfile.header()['displayWindow'].min.x
        depth = numpy.reshape(depth_vector, (height, width))

        depth = depth[h:h + self.patch_size, w:w + self.patch_size]
        depth = depth[None]
        depth = depth.astype(numpy.float32)
        depth = depth * mask
        return depth

Finally, I'm creating a DataLoader as follows:

train_data_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)

What I've tried so far:

  1. I've searched if it is possible to read a part of the image. Unfortunately, I didn't get any leads. Looks like python libraries read the full image.
  2. I'm planning to read more patches from a single image so that I will need to read fewer images. But in PyTorch framework, the get_item() function has to return a single sample, not a batch. So, in each get_item() I can read only a patch.
  3. I'm planning to circumvent this as follows: Read 4 patches in get_item() and return patches of shape (4,3,256,256) instead of (3,256,256). Later when I read a batch using dataloader, I'll get a batch of shape (BS,4,3,256,256) instead of (BS,3,256,256). I can then concatenate the data along dim=1 to convert (BS,4,3,256,256) to (BS*4,3,256,256). Thus I can reduce batch_size (BS) by 4 and hopefully this will speed up data loading by 4 times.

Are there any other options? I'm open to all kind of suggestions. Thanks!

0

There are 0 best solutions below