# -*- Mode:Python; coding: utf-8 -*-

# Having a go at the constraint-based inference algorithm described by
#  Pottier and Rémy in "Advanced Topics in Types and Programming
#  Languages", chapter 10: "The Essence of ML Type Inference".
#
# Another great reference is a somewhat simplified presentation of the
#  same material, but (thankfully) with some context, by Pottier: "A
#  modern eye on ML type inference - Old techniques and recent
#  developments", available from his home page:
#  http://cristal.inria.fr/~fpottier/
#
# NOTE: this is a work in progress.
# http://www.nightmare.com/rushing/irken
#

# now with n-ary args
# TODO: s-letall (maybe?), letrec, kind-checking

import nodes
import sys
import pdb
trace = pdb.set_trace
is_a = isinstance

# the simply typed lambda calculus:
#   e ::= x | λx.e | e e
# 
# expressions:
# x    : <varref>
# λx.e : <function>
# e e  : <application>

# types
# t ::= a | (arrow t t)
# (where a = <tvar>)

# constraints:
# C ::= (equals t t) | (and C C) | (exists a C)

# types

tvar_counter = 0

def fresh():
    global tvar_counter
    result = tvar_counter
    tvar_counter += 1
    return result

class _type:
    pass

class t_base (_type):
    def __cmp__ (self, other):
        return cmp (self.__class__, other.__class__)

class t_int (t_base):
    def __repr__ (self):
        return 'int'

class t_char (t_base):
    def __repr__ (self):
        return 'char'

class t_str (t_base):
    def __repr__ (self):
        return 'str'

class t_var (_type):
    next = None
    rank = -1
    letters = 'abcdefghijklmnopqrstuvwxyz'
    eq = None
    def __init__ (self):
        self.id = fresh()
    def __repr__ (self):
        return base_n (self.id, len(self.letters), self.letters)

class t_predicate (_type):
    def __init__ (self, name, args):
        self.name = name
        self.args = tuple (args)
    def __repr__ (self):
        # special case
        if self.name == 'arrow':
            if len(self.args) == 2:
                return '%r->%r' % (self.args[1], self.args[0])
            else:
                return '%r->%r' % (self.args[1:], self.args[0])
        else:
            return '%s%r' % (self.name, self.args)

def is_pred (t, *p):
    # is this a predicate from the set <p>?
    return is_a (t, t_predicate) and t.name in p

def arrow (*sig):
    # sig = (<result_type>, <arg0_type>, <arg1_type>, ...)
    # XXX this might be more clear as (<arg0>, <arg1>, ... <result>)
    return t_predicate ('arrow', sig)
    
# row types
def product (row):
    # a.k.a. 'Π'
    # XXX kind-check that args[0] is a row?
    return t_predicate ('product', (row,))

def sum (row):
    # a.k.a. 'Σ'
    return t_predicate ('sum', (row,))

def rdefault (arg):
    # a.k.a. 'δ'
    return t_predicate ('rdefault', (arg,))

def rlabel (name, type, rest):
    return t_predicate ('rlabel', (name, type, rest))

def abs():
    return t_predicate ('abs', ())

def pre (x):
    return t_predicate ('pre', (x,))

# constraints

def constraint_repr (kind, args):
    if args:
        return '[%s %s]' % (kind, ' '.join ([repr(x) for x in args]))
    else:
        return '%s' % (kind,)

class constraint:
    kind = 'abstract'
    args = ()

    def __repr__ (self):
        return constraint_repr (self.kind, self.args)

class c_true (constraint):
    kind = 'true'
    args = ()

class c_false (constraint):
    kind = 'true'
    args = ()

class c_equals (constraint):
    kind = 'equals'
    def __init__ (self, *args):
        self.args = args
    def __repr__ (self):
        return constraint_repr ('=', self.args)

class c_and (constraint):
    kind = 'and'
    def __init__ (self, c0, c1):
        self.args = (c0, c1)
        
class c_exists (constraint):
    kind = 'exists'
    def __init__ (self, vars, sub):
        self.args = (vars, sub)
        self.vars = vars
        self.sub = sub

