How to speed up the trampolined cps version fib function and support mutual recursion in python?

221 Views Asked by At

I have try to implement trampoline for a cps version of fibonacci function. But I can't make it fast (add cache) and support mutual_recursion.

The implement code:

import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable

START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3


@dataclass
class CTX:
    kind: int
    result: Any    # TODO ......
    f: Callable
    args: Optional[list]
    kwargs: Optional[dict]


def trampoline(f):
    ctx = CTX(START, None, None, None, None)

    @functools.wraps(f)
    def decorator(*args, **kwargs):
        nonlocal ctx
        if ctx.kind in (CONTINUE, CONTINUE_END):
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE
            return
        elif ctx.kind == START:
            ctx.args = args
            ctx.kwargs = kwargs
            ctx.kind = CONTINUE

        result = None
        while ctx.kind != RETURN:
            args = ctx.args
            kwargs = ctx.kwargs
            result = f(*args, **kwargs)
            if ctx.kind == CONTINUE_END:
                ctx.kind = RETURN
            else:
                ctx.kind = CONTINUE_END

        return result

    return decorator

Here is the runnable example.

@functools.lru_cache
def fib(n):
    if n == 0:
        return 1
    elif n == 1:
        return 1
    else:
        return fib(n - 1) + fib(n - 2)

@trampoline
def fib_cps(n, k):
    if n == 0:
        return k(1)
    elif n == 1:
        return k(1)
    else:
        return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))

def fib_cps_wrapper(n):
    return fib_cps(n, lambda i:i)


@trampoline
def fib_tail(n, acc1=1, acc2=1):
    if n < 2:
        return acc1
    else:
        return fib_tail(n - 1, acc1 + acc2, acc1)


if __name__ == "__main__":
    print(fib(100))
    print(fib_tail(10000))
    print(fib_cps_wrapper(40))

It is too slow to run the number 40. The fib got maximum recursion depth exceeded when n is bigger. But after add lru_cache it will be fast. The iter trampolined version is ok for recursion depth and run very fast.

Here is some other people's work:

  1. support cps version cache: https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
  2. support mutual_recursion: https://github.com/0x65/trampoline But it is too hack for understand.
1

There are 1 best solutions below

2
On BEST ANSWER

Looking at the links you have shared, there are a lot of interesting solutions. I was particularly inspired by this and changed a few things. Just a recap, you need a tail-recursive decorator that both caches results from previous executions of the function and supports mutual recursion (?). There is another interesting discussion about mutual recursion in a tail-recursion context that might help you understand the main problems.


I have written a decorator that does both caching and mutual-recursion: I think it can be further simplified/improved, but it works for the test samples I have chosen:

from collections import namedtuple
import functools

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
def tail_recursive(f):
    f._first_call = True
    f._cache = {}

    @functools.wraps(f)
    def wrapper(*args, **kwargs):
        if f._first_call:
            f._new_args = args
            f._new_kwargs = kwargs
        
            try:
                f._first_call = False
                while True:
                    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
                    if cache_key in f._cache:
                        return f._cache[cache_key]

                    result = f(*f._new_args, **f._new_kwargs)

                    if not isinstance(result, TailRecArguments):
                        f._cache[cache_key] = result

                    if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                        f._new_args = result.args
                        f._new_kwargs = result.kwargs
                    else:
                        break

                return result
            finally:
                f._first_call = True
        else:
            return TailRecArguments(f, args, kwargs)

    return wrapper

It seems pretty complex at first sight, but it reuses some of the concepts discussed in the links.


Initialization

f._first_call = True
f._cache = {}

Instead of having states like START, CONTINUE and RETURN, in this case I just need to differentiate between the _first_call and the following ones. In fact, after the first time a function is called, the next calls return a TailRecArgument that stores the parameters.

f._cache is the cache for that specific function.


Tail-Recursion

if f._first_call:
    f._new_args = args
    f._new_kwargs = kwargs

    try:
        f._first_call = False
        while True:
            result = f(*f._new_args, **f._new_kwargs)

            if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                f._new_args = result.args
                f._new_kwargs = result.kwargs
            else:
                break

        return result
    finally:
        f._first_call = True
else:
    return TailRecArguments(f, args, kwargs)

How does this version of tail-recursion work? In the while loop, the function is continuously called with the new arguments returned after the first time the decorated function is called.

When can I exit from the loop? Once the returned value is not of type TailRecArguments, which means that the last function call did not recursively call itself but returned an actual value. In that case, I just need to return the result and set f._first_call = True. Unfortunately, it is a little more complex than that, because it would not work with mutual recursion. The fix here is to store in TailRecArguments even the function called. In this way I can check if the arguments used for the next loop are for the same function (result.wrapped_func == f) or for another tail-recursive one. In the latter case, I do not want to deal with those parameters because related to another function, instead I can return them since they will be surely executed in the while loop of the first tail-recursive function encountered. The only downside is that f._first_call is reset every time the arguments belong to another function.


Caching

while True:
    cache_key = functools._make_key(f._new_args, f._new_kwargs, False)
    if cache_key in f._cache:
        return f._cache[cache_key]

    result = f(*f._new_args, **f._new_kwargs)

    if not isinstance(result, TailRecArguments):
        f._cache[cache_key] = result

Before commenting the caching mechanism (that is the very popular memoization technique), it is important to place caching code correctly: notice that I put it inside the while loop. It cannot be otherwise because only inside the while loop the function is continuously called and I can check for cache hits.

I cheated a little for the cache_key creation because I have used an internal function of the functools module. It is the one used by the @cache decorator in that same module and you can extract the code with

import inspect
import functools
print(inspect.getsource(functools._make_key))

