Continuation-Passing Style

See the Wikipedia article to get started. These notes actually start from the factorial example (in Scheme), and show how to convert this code to C, using 'closure conversion'.

Note: a previous version of this document used 'lambda lifting' and 'closure conversion' interchangably (I blame wikipedia). But 'lambda lifting' refers to the technique of passing free variables as arguments, thereby making a closure unnecessary. However, it works only for known functions. In other words, any function that escapes requires a closure data structure of some kind.

The CPS transform introduces lots of new, escaping functions, so if your target language doesn't have closures you will need to do closure conversion.

transforming code to continuation-passing-style

First off, let's note that the accumulator version of the factorial function is the 'correct' way to write it. But the 'naive' version allows us to show how normal stack-like behavior can be handled without a stack.

In C, the original function looks like this:

// --------------------------------------------------------------------------------
// original
// --------------------------------------------------------------------------------

#include <stdio.h>

int
fact (int n)
{
  if (n == 0) {
    return 1;
  } else {
    return n * fact (n-1);
  }
}

int
main (int argc, char * argv[])
{
  fprintf (stdout, "fact(5)==%d\n", fact(5));
  return 0;
}

continuation-passing style with free variables

The CPS version of this in Scheme looks like this:

(define (factorial& n k)
  (=& n 0 (lambda (b)
            (if b			; growing continuation
                (k 1)			; in the recursive call
                (-& n 1 (lambda (nm1)
			  (factorial& nm1 (lambda (f)
					    (*& n f k)))))))))

Translating this directly to C is tricky, because there are free variables in the lambda functions there. Luckily, gcc includes a nifty feature, 'lexical functions', that will let us get something close to that:

// --------------------------------------------------------------------------------
// cps converted, using gcc's lexical functions to handle the free variables.
// --------------------------------------------------------------------------------

typedef int factk (int);

int
ret_cps (int n)
{
  return n;
}

int
minus_cps (int a, int b, factk k)
{
  return k (a-b);
}

int
mul_cps (int a, int b, factk k)
{
  return k (a*b);
}

int
eq_cps (int a, int b, factk k)
{
  return k (a==b);
}

int
factcps (int n, factk k)
{
  int factcps_1 (int b) {
    int factcps_2 (int nm1) {
      int factcps_3 (int f) {
        return mul_cps (n, f, k);
      }
      return factcps (nm1, factcps_3);
    }
    if (b) {
      return k (1);
    } else {
      return minus_cps (n, 1, factcps_2);
    }
  }
  return eq_cps (n, 0, factcps_1);
}

#include <stdio.h>
int
main (int argc, char * argv[])
{
  int result = factcps (5, ret_cps);
  fprintf (stdout, "factcps(5)==%d\n", result);
  return 0;
}

[Depending on your installation of gcc, you may need to use the -fnested-functions option to compile this.] Note how each of the auxiliary 'factcps' functions is embedded inside the other, where each one has access to the free variables surrounding it.

If we want to compile this with a normal C compiler, we need to get rid of these lexical functions. But first, let's simplify things by 'inlining' some of the primitive operations.

inlining of primitive operators

step one: inline eq_cps:

  int
  factcps (int n, factk k)
  {
    int factcps_2 (int nm1) {
      int factcps_3 (int f) {
        return mul_cps (n, f, k);
      }
      return factcps (nm1, factcps_3);
    }
    if (n == 0) {
      return k (1);
    } else {
      return minus_cps (n, 1, factcps_2);
    }
  }

step two: inline minus_cps:

  int
  factcps (int n, factk k)
  {
    int factcps_3 (int f) {
      return mul_cps (n, f, k);
    }
    if (n == 0) {
      return k (1);
    } else {
      return factcps (n-1, factcps_3);
    }
  }

step three: inline mul_cps:

// step three: inline mul_cps
int
factcps (int n, factk k)
{
  int factcps_3 (int f) {
    return k (n * f);
  }
  if (n == 0) {
    return k (1);
  } else {
    return factcps (n-1, factcps_3);
  }
}

closure conversion

Now, we have only one 'internal' function left. That makes the job of closure conversion a little easier. Up to now, we've represented continuations as a pointer to a function. Now, though, we need to add some smarts - the continuation needs to include not just a function pointer, but values for the free variables inside that function.

To do that, we change our representation of continuations to a structure:

struct _cont;

typedef void (*factk)(int, struct _cont *);

typedef struct _cont {
  factk fun;
  struct _cont * k;
  int n;
} cont;

We move the definition of factcps_3 outside of factcps:

void
factcps_3 (int f, cont * k) {
  k->fun (k->n * f, k->k);
}

We create the continuation object on the stack inside factcps:

 void
factcps (int n, cont * k)
{
  if (n == 0) {
    k->fun (1, k->k);
  } else {
    cont k0 = { factcps_3, k, n };
    factcps (n-1, &k0);
  }
}

Tricky: notice that one of the free variables we store in the continuation struct is the previous value of k. Grok this! It's critical to seeing how the C stack is replaced by a linked list of continuations...

heap-allocated continuation closures

We're *very* close now to being completely liberated from the C stack. But not quite. Notice how we allocated those continuation closures as local variables inside factcps? Well, we need to allocate those from the heap instead. The final result will have something *resembling* a stack of frames, but will actually be allocated on a heap.

For now, we'll just make a fake 'heap' using a pre-allocated array of continuations. In a real implementation you'd have a garbage collector to do the job of reclaiming these:

void
static cont cont_heap[100];
static int cont_counter = 0;

static
cont *
get_cont()
{
  return &cont_heap[cont_counter++];
}

// note that ret_cps ignores its continuation
static
void
ret_cps (int n, cont * k)
{
  the_result = n;
}

static
void
factcps_3 (int f, cont * k) {
  k->fun (k->n * f, k->k);
}

void
factcps (int n, cont * k)
{
  if (n == 0) {
    k->fun (1, k->k);
  } else {
    cont * k0 = get_cont();
    k0->fun = factcps_3;
    k0->k = k;
    k0->n = n;
    factcps (n-1, k0);
  }
}

The continuation closure is like a linked list of information about pending returns from function calls, and takes the place of a classic C unframed stack

N.B. This entire presentation relies on the idea that modern C compilers will perform tail call optimization. Notice how every CPS function (especially in the last examples) never returns... they only ever call other functions. When optimized properly, this turns each 'call' into a 'goto'. Definitely check the output of your compiler.

If you're using clang/llvm (which I recommend highly), you should look at the LLVM IR output from the compiler:

  $ clang factcps_heap.c -emit-llvm -S
  $ less factcps_heap.s
  $ opt fact_cps.s -mem2reg -S

the source code

Note that you'll need gcc to compile the 'lexical' and 'simplified' versions:


Last modified: Sat Mar 24 14:31:23 PDT 2012