class c_is (constraint):
    # <x> has type <t> iff <t> is an instance of the type scheme associated with <x>
    kind = 'is'
    def __init__ (self, x, t):
        self.args = (x, t)
        self.x = x
        self.t = t

class c_let (constraint):
    kind = 'let'
    def __init__ (self, formal, init, body):
        self.args = (formal, init, body)
        self.formal = formal
        self.init = init
        self.body = body

class c_forall (constraint):
    kind = 'forall'
    def __init__ (self, vars, constraint, type):
        self.args = (vars, constraint, type)
        self.vars = vars
        self.constraint = constraint
        self.type = type

# stack frames
class frame:
    kind = 'abstract'

class s_empty (frame):
    kind = 'empty'

empty = s_empty()

class s_and (frame):
    def __init__ (self, c):
        self.constraint = c

class s_exists (frame):
    def __init__ (self, vars):
        self.vars = vars

class s_let (frame):
    def __init__ (self, formal, vars, type, body, rank):
        self.formal = formal
        self.vars = vars
        self.type = type
        self.body = body
        self.rank = rank
        for v in vars:
            v.rank = rank

    def add_vars (self, vars):
        self.vars += tuple (vars)
        for v in vars:
            v.rank = self.rank

class s_env (frame):
    # after a <let> type scheme has been solved, an <env> frame
    #   binds the scheme to the formal.
    def __init__ (self, formal, type):
        self.formal = formal
        self.type = type

# this is a two-phase algorithm
# 1) constraint generation
# 2) constraint solving

class constraint_generator:

    def go (self, exp):
        t = t_var()
        return self.gen (exp, t), t

    def gen (self, exp, t):
        if is_a (exp, nodes.varref):
            return c_is (exp.name, t)
        elif is_a (exp, nodes.function):
            arg_tvs = [t_var() for x in exp.formals]
            rtv = t_var()
            bod0 = self.gen (exp.body, rtv)
            for i in range (len (exp.formals)):
                bod0 = c_let (exp.formals[i].name, arg_tvs[i], bod0)
            # XXX: in ATTPL, this is a c_supertype relation
            sub1 = c_equals (t, arrow (rtv, *arg_tvs))
            return c_exists ([rtv] + arg_tvs, c_and (bod0, sub1))
        elif is_a (exp, nodes.application):
            arg_tvs = [t_var() for x in exp.rands]
            sub0 = self.gen (exp.rator, arrow (t, *arg_tvs))
            for i in range (len(exp.rands)):
                sub0 = c_and (sub0, self.gen (exp.rands[i], arg_tvs[i]))
            return c_exists (arg_tvs, sub0)
        elif is_a (exp, nodes.let):
            # unary let.  do I generalize this now?
            x = t_var()
            assert (len(exp.names) == 1)
            init0 = self.gen (exp.inits[0], x)
            body0 = self.gen (exp.body, t)
            return c_let (exp.names[0].name, c_forall ((x,), init0, x), body0)
        elif is_a (exp, nodes.literal):
            if exp.kind == 'int':
                return c_equals (t, t_int())
            elif exp.kind == 'char':
                return c_equals (t, t_char())
            else:
                raise ValueError ("unsupported literal type")
        else:
            raise ValueError

class UnboundVariable (Exception):
    pass

class TypeError (Exception):
    pass

class multi:
    # a 'standard' multi-equation of the form A=B=C=T where A,B,C are
    # type variables and T is an optional type.
    def __init__ (self, vars, type):
        self.vars = vars
        self.type = type
        self.rep = self.min_rank()
        for v in self.vars:
            # point them all at the rep var
            if v is not self.rep:
                v.next = self.rep
            v.eq = self
        self.rank = self.rep.rank
        self.free = set()
        ftv (self.free, type)

    def min_rank (self):
        # choose the variable with lowest <rank,id>
        mr = sys.maxint
        mv = None
        for v in self.vars:
            if v.rank < mr:
                mr = v.rank
                mv = v
            elif v.rank == mr:
                if v.id < mv.id:
                    mv = v
        return mv

    def __repr__ (self):
        r = '='.join (['%r' % v for v in self.vars])
        if self.type:
            return r + '=%r' % (self.type,)
        else:
            return r

