How to set numba signature with nested lists?

234 Views Asked by At

I'm trying to return an nested list, however running into some conversion error. Below is small piece of code for reproduction of error.

from numba import njit, prange

@njit("ListType(ListType(ListType(int32)))(int32, int32)", fastmath = True, parallel = True, cache = True)
def test(x, y):
    a = []
    for i in prange(10):
        b = []
        for j in range(4):
            c = []
            for k in range(5):
                c.append(k)
            b.append(c)
        a.append(b)
    return a

Error enter image description here

3

There are 3 best solutions below

3
Kocas On

I try to avoid using empty lists with numba, mainly because an empty list cannot be typed. Check out nb.typeof([])

I am not sure whether your output can be preallocated but you could consider arrays. There would also be massive performance benefits. Here is an attempt:

from numba import njit, prange, int32
import numpy as np

@njit(int32[:,:,:](int32, int32), fastmath = True, parallel = True, cache = True)
def test(x, y):
    out = np.zeros((10,x,y), dtype=int32)
    for i in prange(10):
        for j in range(x):
            for k in range(y):
                out[i][j][k] = k
    return out

That said, you might indeed need lists for your application, in which case this answer might not be of much use.

1
Bhaskar Dhariyal On

This worked for me.


from numba import njit, prange
from numba.typed import List

@njit(fastmath = True, parallel = True, cache = True)
def test(x, y):
    a = List()
    for i in prange(10):
        b = List()
        for j in range(4):
            c = List()
            for k in range(5):
                c.append(k)
            b.append(c)
        a.append(b)
    return a
0
Rutger Kassies On

Your signature is fine, but you need to match the type of List that you create inside the function. So a numba.typed.List instead of [].

from numba import njit, prange
from numba.typed import List
from numba.types import int32

@njit("ListType(ListType(ListType(int32)))(int32, int32)", fastmath=True, parallel=True, cache=True)
def test(x, y):
    a = List.empty_list(List.empty_list(List.empty_list(int32)))
    for i in prange(10):
        b = List.empty_list(List.empty_list(int32))
        for j in range(4):
            c = List.empty_list(int32)
            for k in range(5):
                c.append(int32(k))
            b.append(c)
        a.append(b)
    return a

I don't think you should expect much from appending to a List in parallel in this case.