Tail call optimization in Python

I've been using Common Lisp recently, and one thing that using a Lisp language always makes me appreciate is Tail call optimization.

It's a bit of shame it's not available in other languages, so let's see if it can be added into python.

TL;DR

The following function will result in RecursionError when its argument >= 1000

def my_fn(target_iters, iteration=0):
  if iteration >= target_iters:
    return iteration

  return my_fn(target_iters, iteration+1)
>>> my_fn(1000)
...
RecursionError: maximum recursion depth exceeded in comparison

With some extra code, we can add tail call optimization

def my_fn2(my_fn2, target_iters, iteration=0):
  if iteration >= target_iters:
    return iteration

  return my_fn2(target_iters, iteration+1)

optimized_fn = tailcall_optimize(my_fn2)
>>> optimized_fn(10000)
10000

Arguably a very contrived example, I admit. But not too bad - the only thing that changed is the signature of the function, everything else is the same.

How it works

Pretty straightforward actually! The key is - as you may already expect - in the my_fn2 passed in as a function argument.

We're making use of the fact that only the leaf returns any data - that means that all other functions leading to the leaf don't need to maintain any state

Let's look at some pseudocode

args = init_args
while true:
  return_value = my_fn(lambda new_args: update_args(new_args), *args)
  if return_value is not None:
    return return_value

Here it also becomes obvious why only the leaf can return any data. If other nodes would return data while recursing we'd require a strategy to merge them back together in order to emulate the behavior we'd experience with a regular recursive function.

The Python code I used is the following

def tailcall_optimize(fn):
  def recursion_wrapper(*args, **kwargs):
    is_called = True
    result = None
    val = (args, kwargs)

    def fn_wrapper(*args, **kwargs):
      nonlocal val, is_called

      val = (args, kwargs)
      is_called = True


    while is_called:
      is_called = False
      result = fn(fn_wrapper, *val[0], **val[1])
    return result

  return recursion_wrapper

Instead of relying on the fact that the non-leaf node returns None, I'm tracking if the fnwrapper passed in is being called or not - but either works!