class unifier:

    # Maintains a conjunction of multi-equations.  In the typical HM algorithm,
    #  this would be called the 'subst'.

    def __init__ (self, step=False):
        self.step = step
        self.eqs = set()
        self.exists = []

    def add (self, vars, type=None):
        # add a set of terms as a multi-equation to the conjunction.
        assert (is_a (vars, set))
        assert (not is_a (type, t_var))

        if is_a (type, t_predicate):
            type = self.try_name_1 (type)

        if (not type and len(vars) == 1) or (type and len(vars) == 0):
            self.dprint ('s-single')
        else:
            # any of these vars already present?
            for v in vars:
                if v.eq:
                    # if so, then fuse
                    self.fuse (v.eq, vars, type)
                    return
            # nope, a new equation
            eq = multi (vars, type)
            self.eqs.add (eq)

    def add2 (self, *args):
        # add an equation between a random collection of variables and types
        vars = set()
        types = []
        for arg in args:
            if is_a (arg, t_var):
                vars.add (arg)
            else:
                types.append (arg)
        if len(types) == 2:
            self.decompose ((vars, types[0]), (vars, types[1]))
        elif len(types) > 2:
            raise ValueError ("too many types")
        elif len(types) == 1:
            self.add (vars, types[0])
        else:
            self.add (vars, None)

    def is_free (self, var):
        # is <var> free in this equation?
        for eq in self.eqs:
            # any var referenced in a type (that does not
            #  point to another var) is 'free'
            if var in eq.free and not var.next:
                return True
        else:
            return False

    def try_name_1 (self, type):
        # ensure that a predicate's arguments are type variables,
        #  naming them if necessary (rule S-NAME-1).
        args2 = []
        flag = False
        for arg in type.args:
            if is_a (arg, str):
                # XXX row labels, must be a better way.
                args2.append (arg)
            elif not is_a (arg, t_var):
                self.dprint ('s-name-1')
                x = t_var()
                self.exists.append (x)
                self.add (set([x]), arg)
                args2.append (x)
                flag = True
            else:
                args2.append (arg)
        if flag:
            return t_predicate (type.name, args2)
        else:
            return type

    def forget (self, eq):
        self.eqs.remove (eq)
        for v in eq.vars:
            v.eq = None
            v.next = None

    def fuse (self, eq, tvs0, ty0):
        tvs1 = eq.vars
        ty1  = eq.type
        # is a three-way fuse possible? (e.g. A=T0 B=T1; A=B=T2)
        # I don't think so, so let's ignore that possibility for now.
        self.forget (eq)
        self.dprint ('s-fuse')
        if ty0 and ty1:
            # must unify types
            # A=B=T0 ^ B=C=T1 => A=B=C=T0=T1
            self.decompose ((tvs0, ty0), (tvs1, ty1))
        else:
            # A=B=T0 ^ B=C => A=B=C=T0
            self.add (tvs0.union (tvs1), ty0 or ty1)

    def decompose (self, t0, t1):
        tvs0, ty0 = t0
        tvs1, ty1 = t1
        tvs = tvs0.union (tvs1)
        if ty0 == ty1:
            # a=b=int=int, etc... => a=b=int
            self.add (tvs, ty0)
        elif is_pred (ty0, 'rlabel', 'rdefault') or is_pred (ty1, 'rlabel', 'rdefault'):
            self.unify_rows (ty0, ty1, tvs)
        elif (is_a (ty0, t_predicate) and is_a (ty1, t_predicate)
              and ty0.name == ty1.name
              and len(ty0.args) == len(ty1.args)):
            self.dprint ('s-decompose')
            # P(a,b,c)=P(d,e,f)=ε => a=d ^ b=e ^ c=f ^ P(a,b,c)=ε
            for i in range (len (ty0.args)):
                self.add2 (ty0.args[i], ty1.args[i])
            self.add (tvs, ty0)
        else:
            self.dprint ('s-clash')
            raise TypeError ((ty0, ty1))

    def unify_rows (self, ty0, ty1, tvs):
        if is_pred (ty0, 'rlabel') and is_pred (ty1, 'rlabel'):
            if ty0.args[0] != ty1.args[0]:
                # distinct head labels
                self.dprint ('s-mutate-ll')
                # XXX be concerned about how one of these may have types
                #     and the other has variables.  do we need to check
                #     and reorder them?
                l0, t0, d0 = ty0.args
                l1, t1, d1 = ty1.args
                x = t_var()
                self.exists.append (x)
                self.add2 (d0, rlabel (l1, t1, x))
                self.add2 (d1, rlabel (l0, t0, x))
                self.add (tvs, rlabel (l0, t0, d0))
            else:
                # XXX this should be handled by the normal s-decompose
                l0, t0, d0 = ty0.args
                l1, t1, d1 = ty1.args
                self.add2 (t0, t1)
                self.add2 (d0, d1)
                self.add (tvs, ty0)
        elif is_pred (ty0, 'rlabel') or is_pred (ty1, 'rlabel'):
            # only one is an rlabel
            if is_pred (ty1, 'rlabel'):
                # ensure that ty0 is the rlabel
                ty0, ty1 = ty1, ty0
            if is_pred (ty1, 'rdefault'):
                self.dprint ('s-mutate-dl')
                x = ty1.args[0]
                assert (is_a (x, t_var))
                self.add2 (x, ty0.args[1])
                self.add2 (ty1, ty0.args[2])
                self.add (tvs, ty1)
            elif is_a (ty1, t_predicate):
                # some other predicate
                self.dprint ('s-mutate-gl')
                n = len (ty1.args)
                tvars0 = [t_var() for x in ty1.args]
                tvars1 = [t_var() for x in ty1.args]
                self.exists.extend (tvars0)
                self.exists.extend (tvars1)
                l0, t0, d0 = ty0.args
                g = ty1.name
                self.add2 (t_predicate (g, tvars0), t0)
                self.add2 (t_predicate (g, tvars1), d0)
                for i in range (n):
                    self.add2 (ty1.args[i], rlabel (l0, tvars0[i], tvars1[i]))
                self.add (tvs, ty1)
            else:
                self.dprint ('s-clash')
                raise TypeError ((ty0, ty1))
        elif is_pred (ty0, 'rdefault',) or is_pred (ty1, 'rdefault'):
            if is_pred (ty1, 'rdefault'):
                # ensure that ty0 is the rdefault/δ
                ty0, ty1 = ty1, ty0
            if is_a (ty1, t_predicate):
                # some other predicate
                self.dprint ('s-mutate-gd')
                n = len (ty1.args)
                g = ty1.name
                tvars = [ t_var() for x in ty1.args ]
                self.exists.extend (tvars)
                self.add2 (ty0.args[0], t_predicate (g, tvars))
                for i in range (n):
                    self.add2 (ty1.args[i], rdefault (tvars[i]))
                self.add (tvs, ty0)
            else:
                self.dprint ('s-clash')
                raise TypeError ((ty0, ty1))
        else:
            self.dprint ('s-clash')
            raise TypeError ((ty0, ty1))

    def split (self, sz):
        # we leave in only equations made entirely of 'old' variables
        young = set (sz.vars)
        u2 = []
        forget = []
        for eq in self.eqs:
            if eq.rep in young or eq.free.intersection (eq.vars):
                u2.append (c_equals (* list(eq.vars) + [eq.type]))
                forget.append (eq)
        for eq in forget:
            self.forget (eq)
        return list_to_conj (u2)

    def dprint (self, msg):
        if self.step:
            sys.stderr.write ('*** ')
            sys.stderr.write (msg)
            sys.stderr.write ('\n')
            self.pprint()

    def simplify (self):
        # d=c=b=a=x => a=x
        def compress (t):
            if is_a (t, t_var):
                return t.next or t
            elif is_a (t, t_predicate):
                return t_predicate (t.name, [compress(x) for x in t.args])
            else:
                return t
        for eq in self.eqs:
            eq.type = compress (eq.type)
            eq.vars = set ([eq.rep])

    def renumber (self):
        # first, collect every tvar referenced
        tvars = set()
        for eq in self.eqs:
            tvars.update (eq.vars)
            tvars.update (eq.free)
        tvars = list(tvars)
        tvars.sort (lambda a,b: cmp (a.id, b.id))
        print 'renumbering, %d tvars' % (len(tvars),)
        # heh, don't look!
        for i in range (len (tvars)):
            tvars[i].id = i

    def pprint (self):
        sys.stdout.write ('U: ')
        eqs = list (self.eqs)
        # sort the equations by representative tvar
        eqs.sort (lambda a,b: cmp (a.rep.id, b.rep.id))
        for eq in eqs:
            sys.stdout.write ('\t%r\n' % (eq,))
        sys.stdout.write ('\n')

