Tail call optimization in Python
I've been using Common Lisp recently, and one thing that using a Lisp language always makes me appreciate is Tail call optimization.
It's a bit of shame it's not available in other languages, so let's see if it can be added into python.
TL;DR
The following function will result in RecursionError
when its argument >= 1000
def my_fn(target_iters, iteration=0): if iteration >= target_iters: return iteration return my_fn(target_iters, iteration+1)
>>> my_fn(1000) ... RecursionError: maximum recursion depth exceeded in comparison
With some extra code, we can add tail call optimization
def my_fn2(my_fn2, target_iters, iteration=0): if iteration >= target_iters: return iteration return my_fn2(target_iters, iteration+1) optimized_fn = tailcall_optimize(my_fn2)
>>> optimized_fn(10000) 10000
Arguably a very contrived example, I admit. But not too bad - the only thing that changed is the signature of the function, everything else is the same.
How it works
Pretty straightforward actually! The key is - as you may already expect - in the my_fn2
passed in as a function argument.
We're making use of the fact that only the leaf returns any data - that means that all other functions leading to the leaf don't need to maintain any state
Let's look at some pseudocode
args = init_args while true: return_value = my_fn(lambda new_args: update_args(new_args), *args) if return_value is not None: return return_value
Here it also becomes obvious why only the leaf can return any data. If other nodes would return data while recursing we'd require a strategy to merge them back together in order to emulate the behavior we'd experience with a regular recursive function.
The Python code I used is the following
def tailcall_optimize(fn): def recursion_wrapper(*args, **kwargs): is_called = True result = None val = (args, kwargs) def fn_wrapper(*args, **kwargs): nonlocal val, is_called val = (args, kwargs) is_called = True while is_called: is_called = False result = fn(fn_wrapper, *val[0], **val[1]) return result return recursion_wrapper
Instead of relying on the fact that the non-leaf node returns None
, I'm tracking if the fnwrapper passed in is being called or not - but either works!