I have try to implement trampoline for a cps version of fibonacci function. But I can't make it fast (add cache) and support mutual_recursion.
The implement code:
import functools
from dataclasses import dataclass
from typing import Optional, Any, Callable
START = 0
CONTINUE = 1
CONTINUE_END = 2
RETURN = 3
@dataclass
class CTX:
kind: int
result: Any # TODO ......
f: Callable
args: Optional[list]
kwargs: Optional[dict]
def trampoline(f):
ctx = CTX(START, None, None, None, None)
@functools.wraps(f)
def decorator(*args, **kwargs):
nonlocal ctx
if ctx.kind in (CONTINUE, CONTINUE_END):
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
return
elif ctx.kind == START:
ctx.args = args
ctx.kwargs = kwargs
ctx.kind = CONTINUE
result = None
while ctx.kind != RETURN:
args = ctx.args
kwargs = ctx.kwargs
result = f(*args, **kwargs)
if ctx.kind == CONTINUE_END:
ctx.kind = RETURN
else:
ctx.kind = CONTINUE_END
return result
return decorator
Here is the runnable example.
@functools.lru_cache
def fib(n):
if n == 0:
return 1
elif n == 1:
return 1
else:
return fib(n - 1) + fib(n - 2)
@trampoline
def fib_cps(n, k):
if n == 0:
return k(1)
elif n == 1:
return k(1)
else:
return fib_cps(n - 1, lambda v1: fib_cps(n - 2, lambda v2: k(v1 + v2)))
def fib_cps_wrapper(n):
return fib_cps(n, lambda i:i)
@trampoline
def fib_tail(n, acc1=1, acc2=1):
if n < 2:
return acc1
else:
return fib_tail(n - 1, acc1 + acc2, acc1)
if __name__ == "__main__":
print(fib(100))
print(fib_tail(10000))
print(fib_cps_wrapper(40))
It is too slow to run the number 40
.
The fib
got maximum recursion depth exceeded when n
is bigger. But after add lru_cache
it will be fast. The iter trampolined version is ok for recursion depth and run very fast.
Here is some other people's work:
- support cps version cache: https://davywybiral.blogspot.com/2008/11/trampolining-for-recursion.html
- support mutual_recursion: https://github.com/0x65/trampoline But it is too hack for understand.
Looking at the links you have shared, there are a lot of interesting solutions. I was particularly inspired by this and changed a few things. Just a recap, you need a tail-recursive decorator that both caches results from previous executions of the function and supports mutual recursion (?). There is another interesting discussion about mutual recursion in a tail-recursion context that might help you understand the main problems.
I have written a decorator that does both caching and mutual-recursion: I think it can be further simplified/improved, but it works for the test samples I have chosen:
It seems pretty complex at first sight, but it reuses some of the concepts discussed in the links.
Initialization
Instead of having states like
START
,CONTINUE
andRETURN
, in this case I just need to differentiate between the_first_call
and the following ones. In fact, after the first time a function is called, the next calls return aTailRecArgument
that stores the parameters.f._cache
is the cache for that specific function.Tail-Recursion
How does this version of tail-recursion work? In the
while
loop, the function is continuously called with the new arguments returned after the first time the decorated function is called.When can I exit from the loop? Once the returned value is not of type
TailRecArguments
, which means that the last function call did not recursively call itself but returned an actual value. In that case, I just need to return the result and setf._first_call = True
. Unfortunately, it is a little more complex than that, because it would not work with mutual recursion. The fix here is to store inTailRecArguments
even the function called. In this way I can check if the arguments used for the next loop are for the same function (result.wrapped_func == f
) or for another tail-recursive one. In the latter case, I do not want to deal with those parameters because related to another function, instead I can return them since they will be surely executed in thewhile
loop of the first tail-recursive function encountered. The only downside is thatf._first_call
is reset every time the arguments belong to another function.Caching
Before commenting the caching mechanism (that is the very popular memoization technique), it is important to place caching code correctly: notice that I put it inside the
while
loop. It cannot be otherwise because only inside the while loop the function is continuously called and I can check for cache hits.I cheated a little for the
cache_key
creation because I have used an internal function of thefunctools
module. It is the one used by the@cache
decorator in that same module and you can extract the code withThere are other ways to create a cache key from
*args
and**kwargs
, like this one, that again points to the implementation of_make_key
. To make your code more stable, avoid to use private members of course.As I said, the rest is memoization, with an additional check:
if not isinstance(result, TailRecArguments): ...
. I want to cache values, not arguments of tail-recursive calls.(Actually, I think you could temporarily store all the
TailRecArguments
in a list and add as many entries in the cache as the size of this list when an actual value is returned by a recursive call. It will complicate the solution, but still acceptable if you have performance issues. This might raise some bugs in case of mutual recursion, I am going to work on it if requested).Testing
These are a few basic functions I have used to test the decorator:
Note that caching is not very useful in these examples, take the factorial for example:
fact(10)
is never going to usefact(8)
, in factfact(8)
fact(10)
The accumulator is part of the cached key, so you should change the caching strategy by customizing the parameters you want to cache (again, if needed I can propose a solution for that too).
UPDATE - Cache Optimization
Here is a partial fix to the cache strategy used in the original answer. The main issue is that including all parameters in the cache key is inefficient considering how a general tail-recursive algorithm works (see factorial example).
A first possible optimization would be letting the user choose which parameters are for the keys and which ones are for the values. It is much less readable because of type hints, but tests make everything a little more clear:
Aside from the
Logger
class that is used just to confirm cache hits, the main difference is that every function now has a new member called_initial_key
, that stores the key of the first call. In this way, if I callfact(5)
,5
becomes the_initial_key
and the result is put inf._cache[5]
.This can optimize both mutual recursive and tail-recursive functions, but ineffective in certain situations. Let's start from the best case:
The
@tail_recursive
decorator initialization include (the logger,) theget_cache_key
that specifies that only the first argumentn
should be part of the cache key andget_result_after_cache_hit
that specifies how to produce the final result after a cache hit. In the above case, whenfact(30)
reachesfact(5, <partial_factorial>)
, then the result is immediately computed as<partial_factorial> * f._cache[5]
.The same goes for
even-odd
, except that in this case the default arguments oftail_recursive
are more than enough:Unfortunately, this does not work with the fibonacci function for example. You should easily convince yourself of it by printing the arguments during each call, resulting in something like this:
Establishing the caching key rule needs a more complex logic that would probably make the
tail_recursive
decorator quite unreadable and less portable.