Open recursion (C++)

In the previous post, we explored how we could leverage open recursion to solve a dynamic programming problem, while keeping the following aspect decoupled:

  • The recurrence relation: the solution to our problem
  • The memoization strategy: indexing into a vector
  • The order in we compute the sub-solutions

Today we will see how to apply this trick in a more mainstream language, C++.

 

Open recursion in C++


Our first step will be to translate the recurrence formula into a naive C++ and very inefficient implementation of our solution to the counting binary search trees problem:

long long bst_count(int n)
{
    if (n <= 1) return 1;

    long long sub_counts = 0;
    for (int i = 0; i < n; ++i)
        sub_counts += bst_count(i) * bst_count(n-1-i);
    return sub_counts;
}

 
Adding open recursion results in a pretty simple change for our C++ implementation:

  • Adding a template parameter “recur” as first argument
  • Replace each recursive call by a call to “recur”

Following this recipe leads to the following C++ implementation:

template<typename Recur>
long long bst_count(Recur recur, int n)
{
­    if (n <= 1) return 1;

    long long sub_counts = 0;
    for (int i = 0; i < n; ++i)
        sub_counts += recur(i) * recur(n-1-i);
    return sub_counts;
}

 
Similar to what we did in Haskell, we can get back back the naive algorithm by introducing a wrapper function handling the two step recursion:

long long best_count_naive(int n)
{
   auto recur = [](int k) { return best_count_naive(k); };
   return bst_count(recur, n);
}

 
This function still has the same terrible performance. It is time to improve our algorithm, by leveraging Dynamic Programming techniques.

 

Adding memoization


You know the drill: we can now exploit this open recursion to insert some memoization in the middle in the recursion.

We will use the same memoization strategy we used previously: indexing into a vector to seek the results of our previously computed sub-solutions.

long long bst_count_memo(int n)
{
    std::vector<long long> memo_table(n+1);
    for (int i = 0; i <= n; ++i)
        memo_table[i] = bst_count([&](int k) { return memo_table[k]; }, i);
    return memo_table.back();
}

 
Even if we ignore the integer overflow issue (let us imagine we use a unbounded integer representation), there is still a big difference between this C++ implementation and the corresponding implementation in Haskell:

memoBstCount :: Int -> Integer
memoBstCount n = Vector.last memo
  where
    memo = Vector.generate (n+1) (bstCount (memo Vector.!))

 
In C++, the order of evaluation now has a direct impact on the correctness of our solution. Had we computed the sub-solutions in the wrong order, we would have got a completely erroneous solution.

So although we applied the same idea, we achieve less decoupling in C++ than in Haskell: the implementer of the memoization must still care about the insides of the recurrence formula. Can we do better?

 

Adding laziness


To get rid of the coupling to the order of evaluation, we will take some of the ideas from Haskell, and introduce laziness into our solution. We just need a helper function that will:

  • Index into the vector to check for a previously computed value
  • Return this value if it could be found
  • Perform the computation otherwise, and store it in the vector

This gives us to the following C++ implementation:

static long long lazy_impl(std::vector<long long>& memo, int n)
{
    if (memo[n]) return memo[n];
    auto recur = [&](int k) { return lazy_impl(memo, k); };
    return memo[n] = bst_count(recur, n);
}

long long bst_count_lazy(int n)
{
    std::vector<long long> memo(n+1);
    return lazy_impl(memo, n);
}

 
You might wonder if we could have used a lambda instead of introducing a function. Unfortunately no, since a lambda has no name, it cannot refer to itself in its body.

The following would not compile!

long long bst_count_lazy(int n)
{
    std::vector<long long> memo(n+1);
    auto lazy_impl = [&](int n) {
        if (memo[n]) return memo[n];
        auto recur = [&](int k) { return lazy_impl(memo, k); };
        return memo[n] = bst_count(recur, n);
    }
    return lazy_impl(memo, n);
}

 
We now have the same level of decoupling as with our Haskell solution.

 

Conclusion


