Memoization and Anonymous Recursion

Keith Farmer brought it to my attention that there is at least a little confusion about how closures work.  Hopefully, I can help shed a little light on the subject.  The question is why doesn't the following code actually memoize fib in the call to Test?

Func<int, int> fib = null;
fib = n => n > 1 ? fib(n - 1) + fib(n - 2) : n;
Test(fib.Memoize());

Before I explain why this code doesn't work, I want to return to the example that I used in my last post.

Func<int, int> fib = null;
fib = n => n > 1 ? fib(n - 1) + fib(n - 2) : n;
Func<int, int> fibCopy = fib;
Console.WriteLine(fib(6)); // displays 8
Console.WriteLine(fibCopy(6)); // displays 8
fib = n => n * 2;
Console.WriteLine(fib(6)); // displays 12
Console.WriteLine(fibCopy(6)); // displays 18

Probably the easiest way to see why this code behaves so strangely is to show what code the C# compiler generates.  This can easily be done with ildasm.exe or reflector.  There are two lambdas defined in the code.  The first one captures the local variable named fib and the second does not capture any locals.  Because fib is capture by a lambda, it must be hoisted on to the heap.  So a class is declared that contains the captured local.

class DisplayClass
{
public Func<int, int> fib;
}

Now the compiler must emit a method that corresponds to the first lambda.  This method will need access to the members of DisplayClass and must also be convertible to a Func<int,int> delegate.  Today the compiler achieves this by emitting the method on the display class as an instance method (we could also have used currying but that is a story for another day).

class DisplayClass
{
public Func<int, int> fib;
public int M1(int n)
{
return n > 1 ? fib(n - 1) + fib(n - 2) : n; // <-- 1st lambda body
}
}

The second lambda does not capture any local variables.  So this method can be emitted as a static method.

static int M2(int n)
{
return n * 2;
}

Finally, we are ready to show the emitted code for the original code fragment.

DisplayClass locals = new DisplayClass();
locals.fib = null;
locals.fib = locals.M1;
Func<int, int> fibCopy = locals.fib;
Console.WriteLine(locals.fib(6)); // displays 8
Console.WriteLine(fibCopy(6)); // displays 8
locals.fib = M2;
Console.WriteLine(locals.fib(6)); // displays 12
Console.WriteLine(fibCopy(6)); // displays 18

Notice how the first call to fib is really just a call to M1 and when M1 "recurses" on fib it just ends up calling M1 again because that is what is assigned to fib.  The call to fibCopy is a little trickier because the original call is really a call to M1 as well but when it "recurses" it invokes fib instead of fibCopy which also happens to be M1 at the time.  So the first two calls behave as expected.

Now it starts to get a little strange.  First we assign M2 to fib.  Then when we invoke fib on the next line it doesn't invoke our original "fib" function M1 anymore but it not invokes M2.  This of course displays the result of multiplying 6 by 2.

Now for the strangest part, we invoke fibCopy.  But fibCopy actually still references M1 and not M2 and since 6 > 1 then it "recurses" by invoking fib twice with 5 and 4 respectively and then summing the results.  But the calls to fib actually invoke M2 now.  So the results of the calls to fib are 10 and 8 which then are summed to produce 18.

Now let's return to our original problem.

Func<int, int> fib = null;
fib = n => n > 1 ? fib(n - 1) + fib(n - 2) : n;
Test(fib.Memoize());

Notice that the function passed to Test is memoized but the body of the function still calls the unmemoized fib.  The function itself is memoized by all of the "recursive" calls are not.  So it probably will not behave as intended.  This can be corrected by doing the following.

Func<int, int> fib = null;
fib = Memoize(n => n > 1 ? fib(n - 1) + fib(n - 2) : n);
Test(fib);

Now the calls to fib when n > 1 are made to the memoized fib delegate.

If you read my last post on anonymous recursion in C# which introduces the Y combinator then you may have tried the following.

Func<int, int> fib = Y<int, int>(f => n => n > 1 ? f(n - 1) + f(n - 2) : n).Memoize();

This behaves very similarly to the incorrectly written fib above.  Calls to the function itself are memoized but all of the recursive calls are not.  This is because fib is memoized but f is not.  Erik Meijer pointed me to a fantastic paper that discusses how to memoize functions that are the fixed point of functionals.  In this paper, Cook and Launchbury present a function written in Haskell that enables memoized anonymous recursive functions.

memoFix f = let g = f h
h = memo g
in h

Using the MemoizeFix function in C# looks very similar to using the Y combinator, but instead of just returning a recursive function, it would also memoize the function and all the recursive calls.

Func<int, int> memoFib = MemoizeFix<int, int>(f => n => n > 1 ? f(n - 1) + f(n - 2) : n);

But there are two problems with implementing the function in C#.  First, Haskell lets programmers write mutually recursive definitions (g is defined in terms of h and h is defined in terms g).  Second, Haskell is a lazy language.  A straightforward implementation in C# will either produce a null reference exception or a stack-overflow because of the eager evaluation and mutually recursive definitions.  Fortunately, lambdas can be used to solve both problems.

static Func<A, R> MemoizeFix<A, R>(this Func<Func<A, R>, Func<A, R>> f)
{
Func<A, R> g = null;
Func<A, R> h = null;
g = a => f(h)(a);
h = g.Memoize();
return h;
}

Here we define both g and h before they are assigned their final values.  Then instead of assigning f(h) to g which would immediately invoke f causing a Null Reference exception, we assign a lambda to g which will be evaluated sometime in the future when h has a non-null value.  The rest of the function is very similar to the Haskell version.

Notice what happens when memoFib is invoked.  This will actually invoke h which was defined in MemoizeFix, but h is really just a memoized version of g.  Since this is the first invocation of h and there will be no precomputed value and so g will be invoked.  But when g is invoked then it actually invokes f passing h (memoized g) as the function to use for recursion.  This is exactly what we want.  Beautiful isn't it?