I have 2 lists of ints: A and B. How can I efficiently ensure that for every element in A there exists at least one element in B such that when they are bit-anded, answer is that element in A.

Eg. B = [14, 13 ]

So for elements in A: a1 = 12 is valid as 12&14 = 12 a2 = 3 is not valid as 3&14 = 2 and 3&13 = 1

A and B can have 10^10 entries.

I am reading B as a column from csv and generating A on fly. And checking for above condition using

df.applymap(lambda x: (x&a) == a ).any().any()

But as size of B increases this check is my bottleneck.

1

There are 1 best solutions below

20
Andrej Kesely On

If speed is a concert you can try to use :

from numba import njit


@njit
def check(a, b):
    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                return True
    return False


a = np.array([12, 3], dtype=np.uint8)
b = np.array([14, 13], dtype=np.uint8)

print(check(a, b))

Prints:

True

Benchmark:

from statistics import median
from timeit import repeat

np.random.seed(42)


def setup():
    a = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
    b = np.random.randint(0, 255, size=10_000_000, dtype=np.uint8)
    return a, b


t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=1000, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

On my computer (AMD 5700x) this prints:

t=0.00000496

EDIT: For np.uint64 values from 0 to MAX(np.unit64):

def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

Prints:

t=0.04230742

EDIT 2: Sorting the first 1000 items from B array according bit-count:

from statistics import median
from timeit import repeat

np.random.seed(42)


# https://stackoverflow.com/a/68943135/10035985
@njit
def bit_count(arr):
    # Make the values type-agnostic (as long as it's integers)
    t = arr.dtype.type
    mask = t(-1)
    s55 = t(0x5555555555555555 & mask)  # Add more digits for 128bit support
    s33 = t(0x3333333333333333 & mask)
    s0F = t(0x0F0F0F0F0F0F0F0F & mask)
    s01 = t(0x0101010101010101 & mask)

    arr = arr - ((arr >> np.uint8(1)) & s55)
    arr = (arr & s33) + ((arr >> np.uint8(2)) & s33)
    arr = (arr + (arr >> np.uint8(4))) & s0F
    return (arr * s01) >> np.uint16((8 * (arr.itemsize - 1)))


@njit
def check(a, b):
    b[:1000] = b[np.argsort(bit_count(b[:1000]))[::-1]]

    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                return True
    return False


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

Prints:

t=0.04198640

EDIT3: If you want to check every element from A to array B:

@njit
def check(a, b):
    for val_a in a:
        for val_b in b:
            if (val_b & val_a) == val_a:
                break
        else:
            return False
    return True


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=10_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat(
    "check(a, b)", setup="a, b = setup()", repeat=50, number=1, globals=globals()
)

print(f"t={median(t):.8f}")

Prints:

t=0.00415215

EDIT4: Parralel version that creates mask:

from statistics import median
from timeit import repeat

import numpy as np
from numba import njit, prange

@njit(parallel=True)
def check(a, b):
    out = np.ones_like(a, dtype=np.uint8)

    for i in prange(len(a)):
        val_a = a[i]

        for val_b in b:
            if (val_b & val_a) == val_a:
                break
        else:
            out[i] = 0

    return out


def setup():
    a = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
    b = np.random.randint(0, np.iinfo(np.uint64).max, size=1_000_000, dtype=np.uint64)
    return a, b


check(np.array([1, 2, 3], dtype=np.uint64), np.array([1, 2, 3], dtype=np.uint64))

t = repeat("check(a, b)", setup="a, b = setup()", repeat=1, number=1, globals=globals())

print(f"t={median(t):.8f}")

Prints:

t=79.79640480