# I believe the 'union-find' data-structure/algorithm is another way
#   of describing HM's use of type variable lookup with path
#   compression.

class solver:

    def __init__ (self, step=True):
        self.step = step
        self.prim_env = self.make_prim_env()

    def dprint (self, msg):
        if self.step:
            sys.stderr.write (msg)
            sys.stderr.write ('\n')

    def solve (self, c):

        self.dprint ('\nHit <return> at each pause (or "t<return>" to enter the debugger)')

        pvars = []
        self.exists = []
        # ensure there are always two items on the stack
        s = [empty, empty]
        u = unifier (self.step)
        c = c

        orig_c = c
        rank = 0

        def push (x):
            s.append (x)

        def pop ():
            s.pop()

        while 1:

            if self.step:
                print 'S:',
                self.pprint_stack (s)
                u.pprint()
                print 'C:', c
                print 'exists:', self.exists

            # the top two elements of the stack
            sy, sz = s[-2], s[-1]

            if self.step:
                print '-----------------------------'
                if raw_input().startswith ('t'):
                    trace()

            # --- solver ---            

            if u.exists:
                self.dprint ('s-ex-1')
                self.move_exists (s, u.exists)
                u.exists = []
            elif is_a (sz, s_exists):
                self.dprint ('s-record-ex')
                self.exists.extend (sz.vars)
                pop()
            elif is_a (c, c_equals):
                self.dprint ('s-solve-eq')
                u.add2 (*c.args)
                c = c_true()
            elif is_a (c, c_is) and is_a (c.x, str):
                self.dprint ('s-solve-id')
                scheme, type = self.lookup (c.x, s), c.t
                # assert that scheme.type is a tvar
                # "Recall that if σ is of the form ∀X0..XN[U].X
                # where X0..XN#ftv(T), then c_is(σ, T) stands for ∃X0..XN.(U ^ X=T)."
                self.dprint ('scheme= %r' % scheme)
                self.dprint ('type=%r' % type)
                if is_a (scheme, c_forall):
                    if not scheme.vars and is_a (scheme.constraint, c_true):
                        c = c_equals (scheme.type, type)
                    else:
                        c = c_exists (scheme.vars, c_and (scheme.constraint, c_equals (scheme.type, type)))
                else:
                    c = c_equals (scheme, type)
            elif is_a (c, c_and):
                self.dprint ('s-solve-and')
                push (s_and (c.args[1]))
                c = c.args[0]
            elif is_a (c, c_exists):
                self.dprint ('s-solve-ex')
                self.move_exists (s, c.vars)
                c = c.sub
            elif is_a (c, c_let):
                self.dprint ('s-solve-let')
                if is_a (c.init, c_forall):
                    vars = c.init.vars
                    push (s_let (c.formal, c.init.vars, c.init.type, c.body, rank))
                    rank += 1
                    c = c.init.constraint
                else:
                    # let x: T in C == let x: ∀∅[true].T in C
                    push (s_let (c.formal, (), c.init, c.body, rank))
                    c = c_true()
            elif is_a (c, c_true):
                if is_a (sz, s_and):
                    self.dprint ('s-pop-and')
                    pop()
                    c = sz.constraint
                elif is_a (sz, s_let) and not is_a (sz.type, t_var):
                    self.dprint ('s-name-2')
                    x = t_var()
                    pop()
                    push (s_let (sz.formal, sz.vars + (x,), x, sz.body, rank))
                    u.add (set([x]), sz.type)
                elif is_a (sz, s_let):
                    unname = []
                    for var in sz.vars:
                        # XXX this isn't quite right - we can subst sz.type with var.rep
                        if var.next and not u.is_free (var) and sz.type is not var:
                            unname.append (var)
                    if unname:
                        self.dprint ('s-unname %r' % (unname,))
                        vars = [x for x in sz.vars if x not in unname]
                        self.dprint ('  new vars=%r' % (vars,))
                        pop()
                        push (s_let (sz.formal, vars, sz.type, sz.body, sz.rank))
                    else:
                        # **** S-LETALL here ****
                        # s-letall will simplify the current scheme by removing some of
                        #   the quantified tvars.
                        # pop-let is a fall-through - after all
                        # the above conditions have been met it turns the <let>
                        # into an <env>.
                        self.dprint ('s-pop-let')
                        pop()
                        push (s_env (sz.formal, c_forall (sz.vars, u.split (sz), sz.type)))
                        c = sz.body
                elif is_a (sz, s_env):
                    # record the type scheme associated with this program variable
                    pvars.append ((sz.formal, sz.type.type))
                    self.dprint ('s-pop-env')
                    pop()
                elif is_a (sz, s_empty):
                    # we're done!
                    self.dprint ('exists=%r' % self.exists)
                    self.dprint ('constraint=%r' % orig_c)
                    return pvars, u
                else:
                    raise ValueError ("unexpected")
            else:
                raise ValueError ("no rule applies")

    def move_exists (self, s, vars):
        # this implements the various S-EX-? rules that attach a set of tvars to
        #   the nearest <let> on the stack.
        n = len (s)
        for i in range (-1, -n, -1):
            if is_a (s[i], s_let):
                s[i].add_vars (vars)
                break
        else:
            self.exists.extend (vars)

    def lookup (self, x, s):
        for f in s:
            if is_a (f, s_and):
                continue
            elif is_a (f, s_exists):
                continue
            elif is_a (f, s_let):
                if f.formal == x:
                    raise ValueError ("shouldn't happen?")
                continue
            elif is_a (f, s_env):
                if f.formal != x:
                    continue
                else:
                    return f.type
        else:
            return self.instantiate (self.lookup_special_names (x))

    def instantiate (self, scheme):
        print 'instantiating %r' % (scheme,)
        map = {}
        def walk (t):
            if is_a (t, int):
                if not map.has_key (t):
                    map[t] = t_var()
                return map[t]
            elif is_a (t, t_predicate):
                return t_predicate (t.name, [walk(x) for x in t.args])
            else:
                return t
        scheme = walk (scheme)
        tvars = tuple (map.values())
        return c_forall (tvars, c_true(), scheme)

    def make_prim_env (self):
        # build type schemes for builtins/primitives
        # these are written in an unusual fashion, using integers as place holders
        #   for type variables that will be instantiated fresh with each lookup.
        int = t_int()
        char = t_char()
        e = {}
        e['%+'] = arrow (int, int, int)
        e['%-'] = arrow (int, int, int)
        e['%make-pair'] = arrow (t_predicate ('pair', (0, 1)), 0, 1)
        e['%pair/first'] = arrow (t_predicate ('pair', (0, 1)), 0)
        e['%pair/second'] = arrow (t_predicate ('pair', (0, 1)), 1)
        e['%char->int'] = arrow (int, char)
        e['%int->char'] = arrow (char, int)
        e['%zed'] = arrow (int)
        # rows
        e['%abs'] = arrow (t_predicate ('abs', ())) # arg discarded
        e['%pre'] = arrow (t_predicate ('pre', (0,)), 0)
        # make a defaulted record
        e['%rmake'] = arrow (product (rdefault (0)), 0)
        # make an empty variant
        e['%vmake'] = arrow (sum (rdefault (abs())), 0)
        # label-related row types are in the 'lookup_special' method
        return e

    def lookup_special_names (self, name):
        # we need this hack for row-related label lookups, since the label
        #  is part of the name.
        print 'looking up %r' % (name,)
        if self.prim_env.has_key (name):
            # the normal path
            return self.prim_env[name]
        elif name.startswith ('%rextend/'):
            what, label = name.split ('/')
            # XXX pre(X)
            # ∀XYZ.Π(l:X;Y) → Z → Π(l:Z;Y)
            return arrow (
                product (rlabel (label, 2, 1)),
                product (rlabel (label, 0, 1)),
                2
                )
        elif name.startswith ('%raccess/'):
            what, label = name.split ('/')
            # XXX pre(X)
            # ∀XY.Π(l:X;Y) → X
            return arrow (0, product (rlabel (label, 0, 1)))
        elif name.startswith ('%vextend/'):
            what, label = name.split ('/')
            # ∀XY.X → Σ(l:pre X;Y)
            return arrow (sum (rlabel (label, pre(0), 1)), 0)
        elif name.startswith ('%vcase/'):
            what, label = name.split ('/')
            # this one's a doozy!
            # ∀XYX'Y'.(X → Y) → (Σ(l:X';Y') → Y) → Σ(l:pre X;Y') → Y
            # ∀XYX'Y'.f0 → f1 → s1 → Y
            f0 = arrow (1, 0)
            f1 = arrow (1, sum (rlabel (label, 2, 3)))
            s1 = sum (rlabel (label, pre (0), 3))
            return arrow (1, f0, f1, s1)
        else:
            raise UnboundVariable (name)

    def pprint_stack (self, s):

        W = sys.stdout.write

        W ('\n')
        n = len(s)
        # the 2 is for the two <empty> sentinels
        for i in range (2,n):
            W ('%2d: ' % (i-2,))
            si = s[i]
            if is_a (si, s_empty):
                W ('[]')
            elif is_a (si, s_and):
                W ('[] ^ %s' % si.constraint)
            elif is_a (si, s_exists):
                W ('exists%r.[]' % si.vars)
            elif is_a (si, s_let):
                W ('let %s: forall%r[[]].%r in %r' % (si.formal, si.vars, si.type, si.body))
            elif is_a (si, s_env):
                W ('env %s: %r in []' % (si.formal, si.type))
            else:
                raise NotImplementedError
            W ('\n')

