My code with numba is slower than without numba

80 Views Asked by At

My code with Numba is slower than without Numba, I don't know what's happening. Please instruct me.

#with numba
import numpy as np
from numba import njit, prange
import time

(X_train, Y_train), (X_val, Y_val) = tf.keras.datasets.cifar10.load_data()

@njit(nopython= True)
def cal_padding(input, padding):
    output = np.zeros((input.shape[0]+ padding*2, input.shape[1]+ padding*2))
    output[ padding :  padding + input.shape[0],  padding :  padding + input.shape[1]] = input
    return output

@njit(nopython = True)
def conv_per_channel(data, kernel, padding = 1, stride = 1): #input ở đây là từng channel trong ảnh, kernel cũng là từng channel trong bộ
    h, w = data.shape
    data = cal_padding(data, padding)
    
    kernel_size = kernel.shape
    h_output = int((h - kernel_size[0] + 2*padding + stride) / stride)
    w_output = int((w - kernel_size[1] + 2*padding + stride) / stride)
    
    tranformed = np.zeros((h_output, w_output), dtype=data.dtype)
    h_kernel, w_kernel = kernel_size
    xRows = []
    yCols = []
    for i in range(w_kernel):
      for j in range(h_kernel):
        yCols.append(j)
        xRows.append(i)

    for i in range(h_output):
      for j in prange(w_output):
        root_pixel = [i,j]
        new_val = 0.0
        for k in range(w_kernel * h_kernel):
          pixel_in_filter_x = xRows[k]
          pixel_in_filter_y = yCols[k]

          pixel_in_input_x = pixel_in_filter_x + root_pixel[0]
          pixel_in_input_y = pixel_in_filter_y + root_pixel[1]
          new_val = new_val + (data[pixel_in_input_x, pixel_in_input_y] * kernel[pixel_in_filter_x][pixel_in_filter_y])
        if new_val > 255:
          new_val = 255
        elif new_val < 0:
          new_val = 0
        tranformed[i,j] = new_val
      
    return tranformed

X_train = np.transpose(X_train, (0,3,1,2))
X_train500 = X_train[0:500]
print(X_train[0][0].shape)

sobelX = np.array([[-1,0,1], [-2,0,2],[-1,0,1]])
start = time.time()
conv_per_channel(X_train[0][0], sobelX)
end = time.time()
print(end-start)
#without numba
import numpy as np
import tensorflow as tf

from numba import jit, prange
import time

(X_train, Y_train), (X_val, Y_val) = tf.keras.datasets.cifar10.load_data()

# @jit(nopython= True)
def cal_padding(input, padding):
    output = np.zeros((input.shape[0]+ padding*2, input.shape[1]+ padding*2))
    output[ padding :  padding + input.shape[0],  padding :  padding + input.shape[1]] = input
    return output

# @jit(nopython = True)
def conv_per_channel(data, kernel, padding = 1, stride = 1): #input ở đây là từng channel trong ảnh, kernel cũng là từng channel trong bộ
    h, w = data.shape
    data = cal_padding(data, padding)
    
    kernel_size = kernel.shape
    h_output = int((h - kernel_size[0] + 2*padding + stride) / stride)
    w_output = int((w - kernel_size[1] + 2*padding + stride) / stride)
    
    tranformed = np.zeros((h_output, w_output), dtype=data.dtype)
    h_kernel, w_kernel = kernel_size
    xRows = []
    yCols = []
    for i in range(w_kernel):
      for j in range(h_kernel):
        yCols.append(j)
        xRows.append(i)

    for i in range(h_output):
      for j in prange(w_output):
        root_pixel = [i,j]
        new_val = 0.0
        for k in range(w_kernel * h_kernel):
          pixel_in_filter_x = xRows[k]
          pixel_in_filter_y = yCols[k]

          pixel_in_input_x = pixel_in_filter_x + root_pixel[0]
          pixel_in_input_y = pixel_in_filter_y + root_pixel[1]
          new_val = new_val + (data[pixel_in_input_x, pixel_in_input_y] * kernel[pixel_in_filter_x][pixel_in_filter_y])
        if new_val > 255:
          new_val = 255
        elif new_val < 0:
          new_val = 0
        tranformed[i,j] = new_val
      
    return tranformed

X_train = np.transpose(X_train, (0,3,1,2))
X_train500 = X_train[0:500]
print(X_train[0][0].shape)

sobelX = np.array([[-1,0,1], [-2,0,2],[-1,0,1]])
start = time.time()
conv_per_channel(X_train[0][0], sobelX)
end = time.time()
print(end-start)

With Numba @jit() decorator this code run slower ! I tried with (parallel=True) or (cache = True) but it's even slower.

Just to help understand better the purpose of this code:

#with numba
(32, 32)
0.6485073566436768

#without numba
(32, 32)
0.007578611373901367
1

There are 1 best solutions below

0
Nick ODell On

This is something which is going to confuse the Numba optimizer:

    xRows = []
    yCols = []
    for i in range(w_kernel):
      for j in range(h_kernel):
        yCols.append(j)
        xRows.append(i)
...
          pixel_in_filter_x = xRows[k]
          pixel_in_filter_y = yCols[k]

Lists in Python are very flexible and can hold multiple different types of objects. The problem with that, from a performance perspective, is that Numba has no idea what type pixel_in_filter_x has. It needs that information to generate efficient code. Numba is built around optimizing access to NumPy array, not lists.

Lastly, I found that the fastmath flag made this code about 5% faster.

Example:

@jit(nopython= True)
def cal_padding_nb(input, padding):
    output = np.zeros((input.shape[0]+ padding*2, input.shape[1]+ padding*2))
    output[ padding :  padding + input.shape[0],  padding :  padding + input.shape[1]] = input
    return output

@jit(nopython = True, fastmath=True)
def conv_per_channel_nb(data, kernel, padding = 1, stride = 1):
    h, w = data.shape
    data = cal_padding_nb(data, padding)
    
    kernel_size = kernel.shape
    h_output = int((h - kernel_size[0] + 2*padding + stride) / stride)
    w_output = int((w - kernel_size[1] + 2*padding + stride) / stride)
    
    tranformed = np.zeros((h_output, w_output), dtype=data.dtype)
    h_kernel, w_kernel = kernel_size

    for i in range(h_output):
      for j in prange(w_output):
        new_val = 0.0
        for a in range(w_kernel):
          for b in range(h_kernel):
            pixel_in_filter_x = a
            pixel_in_filter_y = b

            pixel_in_input_x = pixel_in_filter_x + i
            pixel_in_input_y = pixel_in_filter_y + j
            new_val = new_val + (data[pixel_in_input_x, pixel_in_input_y] * kernel[pixel_in_filter_x][pixel_in_filter_y])
        if new_val > 255:
          new_val = 255
        elif new_val < 0:
          new_val = 0
        tranformed[i,j] = new_val
      
    return tranformed

I find this is much faster than the original non-numba version.

Timings, ignoring first iteration of numba version:

Original, without numba: 11.3 ms ± 92.8 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
Numba, without lists: 31 µs ± 824 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

This is roughly 360 times faster.