How can I combine a py_function inside a map function?

49 Views Asked by At

I wanted to combine a py_function inside a map function, which took me a day, despite chatGPT's assistance.

Since resizing an image with tf.image has implementation differences in relate to openCVs, I wanted to keep using the optimized tf.Dataset with the .map API, but also combine the opencv.resize API.

1

There are 1 best solutions below

0
On

Here's what worked for me:

def resize_with_opencv_ver6(self, image):
    image = image.numpy()
    image = np.squeeze(image)
    target_shape = (self._target_shape[0], self._target_shape[1])
    resized = cv2.resize(image, target_shape, interpolation=cv2.INTER_NEAREST)
    resized = tf.expand_dims(resized, axis=-1)

    return resized

def resize_fn(self, image_path, image, label_index):
    im_shape = image.shape
    [image, ] = tf.py_function(self.resize_with_opencv_ver6, [image], [tf.uint8])
    image.set_shape(im_shape)
    return image_path, image, label_index

In general, the resize_fn is called from the tensorflow map API

dataset = dataset.map(self.resize_fn, num_parallel_calls=self._autotune)

A short explanation: The cv2.resize drops the channels dimension for grayscale images so you can also neglect the np.squeeze command and just stay with the tf.expand_dims to return the image as a tensor with the channels dimension. In addition, the image.shape and the image.set_shape just make sure that the channels' dimension is kept, but they aren't mandatory here.

Hope it will help others.