def list_to_conj (l):
    # convert list <l> into a conjunction built with <c_and>
    if len(l) == 0:
        return c_true()
    elif len(l) == 1:
        return l[0]
    else:
        r = l[0]
        for x in l[1:]:
            r = c_and (r, x)
        return r

def ftv (s, t):
    # accumulate free type variables into the set <s>
    if is_a (t, t_var):
        s.add (t)
    elif is_a (t, t_predicate):
        for arg in t.args:
            ftv (s, arg)
    elif is_a (t, t_base):
        pass
    elif is_a (t, str):
        pass
    elif t is None:
        pass
    else:
        raise ValueError ("unknown type object")

def print_solution (pvars, u, top_tv):

    print pvars
    u.pprint()
    print top_tv

    def lookup (x):
        # XXX this can probably be simplified because of u.simplify
        if is_a (x, t_var):
            if x.eq:
                if x.eq.type:
                    return lookup (x.eq.type)
                else:
                    return x.eq.rep
            else:
                return x
        elif is_a (x, t_predicate):
            return t_predicate (x.name, [lookup(y) for y in x.args])
        elif is_a (x, t_base):
            return x
        elif is_a (x, str):
            # XXX row labels
            return x
        else:
            raise ValueError ("unknown type object")

    for pvar, tvar in pvars:
        print '%r: %r' % (pvar, lookup (tvar))

    print 'program: %r' % lookup (top_tv)