There are other ways to create a cache key from *args and **kwargs, like this one, that again points to the implementation of _make_key. To make your code more stable, avoid to use private members of course.

As I said, the rest is memoization, with an additional check: if not isinstance(result, TailRecArguments): .... I want to cache values, not arguments of tail-recursive calls.

(Actually, I think you could temporarily store all the TailRecArguments in a list and add as many entries in the cache as the size of this list when an actual value is returned by a recursive call. It will complicate the solution, but still acceptable if you have performance issues. This might raise some bugs in case of mutual recursion, I am going to work on it if requested).


Testing

These are a few basic functions I have used to test the decorator:

@tail_recursive
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    """
    return True if n == 0 else odd(n - 1)

@tail_recursive
def odd(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> odd(100)
    False
    >>> odd(101)
    True
    """
    return False if n == 0 else even(n - 1)

@tail_recursive
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(30)
    265252859812191058636308480000000
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

@tail_recursive
def fib(n, a = 0, b = 1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(20)
    >>> fib(30)
    832040
    """
    return a if n == 0 else b if n == 1 else fib(n - 1, b, a + b)

if __name__ == '__main__':
    import doctest
    doctest.testmod()

Note that caching is not very useful in these examples, take the factorial for example: fact(10) is never going to use fact(8), in fact

fact(8) fact(10)
fact(10, 1)
fact(9, 10)
fact(8, 1) fact(8, 90)
... ...

The accumulator is part of the cached key, so you should change the caching strategy by customizing the parameters you want to cache (again, if needed I can propose a solution for that too).


UPDATE - Cache Optimization

Here is a partial fix to the cache strategy used in the original answer. The main issue is that including all parameters in the cache key is inefficient considering how a general tail-recursive algorithm works (see factorial example).

A first possible optimization would be letting the user choose which parameters are for the keys and which ones are for the values. It is much less readable because of type hints, but tests make everything a little more clear:

class Logger:
    def __init__(self, name):
        self._name = name
        self._entries = []
    
    def log(self, s):
        self._entries.append(s)

    def print(self):
        log_prefix = f"[{self._name}] - "
        print(log_prefix + f"\n{log_prefix}".join(self._entries))

TailRecArguments = namedtuple('TailRecArguments', ['wrapped_func', 'args', 'kwargs'])
default_logger = Logger('default')
def tail_recursive(logger: Logger = default_logger, \
        get_cache_key: Callable[[Iterable, Dict], Hashable] = lambda args, kwargs: \
            functools._make_key(args, kwargs, False),\
        get_result_after_cache_hit: Callable[[Any, Iterable, Dict], Any] = lambda value, args, kwargs: \
            value):
    def decorator(f):
        f._first_call = True
        f._cache = {}

        @functools.wraps(f)
        def wrapper(*args, **kwargs):
            if f._first_call:
                f._new_args = args
                f._new_kwargs = kwargs
            
                try:
                    f._first_call = False
                    f._initial_key = get_cache_key(f._new_args, f._new_kwargs)
                    while True:
                        cache_key = get_cache_key(f._new_args, f._new_kwargs)
                        if cache_key in f._cache:
                            logger.log('cache hit for ' + str(cache_key))
                            return get_result_after_cache_hit(f._cache[cache_key], f._new_args, f._new_kwargs)

                        result = f(*f._new_args, **f._new_kwargs)

                        if not isinstance(result, TailRecArguments):
                            f._cache[f._initial_key] = result

                        if isinstance(result, TailRecArguments) and result.wrapped_func == f:
                            f._new_args = result.args
                            f._new_kwargs = result.kwargs
                        else:
                            break

                    return result
                finally:
                    f._first_call = True
            else:
                return TailRecArguments(f, args, kwargs)

        return wrapper
    return decorator

Aside from the Logger class that is used just to confirm cache hits, the main difference is that every function now has a new member called _initial_key, that stores the key of the first call. In this way, if I call fact(5), 5 becomes the _initial_key and the result is put in f._cache[5].

This can optimize both mutual recursive and tail-recursive functions, but ineffective in certain situations. Let's start from the best case:

fact_logger = Logger('fact')
@tail_recursive(logger=fact_logger, get_cache_key=lambda args, kwargs: args[0],\
    get_result_after_cache_hit=lambda value, args, kwargs: value * args[1])
def fact(n, acc=1):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> fact(5)
    120
    >>> fact(30)
    265252859812191058636308480000000
    >>> fact_logger.print()
    [fact] - cache hit for 5
    """
    return acc if n <= 1 else fact(n - 1, acc * n)

The @tail_recursive decorator initialization include (the logger,) the get_cache_key that specifies that only the first argument n should be part of the cache key and get_result_after_cache_hit that specifies how to produce the final result after a cache hit. In the above case, when fact(30) reaches fact(5, <partial_factorial>), then the result is immediately computed as <partial_factorial> * f._cache[5].

The same goes for even-odd, except that in this case the default arguments of tail_recursive are more than enough:

even_logger = Logger('even')
@tail_recursive(logger=even_logger)
def even(n):
    """
    >>> import sys
    >>> sys.setrecursionlimit(30)
    >>> even(100)
    True
    >>> even(101)
    False
    >>> even(104)
    True
    >>> even_logger.print()
    [even] - cache hit for 100
    """
    return True if n == 0 else odd(n - 1)

Unfortunately, this does not work with the fibonacci function for example. You should easily convince yourself of it by printing the arguments during each call, resulting in something like this:

30 0 1
29 1 1
28 1 2
27 2 3
26 3 5
25 5 8
...

Establishing the caching key rule needs a more complex logic that would probably make the tail_recursive decorator quite unreadable and less portable.