Can this algorithm be reversed in better than O(N^2) time?

183 Views Asked by At

Say I've got the following algorithm to biject bytes -> Natural:

def rank(s: bytes) -> int:
    k = 2**8
    result = 0
    offset = 0
    for i, w in enumerate(s):
        result *= k
        result += w
        offset += (k**i)
    return result + offset

The decoder (at least to the best of my limited abilities) is as follows:

def unrank(value: int) -> bytes:
    k = 2**8
    # 1. Get length
    import itertools
    offset = 0
    for length in itertools.count():  #! LOOP RUNS O(N) TIMES !#
        offset += (k**length)  #! LONG ADDITION IS O(N) !#
        if offset > value:
            value = value - (offset - k**length)
            break
    # 2. Get value
    result = bytearray(length)
    for i in reversed(range(length)):
        value, result[i] = divmod(value, k)  # (Can be done with bit shifts, ignore for complexity)
    return bytes(result)

Letting N ≈ len(bytes) ≈ log(int), this decoder clearly has a worst-case runtime of O(N^2). Granted, it performs well (<2s runtime) for practical cases (≤32KiB of data), but I'm still curious if it's fundamentally possible to beat that into something that swells less as the inputs get bigger.


# Example / test cases:

assert rank(b"") == 0
assert rank(b"\x00") == 1
assert rank(b"\x01") == 2
...
assert rank(b"\xFF") == 256
assert rank(b"\x00\x00") == 257
assert rank(b"\x00\x01") == 258
...
assert rank(b"\xFF\xFF") == 65792
assert rank(b"\x00\x00\x00") == 65793

assert unrank(0) == b""
assert unrank(1) == b"\x00"
assert unrank(2) == b"\x01"
# ...
assert unrank(256) == b"\xFF"
assert unrank(257) == b"\x00\x00"
assert unrank(258) == b"\x00\x01"
# ...
assert unrank(65792) == b"\xFF\xFF"
assert unrank(65793) == b"\x00\x00\x00"

assert unrank(2**48+1) == b"\xFE\xFE\xFE\xFE\xFF\x00"
3

There are 3 best solutions below

21
Matt Timmermans On BEST ANSWER

It's clearer if you write your rank function like this:

def rank(s: bytes) -> int:
    k = 2**8
    result = 0
    for w in s:
        result *= k
        result += w + 1
    return result

... that you can write unrank like this:

def unrank(value: int) -> bytes:
    k = 2**8
    ret = bytearray(0)
    while value > 0:
        value -= 1
        value, digit = divmod(value, k)
        ret.append(digit)
    ret.reverse()
    return bytes(ret)

(thanks for providing test cases)

The above version of unrank is still quadratic due to the costs of operations on long integers, so here is a less readable version of the same algorithm that is actually O(n) in python:

def unrank(value: int) -> bytes:
    ret = bytearray(value.to_bytes(value.bit_length()//8 + 1, 'little'))
    for i in range(len(ret)):
        byte = ret[i]
        # subtract 1
        j = i
        while byte == 0:
            ret[j] = 255
            j+=1
            if j >= len(ret):
                break
            byte = ret[j]
        if byte == 0:
            # borrow went off the end
            ret = ret[:i]
            break
        ret[j] = byte-1
    ret.reverse()
    return bytes(ret)
0
n. m. could be an AI On

The reversal algorithm in plain English is like this:

  1. Write down the hexadecimal representation of the number; if the number of digits is odd, prepend a leading zero.
  2. Group the digits in pairs.
  3. For each pair viewed as a number 0 to 255, starting from the rightmost (least significant) one:
    • Subtract 1 from the pair modulo 256
    • If there is an overflow, carry over to the pair to the left; if there is no pair to the left, remove the pair that caused the overflow. (You only ever need to subtract 1 or 2 at each step).

It is clear that the complexity is O(N) (N being the number of digits) provided you have access to the individual bits of the input.

2
Kelly Bundy On

A faster linear solution:

def unrank(value: int) -> bytes:
    N = value.bit_length() // 8 + 1
    offset = int.from_bytes(b'\1' * N)
    if offset > value:
        N -= 1
        offset >>= 8
    return (value - offset).to_bytes(N)

Times for unranking with random 10000 bytes:

1433.748 ms  unrank_original
 192.317 ms  unrank_Matt
   1.877 ms  unrank_Matt2
   0.049 ms  unrank_Kelly

Benchmark code, includes some correctness checks (Attempt This Online!):

from itertools import product
import os
from time import perf_counter as time


def rank(s: bytes) -> int:
    k = 2**8
    result = 0
    offset = 0
    for i, w in enumerate(s):
        result *= k
        result += w
        offset += (k**i)
    return result + offset


def unrank_original(value: int) -> bytes:
    k = 2**8
    # 1. Get length
    import itertools
    offset = 0
    for length in itertools.count():  #! LOOP RUNS O(N) TIMES !#
        offset += (k**length)  #! LONG ADDITION IS O(N) !#
        if offset > value:
            value = value - (offset - k**length)
            break
    # 2. Get value
    result = bytearray(length)
    for i in reversed(range(length)):
        value, result[i] = divmod(value, k)  # (Can be done with bit shifts, ignore for complexity)
    return bytes(result)


def unrank_Kelly(value: int) -> bytes:
    N = value.bit_length() // 8 + 1
    offset = int.from_bytes(b'\1' * N)
    if offset > value:
        N -= 1
        offset >>= 8
    return (value - offset).to_bytes(N)


def unrank_Matt(value: int) -> bytes:
    k = 2**8
    ret = bytearray(0)
    while value > 0:
        value -= 1
        value, digit = divmod(value, k)
        ret.append(digit)
    ret.reverse()
    return bytes(ret)


def unrank_Matt2(value: int) -> bytes:
    ret = bytearray(value.to_bytes(value.bit_length()//8 + 1, 'little'))
    for i in range(len(ret)):
        byte = ret[i]
        # subtract 1
        j = i
        while byte == 0:
            ret[j] = 255
            j+=1
            if j >= len(ret):
                break
            byte = ret[j]
        if byte == 0:
            # borrow went off the end
            ret = ret[:i]
            break
        ret[j] = byte-1
    ret.reverse()
    return bytes(ret)


funcs = unrank_original, unrank_Matt, unrank_Matt2, unrank_Kelly

# Correctness
def check(s):
    r = rank(s)
    for f in funcs:
        assert f(r) == s
for N in range(6):
    for p in product([*range(5), *range(251, 256)], repeat=N):
        check(bytes(p))
check(os.urandom(1000))

# Speed
r = rank(os.urandom(10000))
for f in funcs:
    t0 = time()
    f(r)
    t = time() - t0
    print(f'{t*1e3:8.3f} ms ', f.__name__)