def base_n (n, base, digits):
    # return a string representation of <n> in <base>, using <digits>
    s = []
    while 1:
        n, r = divmod (n, base)
        s.insert (0, digits[r])
        if not n:
            break
    return ''.join (s)

def read_string (s):
    import cStringIO
    import lisp_reader
    sf = cStringIO.StringIO (s)
    r = lisp_reader.reader (sf)
    return r.read()

def test (s, step=True):
    import nodes
    global tvar_counter
    tvar_counter = 0
    # wrap everything in a top-level <let>
    # if we omit this, s-compress won't work on top-level exists
    print 'expression=', s
    s = "(let ((top %s)) top)" % s
    exp = read_string (s)
    w = nodes.walker()
    exp2 = w.go (exp)
    # alpha conversion
    nodes.rename_variables (exp2)
    cg = constraint_generator()
    c, top_tv = cg.go (exp2)
    print 'constraint=', c
    m, u = solver(step).solve(c)
    u.simplify()
    u.renumber()
    print_solution (m, u, top_tv)

tests = [
    "5",
    "(lambda (x) 3)",
    "(lambda (x) x)",
    "(let ((f (lambda (x) x))) (f 5))",
    "(lambda (x) (lambda (y) x))",
    "(let ((f (lambda (x) x))) f)",
    "((lambda (x) x) (lambda (x) x))",
    "(%+ 3 4)",
    "(%- (%+ 3 4) 2)",
    "(%make-pair 3 4)",
    "(%make-pair 3 #\\A)",
    "(%pair/first (%make-pair 3 #\\A))",
    "(%pair/second (%make-pair 3 #\\A))",
    "(%char->int #\\A)",
    "(%int->char 65)",
    # zero-argument function
    "(%zed)",
    # this causes a cycle, breaks print-solution
    #"(let ((f (lambda (x) 3))) (%int->char (f f)))",
    # row tests
    # make a default row - maps all symbols to a function: δ(a->a)
    "(%rmake (lambda (x) x))",
    # lookup via a random field name on a row of type δ(int)
    "(%raccess/field (%rmake 0))",
    # extend a row: (label:char;δ(int))
    "(%rextend/label (%rmake 0) #\\A)",
    # dereference via the defined label
    "(%raccess/label (%rextend/label (%rmake 0) #\\A))",
    # dereference via an undefined label
    "(%raccess/fnord (%rextend/label (%rmake 0) #\\A))",
    # extend a row twice: (other:a->a; label:char; δ(int))
    "(%rextend/other (%rextend/label (%rmake 0) #\\A) (lambda (x) x))",
    # deref one defined label
    "(%raccess/label (%rextend/other (%rextend/label (%rmake 0) #\\A) (lambda (x) x)))",
    # deref the other
    "(%raccess/other (%rextend/other (%rextend/label (%rmake 0) #\\A) (lambda (x) x)))",
    # deref an undefined label
    "(%raccess/fnord (%rextend/other (%rextend/label (%rmake 0) #\\A) (lambda (x) x)))",
    # build an extended record, store it in a variable, then deref to pull the identity function out and use it.
"""(let ((rec0 (%rextend/other (%rextend/label (%rmake 0) #\\A) (lambda (x) x))))
     (let ((f0 (%raccess/other rec0)))
       (f0 9)))""",
    # abs/pre
    "(%abs)",
    "(%rmake (%abs))",
    "(%rextend/l0 (%rmake (%abs)) (%pre #\\A))",
    # variants
    "(%vextend/l0 3)",
    # three args: f0 is X->Y, f1 is sum->Y, s0 is sum. returns Y.
    "(%vcase/l0 (lambda (x) 3) (lambda (y) 4) (%vextend/l0 9))"
    # now with the wrong variant (note the difference in the type of <y>)
    "(%vcase/l0 (lambda (x) 3) (lambda (y) 4) (%vextend/l1 9))"
    ]

if __name__ == '__main__':
    if '-t' in sys.argv:
        for t in tests:
            test (t, step=False)
    elif '-p' in sys.argv:
        import profile
        profile.run ("test (tests[-1], step=False)")
    elif '-l' in sys.argv:
        # try out the very last test
        test (tests[-1], step=True)
    elif '-i' in sys.argv:
        # interactive
        while 1:
            sys.stdout.write ('> ')
            line = raw_input()
            if not line:
                break
            else:
                #test (line, step=False)
                test (line, step=True)
    else:
        test ("(lambda (x) x)")