# -*- Mode:Python; coding: utf-8 -*-
# This is an implementation of the constraint-based inference algorithm described by
# Pottier and Rémy in "Advanced Topics in Types and Programming Languages", chapter 10:
# "The Essence of ML Type Inference".
# Another great reference is a somewhat simplified presentation of the
# same material, but (thankfully) with some context, by Pottier: "A
# modern eye on ML type inference - Old techniques and recent
# developments", available from his home page:
# http://cristal.inria.fr/~fpottier/
# For now, I have ignored subtyping - by having '<' mean '='. This should
# be a pretty easy thing to change, once I'm ready to wrap my head around it.
import nodes
import graph
import sys
import pdb
trace = pdb.set_trace
is_a = isinstance
# the simply typed lambda calculus:
# e ::= x | λx.e | e e
# expressions:
# x : <varref>
# λx.e : <function>
# e e : <application>
# types
# t ::= a | (arrow t t)
# (where a = <tvar>)
# constraints:
# C ::= (equals t t) | (and C C) | (exists a C)
# types
from itypes import *
# constraints
def constraint_repr (kind, args):
if args:
return '[%s %s]' % (kind, ' '.join ([repr(x) for x in args]))
return '%s' % (kind,)
class constraint:
kind = 'abstract'
args = ()
def __repr__ (self):
return constraint_repr (self.kind, self.args)
class c_true (constraint):
kind = 'true'
args = ()
class c_equals (constraint):
kind = 'equals'
def __init__ (self, *args):
if None in args:
self.args = args
def __repr__ (self):
return constraint_repr ('=', self.args)
class c_and (constraint):
kind = 'and'
def __init__ (self, c0, c1):
self.args = (c0, c1)
class c_exists (constraint):
kind = 'exists'
def __init__ (self, vars, sub):
self.args = (vars, sub)
self.vars = vars
self.sub = sub
class c_is (constraint):
# <x> has type <t> iff <t> is an instance of the type scheme associated with <x>
kind = 'is'
def __init__ (self, x, t):
self.args = (x, t)
self.x = x
self.t = t
class c_let (constraint):
kind = 'let'
def __init__ (self, names, vars, constraint, body):
self.args = (names, vars, constraint, body)
self.names = names
self.vars = vars
self.constraint = constraint
self.body = body
for i in range (len (names)):
# associate each vardef with its tvar
names[i].tv = vars[i]
class c_forall (constraint):
kind = 'forall'
def __init__ (self, vars, constraint):
self.args = (vars, constraint)
self.vars = vars
self.constraint = constraint
def flatten_conj (c):
l = []
def p (c):
if is_a (c.args[0], c_and):
p (c.args[0])
l.append (c.args[0])
l.append (c.args[1])
if is_a (c, c_and):
p (c)
return l
return [c]
def pprint_constraint (c):
W = sys.stdout.write
def pp (c, d):
W ('\n' + (' ' * d))
if is_a (c, c_let):
W ('let %s' % ' '.join (['%s:%r' % (c.names[i].name,c.vars[i]) for i in range (len (c.names))]))
if not is_a (c.constraint, c_true):
pp (c.constraint, d+1)
W ('\n' + (' ' * d))
W ('in')
pp (c.body, d+1)
elif is_a (c, c_forall):
W ('forall (%s)' % (','.join ([repr(v) for v in c.vars])))
pp (c.constraint, d+1)
elif is_a (c, c_and):
W ('and')
for t in flatten_conj (c):
pp (t, d+1)
elif is_a (c, c_exists):
W ('exists (%s)' % (','.join ([repr(v) for v in c.vars])))
pp (c.sub, d+1)
W (repr (c))
pp (c, 0)
W ('\n')
def check_constraint (c, top_tv):
# verify that all tvars are bound correctly in the constraint <c>
def lookup (v, env):
count = 0
while env is not None:
rib, env = env
if v in rib:
count += 1
if count > 1:
raise ValueError ("variable bound more than once!")
elif count == 1:
raise UnboundVariable
def pp (c, env):
if is_a (c, c_let):
pp (c.constraint, env)
pp (c.body, env)
elif is_a (c, c_forall):
rib = c.vars
pp (c.constraint, (rib, env))
elif is_a (c, c_and):
for t in flatten_conj (c):
pp (t, env)
elif is_a (c, c_exists):
pp (c.sub, (c.vars, env))
elif is_a (c, t_predicate):
for arg in c.args:
pp (arg, env)
elif is_a (c, c_is):
pp (c.t, env)
elif is_a (c, c_equals):
for arg in c.args:
pp (arg, env)
elif is_a (c, c_true):
elif is_a (c, t_var):
lookup (c, env)
elif is_a (c, t_base):
elif is_a (c, str):
# row labels
raise ValueError
pp (c, (set([top_tv]), None))
# stack frames
class frame:
kind = 'abstract'
class s_empty (frame):
kind = 'empty'
empty = s_empty()
class s_and (frame):
def __init__ (self, c):
self.constraint = c
class s_exists (frame):
def __init__ (self, vars):
self.vars = vars
class s_let (frame):
def __init__ (self, names, types, vars, body, rank):
if len(names) == 0:
raise ValueError
self.names = names
self.types = types
self.vars = set (vars)
self.body = body
self.rank = rank
for v in vars:
v.rank = rank
self.counter = 0
def add_vars (self, vars):
self.vars.update (vars)
for v in vars:
v.rank = self.rank
self.counter += len (vars)
class s_env (frame):
# after a <let> type scheme has been solved, an <env> frame
# binds the schemes to the formals.
def __init__ (self, names, types, vars, u):
self.names = names
self.types = types
self.vars = set (vars)
self.u = u
# this is a two-phase algorithm
# 1) constraint generation
# 2) constraint solving
class constraint_generator:
def __init__ (self, context):
self.context = context
def go (self, exp):
t = t_var()
c, top_tv = self.gen (exp, t), t
#check_constraint (c, top_tv)
return c, top_tv
def gen (self, exp, t):
exp.tv = t
t.node = exp
name = 'gen_%s' % exp.kind
probe = getattr (self, name)
if probe:
return probe (exp, t)
raise ValueError (exp.kind)
def gen_varref (self, exp, t):
return c_is (exp.name, t)
def gen_varset (self, exp, t):
x = t_var()
return c_exists (
c_and (
c_and (
c_is (exp.name, x),
self.gen (exp.value, x)
c_equals (t, t_undefined())
def gen_function (self, exp, t):
if is_pred (t, 'arrow'):
# lemma 10.4.7
rtv, args = t.args[0], list(t.args[1:])
rtv, args = t_var(), [t_var() for x in exp.formals]
if len(args):
c = c_let (exp.formals, args, c_true(), self.gen (exp.body, rtv))
c = self.gen (exp.body, rtv)
# XXX: in ATTPL, this is a c_supertype relation
if is_pred (t, 'arrow'):
return c
sub1 = c_equals (t, arrow (rtv, *args))
return c_exists ([rtv] + args, c_and (c, sub1))
def gen_application (self, exp, t):
args = [t_var() for x in exp.rands]
c = self.gen (exp.rator, arrow (t, *args))
for i in range (len(exp.rands)):
c = c_and (c, self.gen (exp.rands[i], args[i]))
return c_exists (args, c)
def gen_primapp (self, exp, t):
args = [t_var() for x in exp.args]
if exp.name.startswith ('%vcon/'):
# XXX don't like having to do this here.
if len(args) == 1:
sig = args[0]
sig = product (*args)
c = c_is (exp.name, arrow (t, sig))
c = c_is (exp.name, arrow (t, *args))
for i in range (len(exp.args)):
c = c_and (c, self.gen (exp.args[i], args[i]))
return c_exists (args, c)
def gen_cexp (self, exp, t):
tvars, sig = exp.type_sig
scheme = instantiate_scheme (c_forall (tvars, sig))
sig = scheme.constraint
tvars = scheme.vars
if is_pred (sig, 'arrow'):
# result type
c = c_equals (t, sig.args[0])
for i in range (len (exp.args)):
# arg types
sig_arg = sig.args[i+1]
if is_pred (sig_arg, 'raw'):
# hack: magically hide the 'raw' predicate from the solver
sig_arg = sig_arg.args[0]
c = c_and (c, self.gen (exp.args[i], sig_arg))
if len(tvars):
return c_exists (tvars, c)
return c
elif is_a (sig, t_base):
# plain type?
return c_equals (t, sig)
raise ValueError ("unhandled cexp type")
def gen_let_splat (self, exp, t):
r = self.gen (exp.body, t)
n = len (exp.names)
for i in range (n-1,-1,-1):
name = exp.names[i]
init = exp.inits[i]
var = t_var()
if name.type is not None:
r = c_and (
c_equals (var, name.type),
c_let ([name], [var], c_forall ((var,), self.gen (init, var)), r)
r = c_let ([name], [var], c_forall ((var,), self.gen (init, var)), r)
return r
def gen_fix (self, exp, t):
partition = graph.reorder_fix (exp, self.context.scc_graph)
c0 = self.gen (exp.body, t)
# Note: deep partitioning magic here
for part in partition:
names = [exp.names[i] for i in part]
funs = [exp.inits[i] for i in part]
# one var for each function
fvars = tuple ([t_var() for x in names])
c1 = list_to_conj (
[self.gen (funs[i], fvars[i]) for i in range (len (part))]
# inner/monomorphic binding
c1 = c_let (names, fvars, c_true(), c1)
# outer/polymorphic binding
c1 = c_let (names, fvars, c_forall (fvars, c1), c0)
c0 = c1
return c0
def gen_conditional (self, exp, t):
test_exp = self.gen (exp.test_exp, t_bool())
then_exp = self.gen (exp.then_exp, t)
else_exp = self.gen (exp.else_exp, t)
return c_and (test_exp, c_and (then_exp, else_exp))
def gen_sequence (self, exp, t):
n = len (exp.subs)
tvars = [t_var() for x in range (n-1)]
c = self.gen (exp.subs[-1], t)
for i in range (n-1):
# everything but the last, type it as don't-care
c = c_and (c, self.gen (exp.subs[i], tvars[i]))
return c_exists (tvars, c)
def gen_literal (self, exp, t):
return c_equals (t, base_types[exp.ltype])
def gen_constructed (self, exp, t):
return self.gen (exp.value, t)
def gen_make_tuple (self, exp, t):
# XXX don't care about the type of the arg?
v = t_var()
return c_exists ((v,), c_and (
self.gen (exp.subs[0], v),
c_equals (t, base_types[exp.ttype])
def gen_pvcase (self, exp, t):
# (pvcase <alt_formals> <alt0> <alt1> ...)
# each <alt> binds a separate set of variables (possibly empty)
# the last alt binds against either "else" (not yet implemented),
# or rdefault(abs()).
alts = exp.alts[:]
vars = []
conj = []
if len(alts) == len (exp.alt_formals):
# no else clause, a closed sum type
row = rdefault (abs())
# with an else clause, open sum type
row = t_var()
vars.append (row)
conj.append (self.gen (alts[-1], t))
for i in range (len (exp.alt_formals)):
alt = alts[i]
label, type, formals = exp.alt_formals[i]
# row type extended with this label and its type
args = [t_var() for x in range (len (formals))]
vars.extend (args)
ptype = t_var()
vars.append (ptype)
row = rlabel (label, pre(ptype), row)
if len(formals):
# XXX this c_let may be redundant, since <alt> contains its own
# lambda to do the bindings...
conj.append (c_let (formals, args, c_true(), self.gen (alt, t)))
conj.append (self.gen (alt, t))
conj.append (self.gen (exp.value, rsum (row)))
return c_exists (vars, list_to_conj (conj))
def gen_nvcase (self, exp, t):
# (nvcase <vtype> <val> <alt0> <alt1> ...)
# like a conditional, but with more branches.
dt = self.context.datatypes[exp.vtype]
if len(dt.tvars):
# it's a type scheme, instantiate it
# (strange thing to do during constraint generation, probably WRONG)
scheme = instantiate_scheme (c_forall (dt.tvars, dt.scheme))
conj = [self.gen (exp.value, scheme.constraint)]
conj = [self.gen (exp.value, dt.scheme)]
for alt in exp.alts:
if alt is not None:
conj.append (self.gen (alt, t))
# this will work even when else_clause is a dummy %%match-error
conj.append (self.gen (exp.else_clause, t))
if len(dt.tvars):
return c_exists (scheme.vars, list_to_conj (conj))
return list_to_conj (conj)
class UnboundVariable (Exception):
class multi:
# a 'standard' multi-equation of the form A=B=C=T where A,B,C are
# type variables and T is an optional type.
def __init__ (self, vars, type):
self.vars = vars
self.type = type
self.rep = self.min_rank()
for v in self.vars:
# point them all at the rep var
if v is not self.rep:
v.next = self.rep
v.eq = self
self.rank = self.rep.rank
#sys.stderr.write ('(%d)' %(len(vars)))
def min_rank (self):
# choose the variable with lowest <rank,id>
mr = sys.maxint
mv = None
for v in self.vars:
if v.rank < mr:
mr = v.rank
mv = v
elif v.rank == mr:
if v.id < mv.id:
mv = v
return mv
def free (self):
free = set()
if is_a (self.type, t_predicate):
for arg in self.type.args:
if is_a (arg, t_var):
free.add (arg)
return free
def __repr__ (self):
vars = list(self.vars)
vars.sort (lambda a,b: cmp (a.id, b.id))
r = '='.join (['%r' % v for v in vars])
if self.type:
return r + '=%r' % (self.type,)
return r
class unifier:
# Maintains a conjunction of multi-equations. In the typical HM
# algorithm, this would be called the 'subst'. It implements the
# 'union find' algorithm ('disjoint set' data structure).
def __init__ (self, step=False):
self.step = step
self.eqs = set()
self.exists = []
# memoize decoded tvars
self.decoded = {}
self.max_size = 0
def add (self, vars, type):
# add a term to the conjunction, e.g. A=B=C=T (where T is optional)
assert (is_a (vars, set))
assert (not is_a (type, t_var))
if is_a (type, t_predicate):
type = self.try_name_1 (type)
if (type is None and len(vars) == 1) or len(vars) == 0:
#self.dprint ('s-single')
# any of these vars already present?
for v in vars:
if v.in_u is self:
# if so, then fuse
self.fuse (v.eq, vars, type)
# nope, a new equation
eq = multi (vars, type)
for v in vars:
v.in_u = self
self.eqs.add (eq)
def add2 (self, *args):
# add an equation between a random collection of variables and types
vars = set()
types = []
for arg in args:
if is_a (arg, t_var):
vars.add (arg)
types.append (arg)
if len(types) == 2:
self.decompose ((vars, types[0]), (vars, types[1]))
elif len(types) > 2:
raise ValueError ("too many types")
elif len(types) == 1:
self.add (vars, types[0])
self.add (vars, None)
def try_name_1 (self, type):
# ensure that a predicate's arguments are type variables,
# naming them if necessary (rule S-NAME-1).
args2 = []
flag = False
for arg in type.args:
if is_a (arg, str):
# XXX row labels, must be a better way.
args2.append (arg)
elif not is_a (arg, t_var):
#self.dprint ('s-name-1')
x = t_var()
self.exists.append (x)
self.add (set([x]), arg)
args2.append (x)
flag = True
args2.append (arg)
if flag:
return t_predicate (type.name, args2)
return type
def forget (self, eq):
self.eqs.remove (eq)
for v in eq.vars:
v.in_u = False
def fuse (self, eq, tvs0, ty0):
tvs1 = eq.vars
ty1 = eq.type
# is a three-way fuse possible? (e.g. A=T0 B=T1; A=B=T2)
# I don't think so, so let's ignore that possibility for now.
#self.dprint ('s-fuse')
self.forget (eq)
if ty0 and ty1:
# must unify types
# A=B=T0 ^ B=C=T1 => A=B=C=T0=T1
self.decompose ((tvs0, ty0), (tvs1, ty1))
# A=B=T0 ^ B=C => A=B=C=T0
self.add (tvs0.union (tvs1), ty0 or ty1)
def decompose (self, t0, t1):
tvs0, ty0 = t0
tvs1, ty1 = t1
tvs = tvs0.union (tvs1)
if ty0 == ty1:
# a=b=int=int, etc... => a=b=int
self.add (tvs, ty0)
elif is_pred (ty0, 'rlabel', 'rdefault') or is_pred (ty1, 'rlabel', 'rdefault'):
self.unify_rows (ty0, ty1, tvs)
elif (is_a (ty0, t_predicate) and is_a (ty1, t_predicate)
and ty0.name == ty1.name
and len(ty0.args) == len(ty1.args)):
#self.dprint ('s-decompose')
# P(a,b,c)=P(d,e,f)=ε => a=d ^ b=e ^ c=f ^ P(a,b,c)=ε
for i in range (len (ty0.args)):
self.add2 (ty0.args[i], ty1.args[i])
self.add (tvs, ty0)
elif is_a (ty0, t_int) and is_a (ty1, t_int):
# covers int16 (need subtyping!)
self.add (tvs, t_int16())
#self.dprint ('s-clash')
raise TypeError ((ty0, ty1))
def unify_rows (self, ty0, ty1, tvs):
if is_pred (ty0, 'rlabel') and is_pred (ty1, 'rlabel'):
if ty0.args[0] != ty1.args[0]:
# distinct head labels
#self.dprint ('s-mutate-ll')
# XXX be concerned about how one of these may have types
# and the other has variables. do we need to check
# and reorder them?
l0, t0, d0 = ty0.args
l1, t1, d1 = ty1.args
x = t_var()
self.exists.append (x)
self.add2 (d0, rlabel (l1, t1, x))
self.add2 (d1, rlabel (l0, t0, x))
self.add (tvs, rlabel (l0, t0, d0))
# XXX this should be handled by the normal s-decompose
l0, t0, d0 = ty0.args
l1, t1, d1 = ty1.args
self.add2 (t0, t1)
self.add2 (d0, d1)
self.add (tvs, ty0)
elif is_pred (ty0, 'rlabel') or is_pred (ty1, 'rlabel'):
# only one is an rlabel
if is_pred (ty1, 'rlabel'):
# ensure that ty0 is the rlabel
ty0, ty1 = ty1, ty0
if is_pred (ty1, 'rdefault'):
#self.dprint ('s-mutate-dl')
x = ty1.args[0]
assert (is_a (x, t_var))
self.add2 (x, ty0.args[1])
self.add2 (ty1, ty0.args[2])
self.add (tvs, ty1)
elif is_a (ty1, t_predicate):
# some other predicate
#self.dprint ('s-mutate-gl')
n = len (ty1.args)
tvars0 = [t_var() for x in ty1.args]
tvars1 = [t_var() for x in ty1.args]
self.exists.extend (tvars0)
self.exists.extend (tvars1)
l0, t0, d0 = ty0.args
g = ty1.name
self.add2 (t_predicate (g, tvars0), t0)
self.add2 (t_predicate (g, tvars1), d0)
for i in range (n):
self.add2 (ty1.args[i], rlabel (l0, tvars0[i], tvars1[i]))
self.add (tvs, ty1)
#self.dprint ('s-clash')
raise TypeError ((ty0, ty1))
elif is_pred (ty0, 'rdefault',) or is_pred (ty1, 'rdefault'):
if is_pred (ty1, 'rdefault'):
# ensure that ty0 is the rdefault/δ
ty0, ty1 = ty1, ty0
if is_pred (ty1, 'rdefault'):
# they're both rdefault - normal decompose here
assert (len(ty0.args) == 1 and len(ty1.args) == 1)
self.add2 (ty0.args[0], ty1.args[0])
self.add (tvs, ty0)
elif is_a (ty1, t_predicate):
# some other predicate
#self.dprint ('s-mutate-gd')
n = len (ty1.args)
g = ty1.name
tvars = [ t_var() for x in ty1.args ]
self.exists.extend (tvars)
self.add2 (ty0.args[0], t_predicate (g, tvars))
for i in range (n):
self.add2 (ty1.args[i], rdefault (tvars[i]))
self.add (tvs, ty0)
#self.dprint ('s-clash')
raise TypeError ((ty0, ty1))
#self.dprint ('s-clash')
raise TypeError ((ty0, ty1))
def split (self, sz):
# leave in only equations made entirely of 'old' variables
# this is the U1,U2 split from the rule S-POP-LET
young = sz.vars
u2 = unifier()
to_add = []
remove = []
for eq in self.eqs:
#print 'split eq=',eq
if eq.rep in young or eq.free().intersection (young):
to_add.append ((eq.vars, eq.type))
remove.append (eq)
for eq in remove:
self.eqs.remove (eq)
for var in eq.vars:
var.in_u = False
for vars, type in to_add:
u2.add (vars, type)
#print 'u: %d eqs u2: %d eqs' % (len(self.eqs), len (u2.eqs))
return u2
def dprint (self, msg):
if self.step:
sys.stderr.write ('*** ')
sys.stderr.write (msg)
sys.stderr.write ('\n')
def sanity (self):
all = set()
for eq in self.eqs:
for v in eq.vars:
if not v.in_u:
print 'wtf?'
raise ValueError
all.add (v)
if all != set (self.vars.keys()):
raise ValueError
def simplify (self, vars, types):
# remove/unname extra variables, and replace all tvars in types
# with the rep.
#print 'before simplify'
def p (t):
# XXX speed hacks, this function is called a *lot*
#if is_a (t, t_var):
if t.__class__ is t_var:
if t.in_u is self:
# if it's in our set, return its rep
return t.eq.rep
# free variable
return t
# elif is_a (t, t_predicate):
elif t.__class__ is t_predicate:
#return t_predicate (t.name, [p(x) for x in t.args])
# another speed hack
t.args = [p(x) for x in t.args]
return t
return t
unname = set()
forget = set()
for eq in list(self.eqs):
new_vars = set([eq.rep])
for v in eq.vars:
if v is not eq.rep:
if v in vars and not v in types:
unname.add (v)
v.next = eq.rep
new_vars.add (v)
eq.vars = new_vars
if eq.type:
eq.type = p (eq.type)
# 'j=j' helps no one.
forget.add (eq)
for v in unname:
v.in_u = False
for eq in forget:
self.forget (eq)
if types:
types = [p(x) for x in types]
return unname, types
def prune (self, types, vars):
# Without pruning, the graph (i.e., 'U') can get HUGE (>10000 equations).
# prune any unconnected parts of the graph.
# first, make an adjacency graph (i.e., digraph -> graph)
adj = {}
for eq in self.eqs:
t = eq.type
v0 = eq.rep
if is_a (t, t_predicate):
for v1 in t.args:
adj.setdefault (v0, set()).add (v1)
adj.setdefault (v1, set()).add (v0)
# now search starting from <types>
visited = set()
def visit (v):
if v not in visited:
visited.add (v)
if v in adj:
for v0 in adj[v]:
visit (v0)
for v in types:
visit (v)
pruned = 0
# finally, remove any unvisited equations
for eq in list (self.eqs):
if eq.rep not in visited:
self.forget (eq)
pruned += 1
# and any unvisited vars
new_vars = []
for v in vars:
if v in visited:
new_vars.append (v)
return new_vars
def find_free (self, bound):
# find the free variables of this unifier
for eq in self.eqs:
bound.update (eq.vars)
free = set()
for eq in self.eqs:
t = eq.type
if is_a (t, t_var) and t not in bound:
free.add (t)
elif is_a (t, t_predicate):
for arg in t.args:
if is_a (arg, t_var) and arg not in bound:
free.add (arg)
return free
def reverse_graph (self):
g = {}
def add (k, v):
if not g.has_key (k):
g[k] = set([v])
for eq in self.eqs:
t = eq.type
if t and is_a (t, t_predicate):
for arg in t.args:
if is_a (arg, t_var):
add (arg, eq.rep)
return g
def do_letall (self, xbar, free):
# find <ybar>, the subset of <xbar> that is 'determined' by <free>
#print 'do_letall',
xbar = set (xbar)
rg = self.reverse_graph()
# first pass, any types using vars from <free>
y = []
for eq in self.eqs:
t = eq.type
#if is_a (t, t_predicate):
if t.__class__ is t_predicate:
for v in t.args:
if is_a (v, t_var) and v in free:
y.append (eq.rep)
y2 = set()
seen = set()
while y:
v = y.pop(0)
if v not in seen:
seen.add (v)
if v in xbar:
y2.add (v)
if rg.has_key (v):
y.extend (rg[v])
return y2
def pprint (self):
sys.stdout.write ('U: ')
eqs = list (self.eqs)
# sort the equations by representative tvar
eqs.sort (lambda a,b: cmp (a.rep.id, b.rep.id))
for eq in eqs:
sys.stdout.write ('\t%r\n' % (eq,))
sys.stdout.write ('\n')
def instantiate_scheme (scheme):
# instantiate a human-style type scheme (as returned from lookup_special_names())
if not is_a (scheme, c_forall):
return scheme
vars = scheme.vars
nvars = []
map = {}
for v in vars:
# fresh tvar for each quantified tvar
tv = t_var()
map[v] = tv
nvars.append (tv)
def f (c):
if is_a (c, t_predicate):
return t_predicate (c.name, [f(x) for x in c.args])
elif is_a (c, t_var) or is_a (c, int):
if c in vars:
return map[c]
return c
return c
return c_forall (nvars, f (scheme.constraint))
class solver:
def __init__ (self, context, exp, verbose=False, step=False):
self.context = context
self.exp = exp
self.step = step
# xxx need to split the notion of verbose and step
self.step = step
self.try_unname = False
def dprint (self, msg):
if self.step:
sys.stderr.write (msg)
sys.stderr.write ('\n')
def solve (self, c):
#self.dprint ('\nHit <return> at each pause (or "t<return>" to enter the debugger)')
pvars = {}
self.exists = []
# ensure there are always two items on the stack
s = [empty, empty]
u = unifier (step=self.step)
c = c
orig_c = c
rank = 0
steps = 0
def push (x):
s.append (x)
def pop ():
while 1:
if self.step:
print 'step %d' % steps
print 'S:',
self.pprint_stack (s)
print 'C:', c
print 'exists:', self.exists
steps += 1
# the top two elements of the stack
sy, sz = s[-2], s[-1]
if self.step:
print '-----------------------------'
if raw_input().startswith ('t'):
if self.try_unname:
self.do_extra_unname (u, s)
# --- solver ---
if u.exists:
#self.dprint ('s-ex-1')
self.move_exists (s, u.exists)
u.exists = []
elif is_a (sz, s_exists):
#self.dprint ('s-record-ex')
self.exists.extend (sz.vars)
elif is_a (c, c_equals):
#self.dprint ('s-solve-eq')
u.add2 (*c.args)
except TypeError, terr:
self.find_type_error (terr, c.args, u, s)
c = c_true()
elif is_a (c, c_is) and is_a (c.x, str):
#self.dprint ('s-solve-id')
# if I pass c.t down into lookup(), it's easier to attach the necessary
# c_equals() since there are two types of instantiation.
con = self.lookup (c.x, s, c.t)
# "Recall that if σ is of the form ∀X0..XN[U].X
# where X0..XN#ftv(T), then c_is(σ, T) stands for ∃X0..XN.(U ^ X=T)."
#self.dprint ('name=%s' % (c.x,))
#self.dprint ('scheme= %r' % (scheme,))
#self.dprint ('type=%r' % c.t)
if is_a (con, c_forall):
c = c_exists (con.vars, con.constraint)
c = con
elif is_a (c, c_and):
#self.dprint ('s-solve-and')
push (s_and (c.args[1]))
c = c.args[0]
elif is_a (c, c_exists):
#self.dprint ('s-solve-ex')
self.move_exists (s, c.vars)
c = c.sub
elif is_a (c, c_let):
#self.dprint ('s-solve-let')
if is_a (c.constraint, c_forall):
push (s_let (c.names, c.vars, c.constraint.vars, c.body, rank))
rank += 1
c = c.constraint.constraint
# let x: T in C == let x: ∀∅[true].T in C
push (s_let (c.names, c.vars, (), c.body, rank))
rank += 1
c = c_true()
elif is_a (c, c_true):
if is_a (sz, s_and):
#self.dprint ('s-pop-and')
c = sz.constraint
elif is_a (sz, s_let):
unname, types = u.simplify (sz.vars, sz.types)
if unname and sz.vars:
#self.dprint ('s-unname %s' % (ps(unname),))
vars = [x for x in sz.vars if x not in unname]
#self.dprint (' old vars=%s' % (ps(sz.vars),))
#self.dprint (' new vars=%s' % (vars,))
#self.dprint (' old types=%s' % (sz.types,))
#self.dprint (' new types=%s' % (types,))
#print 'unnamed %s %d' % (sz.names, len(unname))
sz.vars.difference_update (unname)
sz.types = types
#push (s_let (sz.names, types, vars, sz.body, sz.rank))
if len(sz.vars):
free = u.find_free (set (sz.vars))
# am I sure about this?
free.difference_update (set (sz.types))
if free:
# s-letall is only applicable if there are free variables in U.
# partition sz.vars into Ybar and Xbar, where ∃Xbar.U determines Ybar
# ------------- s-letall
ybar = u.do_letall (sz.vars, free)
# make these guys 'old'
#print 'moving...', ybar
self.move_exists (s, ybar)
vars = [x for x in sz.vars if x not in ybar]
push (s_let (sz.names, sz.types, vars, sz.body, sz.rank))
# XXX if we move Ybar up, will the changes in rank be visible in U,
# or do we need to do that manually?
# ------------- s-letall
sz = s[-1]
# the conditions have been met; turn the <let> into an <env>.
#self.dprint ('s-pop-let')
u2 = u.split (sz)
#print 'split'
sz.vars = u2.prune (sz.types, sz.vars)
#sys.stderr.write ('[%d %d]%r\n' % (len (sz.vars), len(u.eqs), sz.names))
push (s_env (sz.names, sz.types, sz.vars, u2))
rank -= 1
c = sz.body
elif is_a (sz, s_env):
#self.dprint ('s-pop-env')
elif is_a (sz, s_empty):
# we're done!
#self.dprint ('exists=%r' % self.exists)
#self.dprint ('constraint=%r' % orig_c)
return pvars
raise ValueError ("unexpected")
raise ValueError ("no rule applies")
def move_exists (self, s, vars):
# this implements the various S-EX-? rules that attach a set of tvars to
# the nearest <let> on the stack.
n = len (s)
for i in range (-1, -n, -1):
if is_a (s[i], s_let):
#print 'adding %r to s_let binding %r' % (vars, s[i].names)
for v in vars:
if v in s[i].vars:
s[i].add_vars (vars)
# 1000 - empirically gives the fastest solver times
if s[i].counter > 1000:
self.try_unname = True
s[i].counter = 0
#sys.stderr.write ('{%d}' % (len(s[i].vars)))
self.exists.extend (vars)
def do_extra_unname (self, u, s):
# find the first s_let on the stack
i = -1
while 1:
if is_a (s[i], s_let):
si = s[i]
i -= 1
unname, ignore = u.simplify (si.vars, set())
si.vars.difference_update (unname)
self.try_unname = False
def instantiate_constraint (self, i, env, t, generalize=True):
# instantiate a constraint-based type scheme (as found on the stack in an s_env)
# can we do most of this work inside the unifier class?
# ok, we have an env frame. we want to instantiate this scheme.
# env.vars are the quantified ones, we need to replace them
#print 'instantiating', env.names[i], env.vars, env.types, env.u.eqs
scheme = env.types[i]
new = {}
if generalize:
for v in env.vars:
new[v] = t_var()
used = set()
eqs = list(env.u.eqs)
conj = []
#print 'new=', new
for i in range (len (eqs)):
eq = eqs[i]
new_vars = [new.get (x,x) for x in eq.vars]
new_vars.sort (lambda a,b: cmp (a.id, b.id))
if is_a (eq.type, t_predicate):
new_args = []
for arg in eq.type.args:
if is_a (arg, t_var):
# first, replace any var with its rep
if arg.in_u is env.u:
v = arg.eq.rep
v = arg
# then, rename it if applicable
v = new.get (v, v)
new_args.append (v)
elif is_a (arg, str):
# row labels
new_args.append (arg)
raise ValueError ("small term constraint?")
new_type = t_predicate (eq.type.name, new_args)
elif is_a (eq.type, t_base):
new_type = eq.type
# XXX can this make sense? references to variables should
# have been simplified to pick the highest-ranking variable,
# so a type-less equation is not useful.
#raise ValueError
new_type = None
# XXX uncomment this and we break f_letpoly.scm
# v=v=v=t => v=v & v=v & v=v & v=t
if new_type:
obs = new_vars + [new_type]
obs = new_vars
c = c_equals (obs[0], obs[1])
for ob in obs[2:]:
c = c_and (c, c_equals (obs[0], ob))
conj.append (c)
# XXX we should remove any vars from new_vars that were not referenced!
if scheme in env.vars and generalize:
scheme = new[scheme]
conj.append (c_equals (t, scheme))
if len(new):
return c_forall (new.values(), list_to_conj (conj))
return list_to_conj (conj)
def lookup (self, x, s, t):
n = -1
while 1:
f = s[n]
n -= 1
if is_a (f, s_env):
for i in range (len (f.names)):
if f.names[i].name == x:
# our cheap version of the value restriction: restrict <ref> cells
# to a monomorphic type. later, let's do the whole
# 'expansive/non-expansive' version.
generalize = not f.names[i].assigns
#print '%s %r' % (f.names[i], generalize)
return self.instantiate_constraint (i, f, t, generalize)
elif f is empty:
scheme = instantiate_scheme (self.lookup_special_names (x))
return c_forall (scheme.vars, c_equals (t, scheme.constraint))
# A trick I've used here is to encode the arity into the name of
# some of the prims, making it possible to return a correct arrow
# type for each. I'm not entirely happy with it, but at least
# it's clear.
def lookup_special_names (self, name):
if name == '%rmake':
return c_forall ((), arrow (rproduct (rdefault (abs()))))
elif name.startswith ('%rextend/'):
what, label = name.split ('/')
# ∀XYZ.(Π(l:X;Y), Z) → Π(l:pre(Z);Y)
return c_forall (
arrow (
rproduct (rlabel (label, pre(2), 1)),
rproduct (rlabel (label, 0, 1)),
elif name.startswith ('%raccess/'):
what, label = name.split ('/')
# ∀XY.Π(l:pre(X);Y) → X
return c_forall ((0,1), arrow (0, rproduct (rlabel (label, pre(0), 1))))
elif name.startswith ('%rset/'):
what, label = name.split ('/')
# ∀XY.(Π(l:pre(X);Y), X) → undefined
return c_forall ((0,1), arrow (t_undefined(), rproduct (rlabel (label, pre(0), 1)), 0))
elif name == '%vfail':
return c_forall ((0,), arrow (0, rsum (rdefault (abs()))))
elif name.startswith ('%dtcon/'):
# lookup the type of the particular constructor
what, dtname, label = name.split ('/')
dt = self.context.datatypes[dtname]
# e.g. list := nil | cons X list
# %dtcons/list/cons := ∀X.(X,list(X)) → list(X)
args = dt.constructors[label]
return c_forall (dt.tvars, arrow (dt.scheme, *args))
elif name.startswith ('%vcon/'):
what, label, arity = name.split ('/')
arity = int(arity)
# remember each unique variant label
self.remember_variant_label (label)
if arity == 0:
# ∀X.() → Σ(l:pre (Π());X)
return c_forall ((1,), arrow (rsum (rlabel (label, pre (product()), 1)), product()))
elif arity == 1:
# ∀XY.X → Σ(l:pre X;Y)
return c_forall ((0,1), arrow (rsum (rlabel (label, pre(0), 1)), 0))
# ∀ABCD.Π(A,B,C) → Σ(l:pre (Π(A,B,C));D)
args = tuple(range (arity))
return c_forall (range(arity+1), arrow (rsum (rlabel (label, pre (product(*args)), arity)), product (*args)))
elif name.startswith ('%vcase/'):
what, label, arity = name.split ('/')
arity = int (arity)
# ∀012345.(3,4,5) → 0, Σ(l:1;2) → 0, Σ(l:pre(Π(3,4,5);2) → 0
# ∀012345.f0,f1,s1 → 0
args = range (3, arity+3)
# success continuation
f0 = arrow (0, *args)
# failure continuation
f1 = arrow (0, rsum (rlabel (label, 1, 2)))
# the sum argument
if arity == 1:
t = args[0]
t = product (*args)
s1 = rsum (rlabel (label, pre (t), 2))
return c_forall (range(arity+3), arrow (0, f0, f1, s1))
elif name.startswith ('%vget/'):
what, label, arity, index = name.split ('/')
arity = int (arity)
index = int (index)
args = range (arity)
rest = arity
# e.g., to pick the second arg:
# ∀0123. Σ(l:pre (0,1,2);3) → 1
if arity > 1:
vtype = rsum (rlabel (label, pre (product (*args)), rest))
vtype = rsum (rlabel (label, pre (args[0]), rest))
return c_forall (args + [arity], arrow (args[index], vtype))
elif name.startswith ('%nvget/'):
what, dtype, label, index = name.split ('/')
dt = self.context.datatypes[dtype]
ti = dt.constructors[label][int(index)]
return c_forall (dt.tvars[:], arrow (ti, dt.scheme))
elif name.startswith ('%vector-literal/'):
what, arity = name.split ('/')
arg_types = (0,) * int (arity)
return c_forall ((0,), arrow (vector(0), *arg_types))
elif name.startswith ('%make-vector'):
return c_forall ((0,), arrow (vector(0), t_int(), 0))
elif name.startswith ('%make-vec16'):
return c_forall ((), arrow (vector(t_int16()), t_int()))
elif name == '%%array-ref':
return c_forall ((0,), arrow (0, vector (0), t_int()))
elif name == '%%array-set':
return c_forall ((0,), arrow (t_undefined(), vector (0), t_int(), 0))
elif name == '%vec16-set':
return c_forall ((), arrow (t_undefined(), vector(t_int16()), t_int(), t_int16()))
# ------
# pattern matching
# ------
elif name == '%%match-error':
return c_forall ((0,), arrow (0))
elif name == '%%fatbar':
return c_forall ((0,0), arrow (0, 0, 0))
elif name == '%%fail':
return c_forall ((0,), arrow (0))
# -------
elif name.count (':') == 1:
# a constructor used in a 'constructed literal'
dt, alt = name.split (':')
return self.lookup_special_names ('%%dtcon/%s/%s' % (dt, alt))
raise UnboundVariable (name)
# XXX consider recording record labels at this point as well
def remember_variant_label (self, label):
vl = self.context.variant_labels
if not vl.has_key (label):
# adjust for the hacked pre-installed labels like 'cons' and 'nil'.
vl[label] = len (vl) - self.context.nvariant_offset
def pprint_stack (self, s):
W = sys.stdout.write
W ('\n')
n = len(s)
# the 2 is for the two <empty> sentinels
for i in range (2,n):
W ('%2d: ' % (i-2,))
si = s[i]
if is_a (si, s_empty):
W ('[]')
elif is_a (si, s_and):
W ('[] ^ %s' % si.constraint)
elif is_a (si, s_exists):
W ('exists%r.[]' % si.vars)
elif is_a (si, s_let):
names = ';'.join (['%s:%r' % (si.names[i].name, si.types[i]) for i in range (len (si.names))])
W ('let %s: forall %s[[]] in %r' % (names, ps (si.vars), si.body))
elif is_a (si, s_env):
names = ';'.join (['%s:%r' % (si.names[i].name, si.types[i]) for i in range (len (si.names))])
W ('env %s: %s %r in []' % (names, ps (si.vars), si.u.eqs))
raise NotImplementedError
W ('\n')
def find_type_error (self, args, terr, u, s):
ty0, ty1 = terr
W = sys.stdout.write
def find_node (v):
# try to find a related node for each tvar
if is_a (v, t_var):
if v.node:
return ty0.node
for vi in v.eq.vars:
if vi.node:
return vi.node
return None
n0 = find_node (ty0)
n1 = find_node (ty1)
if n0 is n1:
W ('node %r\n' % (n0,))
if n0:
W ('node0 %r\n' % (n0,))
if n1:
W ('node1 %r\n' % (n1,))
# find the portion of the program
all = []
def walk_depth (n, d):
all.append ((n, d))
for sub in n.subs:
walk_depth (sub, d+1)
walk_depth (self.exp, 0)
def near (n):
lines = self.context.type_error_lines
# we want <lines> before and after
total = len (all)
start = 0
end = total
for i in range (total):
if all[i][0] is n:
start = max (i-lines, start)
end = min (i+lines, end)
for ni, depth in all[start:end]:
if ni is n:
indent = '--'
indent = ' '
W ('%s%r\n' % (indent * depth, ni))
if n0:
W ('\n n0:\n')
near (n0)
if n1 and (n0 != n1):
W ('\n n1:\n')
near (n1)
def ps (s):
l = list(s)
l.sort (lambda a,b: cmp (a.id, b.id))
return '{%s}' % repr(l)[1:-1]
def list_to_conj (l):
# convert list <l> into a conjunction built with <c_and>
if len(l) == 0:
return c_true()
elif len(l) == 1:
return l[0]
r = l[0]
for x in l[1:]:
r = c_and (r, x)
return r
def read_string (s):
import cStringIO
import lisp_reader
sf = cStringIO.StringIO (s)
r = lisp_reader.reader (sf)
return r.read()
def get_record_sig (t):
# rproduct (rlabel (...))
if is_a (t, c_forall):
t = t.constraint
assert (is_pred (t, 'rproduct'))
labels = []
t = t.args[0]
while 1:
if is_pred (t, 'rlabel'):
label, type, rest = t.args
if is_pred (type, 'pre'):
labels.append (label)
t = rest
elif is_pred (t, 'rdefault'):
elif is_a (t, t_var):
labels.append ('...')
return None
return tuple (labels)
class typer:
def __init__ (self, context, verbose, step):
self.context = context
self.verbose = verbose
self.step = step
def go (self, exp):
cg = constraint_generator (self.context)
c, top_tv = cg.go (exp)
print 'solving...'
if self.verbose:
pprint_constraint (c)
s = solver (self.context, exp, self.verbose, self.step)
m = s.solve (c)
print 'decoding...'
for node in exp:
node.type = decode (node.tv)
if node.one_of ('let_splat', 'fix'):
for name in node.names:
if not name.type:
name.type = decode (name.tv)
print name
if self.context.print_types and node.is_a ('function') and node.name:
print node.name, '\t', node.type
def decode (t):
seen = set()
def p (t):
if t in seen:
return t
seen.add (t)
if is_a (t, t_var):
if t.eq and t.eq.type:
return p (t.eq.type)
elif t.next:
return p (t.next)
return t
elif is_a (t, t_predicate):
return t_predicate (t.name, [p (x) for x in t.args])
return t
return p (t)
def test (s, step=True):
import transform
import nodes
from pprint import pprint as pp
global tvar_counter
tvar_counter = -1
# wrap everything in a top-level <let>
s = "(let ((top %s)) top)" % s
exp = read_string (s)
t = transform.transformer (1)
exp2 = t.go ([exp])
w = nodes.walker()
exp3 = w.go (exp2)
# alpha conversion
var_dict = nodes.rename_variables (exp3)
t = typer ({}, step)
c = t.go (exp3)
if __name__ == '__main__':
if '-v' in sys.argv:
step = True
step = False
# interactive test mode
while 1:
sys.stdout.write ('> ')
line = raw_input()
if not line:
test (line, step=step)