Tail call optimization is a great feature. It’s obviously easier to solve recursive problems recursively rather than iteratively. It’s a shame it’s not available in Python (or many other languages), so let’s add it.
The following function will result in
RecursionError when its argument >= 1000
def my_fn(target_iters, iteration=0): if iteration >= target_iters: return iteration return my_fn(target_iters, iteration+1)
>>> my_fn(1000) ... RecursionError: maximum recursion depth exceeded in comparison
With some extra code, we can add tail call optimization
def my_fn2(my_fn2, target_iters, iteration=0): if iteration >= target_iters: return iteration return my_fn2(target_iters, iteration+1) optimized_fn = tailcall_optimize(my_fn2)
>>> optimized_fn(10000) 10000
Arguably a very contrived example, I admit. But not too bad - the only thing that changed is the signature of the function, everything else is the same.
How it works
Pretty straightforward actually! The key is - as you may already expect - in the
my_fn2 passed in as a function argument.
We’re making use of the fact that only the leaf returns any data - that means that all other functions leading to the leaf don’t need to maintain any state
Let’s look at some pseudocode
args = init_args while true: return_value = my_fn(lambda new_args: update_args(new_args), *args) if return_value is not None: return return_value
Here it also becomes obvious why only the leaf can return any data. If other nodes would return data while recursing we’d require a strategy to merge them back together in order to emulate the behavior we’d experience with a regular recursive function.
The Python code I used is the following
def tailcall_optimize(fn): def recursion_wrapper(*args, **kwargs): is_called = True result = None val = (args, kwargs) def fn_wrapper(*args, **kwargs): nonlocal val, is_called val = (args, kwargs) is_called = True while is_called: is_called = False result = fn(fn_wrapper, *val, **val) return result return recursion_wrapper
Instead of relying on the fact that the non-leaf node returns
None, I’m tracking if the fn_wrapper passed in is being called or not - but either works!
This certainly requires a few more checks to bring it into production, but conceptually this is all we need. This could also be easily ported to other languages.
This was all good fun, but let’s address the elephant is the room: How to make this work with a broader set of recursive functions.
Naturally I looked into that as well…