How can I traverse a 3D matrix on a thread per row basis in Numba?

289 Views Asked by At

I’m trying to implement a runs-based kernel using Numba CUDA where I need to traverse the elements of a 3D matrix on a row per thread basis i.e. each thread is assigned a row that iterates over all elements of that row.

For example, if, for simplicity, I were to use a 2D matrix with 50 rows and 100 columns, I would need to create 50 threads that would go through the 100 elements of their respective row.

Can someone tell me how to do this?

1

There are 1 best solutions below

0
On BEST ANSWER

Turns out it’s actually quite simple. You only need to launch as many threads as rows and have the kernel “point” its direction. Here’s a simple kernel that demonstrates how to do such an iteration over a 3D matrix (binary_image). The kernel itself is part of the CCL algorithm I’m implementing but that can safely be ignored:

from numba import cuda

@cuda.jit
def kernel_1(binary_image, image_width, s_matrix, labels_matrix):
    # notice how we're only getting the row and depth of each thread
    row, image_slice = cuda.grid(2) 
    sm_pos, lm_pos = 0, 0
    span_found = False
    if row < binary_image.shape[0] and image_slice < binary_image.shape[2]:  # guard for rows and slices
        # and here's the traversing over the columns
        for column in range(binary_image.shape[1]):
            if binary_image[row, column, image_slice] == 0:
                if not span_found:  # Connected Component found
                    span_found = True
                    s_matrix[row, sm_pos, image_slice] = column
                    sm_pos = sm_pos + 1
                    # converting 2D coordinate to 1D
                    linearized_index = row * image_width + column
                    labels_matrix[row, lm_pos, image_slice] = linearized_index
                    lm_pos = lm_pos + 1
                else:
                    s_matrix[row, sm_pos, image_slice] = column
            elif binary_image[row, column, image_slice] == 255 and span_found:
                span_found = False
                s_matrix[row, sm_pos, image_slice] = column - 1
                sm_pos = sm_pos + 1