Deep Learning - generate patches for 3D multimodal data

743 Views Asked by At

I have chosen the problem of Semantic Segmentation of Brain Tumors using Deep Learning. I am using the BRATS2015 dataset. It has 274 patient MRI scans each of size 240x240x155. There are four modalities for each patient (T1, T2, T1c, FLAIR). So I am using the modalities as channels in the network.

In an ideal world, the input to my 3D UNet network can have the shape (Batch_size, 240, 240, 155, 4) in channels_last mode. But the graphic cards are obviously not equipped to handle data of this size. Hence, I need to convert my MRI scan into patches.

Here's where I am confused. It is relatively easy to get patches for single-channel 3D data. We have many libraries and helper functions for that. The issue I am facing is to generate patches for multimodal data i.e. a 3D data with channels.

  • I have thought of the idea to generate patches for each channel separately and concat the final result but I believe I might lose some multichannel information if I process it separately instead of directly generating patches for multimodal data.

I have looked at the patchify library where we can use the following to generate patches

from patchify import patchify, unpatchify

#This will split the image into small images of shape [3,3,3]
patches = patchify(image, (3, 3, 3), step=1)

reconstructed_image = unpatchify(patches, image.shape)

But I am not sure how to generate multimodal patches. Is there a way to do this with patchify or any other library/helper function?

1

There are 1 best solutions below

0
On

You might like to patch your dataset by view_as_blocks. The documentation says

Block view of the input n-dimensional array (using re-striding). Blocks are non-overlapping views of the input array.

Here is a dummy example.

import numpy as np
from skimage.util.shape import view_as_blocks


# batch_size:3
# height:4
# width:6
# depth:8
# channels:4
arr = np.arange(3*4*6*8*4).reshape(3,4,6,8,4)


patches = view_as_blocks(arr,block_shape=(1,2,2,2,4))

print(patches.shape) # (3, 2, 3, 4, 1, 1, 2, 2, 2, 4)

# Print the first patch
print(patches[0,0,0,0,0,:,:,:,:,:])

# [[[[[  0   1   2   3]
#     [  4   5   6   7]]

#    [[ 32  33  34  35]
#     [ 36  37  38  39]]]


#   [[[192 193 194 195]
#     [196 197 198 199]]

#    [[224 225 226 227]
#     [228 229 230 231]]]]]

patches.shape might look confusing, here is a short explanation. Last 5 numbers are for the block shape (1,2,2,2,4), and first 5 numbers stand for the number of blocks for the corresponding dimensions. Or, simply, arr.shape/block_shape would give you (3, 2, 3, 4, 1).

A few points to pay attention:

  1. Each dimension (of block_shape) must divide evenly into the corresponding dimensions of arr_in.

To do so, you might pad your image first according to your block_shape.

  1. patches will take more storage than arr. So it would make sense to save all the patches to your disk, and then feed them to your model one by one.