We proved that the open recursion trick known from the Haskell world can be very easily brought to C++. Using it, we can decouple the memoization strategy of a Dynamic Programming solution from the recurrence relation.

Although it might seem that we achieved a total decoupling between the recurrence relation and the memoization part, the implementer of the memoization strategy still needs to care about which sub-solutions to compute.

Can we do something about it? We will look into this subject in the next post.

Open recursion (Haskell)

There is a large class of problems that we can solve by nice mathematical recurrence relations.

The recurrence is usually simple to read, to reason about, and describes the solution with concision. But a naive translation of this recurrence into code will very often lead to a very inefficient algorithm.

When the sub-problems of the recurrence relation form a DAG, the usual trick is to use Dynamic Programming to speed up the computation. But it often results in a complete change of the code, hiding the recurrence, sometimes to the point we cannot recognize the problem anymore.

Thankfully, with open recursion, can have the best of both worlds!

 

Counting Binary Search Trees


Before introducing the technique, let us first take a nice example whose solution can be described by a simple recurrence relation: counting binary search trees.

We observe that we can cut any collection [X1, X2 .. Xn] in two parts, around one of its element Xi. Then, we can:

  • Recursively count all BST on the left part [X1 .. Xi).
  • Recursively count all BST on the right part (Xi .. Xn]
  • Multiply these values to get the count of BST rooted in Xi.

As we have to do this for all i in [1..N], the recurrence relation becomes:

bstCount :: Int -> Integer
bstCount n
  | n <= 1 = 1
  | otherwise = sum [bstCount i * bstCount (n-1-i) | i <- [0..n-1]]

 
But this algorithm is terribly inefficient: we keep recomputing the same sub-solutions over an over again.

We know the next step, which is to memoize the solutions to the sub-problems. However, as our original goal was to do it without compromising the code readability, let us first introduce open recursion.

 

Adding open recursion


Open recursion consists of avoiding direct recursion by adding an extra layer of indirection. It typically means transforming our recurrence relation to take a new parameter, a function that will be called instead of recurring.

By doing so, the recurrence formula looses its recursive nature. Here is how it would translate into Haskell:

bstCount :: (Int -> Integer) -> Int -> Integer
bstCount rec n
  | n <= 1 = 1
  | otherwise = sum [rec i * rec (n-1-i) | i <- [0..n-1]]

 
How can we get our recurrence relation back? By introducing a wrapper function that will give itself to the “bstCount” recurrence. Instead of having a direct recursion, we have a two-step recursion. This is best explained by example:

bstCountNaive :: Int -> Integer
bstCountNaive = bstCount bstCountNaive

 
By simple renaming, we can see that it can be expressed as: x = f x. So the naive recursive algorithm we had earlier is effectively the fixed point of the open recurrence relation. Which can be written in Haskell as:

import Data.Function(fix)

bstCountNaive :: Int -> Integer
bstCountNaive = fix bstCount

 

Adding memoization


We can now exploit this open recursion to insert some memoization in the middle of the recursion. In our specific case, the sub-problems exhibit a simple structure:

  • We can compute a vector of the results of each sub-problem 0..N
  • Then recurring part consist in indexing into this vector in O(1)

So instead of triggering a recursion, we search the sub-solution result into our memoization vector, which translates into the following Haskell code:

memoBstCount :: Int -> Integer
memoBstCount n = Vector.last memo
  where
    memo = Vector.generate (n+1) (bstCount (memo Vector.!))

 
How can it even work? You are witnessing here the magic of Haskell: laziness helps us refer to the item we are computing inside its own computation.

Following this change, at each N we now have N-1 steps to perform, each taking constant time, thanks to memoization. This gives us a quadratic complexity (the result of the sum from 1 to N).

 

Conclusion


Using open recursion in combination with Haskell laziness has effectively let us decoupled the following aspects:

  • The recurrence relation, solution to our problem
  • The memoization strategy, indexing into a vector
  • The order in which we compute these sub-solutions

As a result, we get all the benefits of having a simple recurrence relation, untainted by the implementation details required to get an efficient implementation.

In the next post, we will see how to apply this trick in a more mainstream language: C++.