# -*- Mode:Python; coding: utf-8 -*-
import nodes
import graph
import sys
from itypes import *
from pprint import pprint as pp
from pdb import set_trace as trace
is_a = isinstance
# The 'subst', or type substitution/map, is not an actual data structure,
# but rather lives in the '.val' attribute of the set of all type variables.
# To 'apply the subst', simply follow the path through each type variable
# until you get to something that's not a tvar.
# this is now only applied *after* unification.
def apply_subst_to_type (t):
# Another task performed here: the detection of recursive types.
# This is done by adding a notation to a tvar before recursing
# into it. When we detect a cycle, we create a new moo_var, and
# at the appropriate place create a 'moo' predicate binding the
# variable.
def p (t):
if is_a (t, str):
# rlabel predicate does this
return t
# equivalence class
t = t.find()
if t.pending:
t.mv = t_var()
return t.mv
else:
# replace all known tvars in <t>
if is_a (t, t_predicate):
if t.name == 'moo':
# we've already been here!
return t
else:
t.pending = True
r = t_predicate (t.name, [p(x) for x in t.args])
t.pending = False
if t.mv:
r = moo (t.mv, r)
return r
else:
return r
else:
return t
return p (t)
# http://en.wikipedia.org/wiki/Disjoint-set_data_structure
# this is Huet's algorithm
# See Kevin Knight: "Unification: A multidisciplinary survey (1989)"
glork = False
def unify (t0, t1):
if glork:
print t0, t1
u = t0.find()
v = t1.find()
if u != v:
if is_a (u, t_base) and is_a (v, t_base):
raise TypeError ((u, v))
# XXX unification would be simpler if all base types were done as no-arg predicates.
elif is_a (u, t_base) and is_a (v, t_predicate):
raise TypeError ((u, v))
elif is_a (u, t_predicate) and is_a (v, t_base):
raise TypeError ((u, v))
elif is_a (u, t_var) or is_a (v, t_var):
pass
elif is_pred (u, 'moo') and is_pred (v, 'moo'):
pass
elif is_pred (u, 'moo') or is_pred (v, 'moo'):
# note early exit...
return unify_moo (u, v)
elif is_pred (u, 'rlabel', 'rdefault') or is_pred (v, 'rlabel', 'rdefault'):
# note early exit...
return unify_rows (u, v)
elif is_a (u, t_predicate) and is_a (v, t_predicate) and (u.name != v.name or len (u.args) != len (v.args)):
raise TypeError ((u, v))
u.union (v)
if is_a (u, t_predicate) and is_a (v, t_predicate):
for i in range (len (u.args)):
unify (u.args[i], v.args[i])
else:
pass
# This implementation of rows is based on the one in ATTPL, all of which are based on Rémy's
# addition of pre() and abs() predicates to Wand's formulation. See section 10.8 of ATTPL,
# or "Type Inference for Records in a Natural Extension of ML" by Rémy.
def unify_rows (ty0, ty1):
if is_pred (ty0, 'rlabel') and is_pred (ty1, 'rlabel'):
if ty0.args[0] != ty1.args[0]:
# distinct head labels, C-MUTATE-LL
l0, t0, d0 = ty0.args
l1, t1, d1 = ty1.args
x = t_var()
unify (d0, rlabel (l1, t1, x))
unify (d1, rlabel (l0, t0, x))
else:
l0, t0, d0 = ty0.args
l1, t1, d1 = ty1.args
unify (t0, t1)
unify (d0, d1)
elif is_pred (ty0, 'rlabel') or is_pred (ty1, 'rlabel'):
# only one is an rlabel
if is_pred (ty1, 'rlabel'):
# ensure that ty0 is the rlabel
ty0, ty1 = ty1, ty0
if is_pred (ty1, 'rdefault'):
# C-MUTATE-DL
x = ty1.args[0]
unify (x, ty0.args[1])
unify (ty1, ty0.args[2])
elif is_a (ty1, t_predicate):
# some other predicate
# S-MUTATE-GL
n = len (ty1.args)
tvars0 = [t_var() for x in ty1.args]
tvars1 = [t_var() for x in ty1.args]
l0, t0, d0 = ty0.args
g = ty1.name
unify (t_predicate (g, tvars0), t0)
unify (t_predicate (g, tvars1), d0)
for i in range (n):
unify (ty1.args[i], rlabel (l0, tvars0[i], tvars1[i]))
else:
raise TypeError ((ty0, ty1))
elif is_pred (ty0, 'rdefault',) or is_pred (ty1, 'rdefault'):
if is_pred (ty1, 'rdefault'):
# ensure that ty0 is the rdefault/δ
ty0, ty1 = ty1, ty0
if is_pred (ty1, 'rdefault'):
# they're both rdefault - normal decompose here
assert (len(ty0.args) == 1 and len(ty1.args) == 1)
# usually rdefault(abs) == rdefault(abs)
unify (ty0.args[0], ty1.args[0])
elif is_a (ty1, t_predicate):
# some other predicate, S-MUTATE-GD
n = len (ty1.args)
g = ty1.name
tvars = [ t_var() for x in ty1.args ]
unify (ty0.args[0], t_predicate (g, tvars))
for i in range (n):
unify (ty1.args[i], rdefault (tvars[i]))
else:
raise TypeError ((ty0, ty1))
else:
raise TypeError ((ty0, ty1))
# XXX TODO: verify that all recursive types go through a row type.
# XXX can I be simplified?
def unify_moo (t0, t1):
if is_pred (t1, 'moo'):
# swap so t0 is always the moo
t1, t0 = t0, t1
# is this enough?
unify (t0.args[0], t1)
def occurs_in_type (tvar, type):
for t in walk_type (type):
if tvar == t:
return True
else:
return False
# XXX apparently this is done differently in many implementations,
# somehow passing a depth argument around the type_of() functions
# makes this easier?
def occurs_free_in_tenv (tvar, tenv):
while tenv:
rib, tenv = tenv
for var, type in rib:
if is_a (type, forall) and tvar in type.gens:
# skip it if it's shadowed (should never happen...)
pass
elif occurs_in_type (tvar, type):
return True
return False
# if a node has user-supplied type, use it. otherwise
# treat it as a type variable.
# XXX untested in this new solver.
def optional_type (exp, tenv):
if exp.type:
return exp.type
else:
return t_var()
class forall:
def __init__ (self, gens, type):
self.gens = gens
self.type = type
def __repr__ (self):
return '<forall %r %r>' % (self.gens, self.type)
def build_type_scheme (type, tenv, name):
gens = set()
def list_generic_tvars (t):
if is_a (t, t_var):
if not occurs_free_in_tenv (t, tenv):
gens.add (t)
elif is_pred (t, 'moo'):
list_generic_tvars (t.args[1])
elif is_a (t, t_predicate):
for arg in t.args:
list_generic_tvars (arg)
elif is_a (t, t_base):
pass
elif is_a (t, str):
pass
elif is_a (t, moo_var):
list_generic_tvars (t.tvar)
else:
raise ValueError
type = apply_subst_to_type (type)
list_generic_tvars (type)
if not gens:
return type
else:
return forall (gens, type)
def instantiate_type (type, tvar, fresh_tvar):
def f (t):
if is_a (t, t_var) or is_a (t, int):
if t == tvar:
return fresh_tvar
else:
return t
elif is_a (t, t_predicate):
return t_predicate (t.name, [f(x) for x in t.args])
else:
return t
return f (type)
def instantiate_type_scheme (tscheme):
gens = tscheme.gens
body = tscheme.type
for gen in gens:
# ah, it's just repeatedly substituting...
body = instantiate_type (body, gen, t_var())
return body
def apply_tenv (tenv, name):
def inst (t):
if is_a (t, forall):
return instantiate_type_scheme (t)
else:
return t
while tenv:
rib, tenv = tenv
# walk the rib backwards for the sake of let*
for i in range (len(rib)-1, -1, -1):
var, type = rib[i]
if var == name:
# is this a type scheme?
return inst (type)
raise ValueError (name)
class UnboundVariable (Exception):
pass
class typer:
def __init__ (self, context):
self.context = context
self.verbose = self.context.verbose
def go (self, exp):
self.exp = exp
tenv = (self.initial_type_environment(), None)
try:
result = self.type_of (exp, tenv)
except TypeError:
sys.exit (1)
for node in exp:
if node.type:
if not hasattr (node.type, 'final'):
# cache
node.type.final = apply_subst_to_type (node.type)
node.type = node.type.final
if self.verbose or self.context.print_types:
for n in exp:
if n.is_a ('function'):
print n.name, n.type
return result
def initial_type_environment (self):
constructors = []
if False:
for name, dt in self.context.datatypes.iteritems():
poly_dt = build_type_scheme (dt, None, name)
# store this type scheme in the type map
the_type_map[name] = poly_dt
for name in dt.get_datatype_constructors():
constructors.append ((name, poly_dt))
return constructors
def unify (self, t0, t1, tenv, exp):
try:
return unify (t0, t1)
except TypeError as terr:
self.print_type_error (exp, terr)
def print_type_error (self, exp, terr):
t0, t1 = terr.args[0]
W = sys.stderr.write
W ('\n---------------\nType Error:\n')
W (' t0: %r\n' % (t0,))
W (' t1: %r\n' % (t1,))
W ('\nnear:\n')
# find the portion of the program
all = []
def walk_depth (n, d):
all.append ((n, d))
for sub in n.subs:
walk_depth (sub, d+1)
walk_depth (self.exp, 0)
# XXX this capability needs to be outside this file
def near (n):
lines = self.context.type_error_lines
# we want <lines> before and after
total = len (all)
start = 0
end = total
for i in range (total):
if all[i][0] is n:
start = max (i-lines, start)
end = min (i+lines, end)
break
for ni, depth in all[start:end]:
if ni is n:
indent = '--'
else:
indent = ' '
W ('%s%r\n' % (indent * depth, ni))
near (exp)
raise
def type_of (self, exp, tenv):
kind = exp.kind
method = getattr (self, 'type_of_%s' % (kind,))
exp.type = method (exp, tenv)
return exp.type
def type_of_literal (self, exp, tenv):
return base_types[exp.ltype]
def type_of_constructed (self, exp, tenv):
return self.type_of (exp.value, tenv)
def type_of_cexp (self, exp, tenv):
tvars, sig = exp.type_sig
scheme = forall (tvars, sig)
sig = instantiate_type_scheme (scheme)
if is_pred (sig, 'arrow'):
result_type = sig.args[0]
arg_types = sig.args[1:]
for i in range (len (arg_types)):
arg_type = arg_types[i]
arg = exp.args[i]
if is_pred (arg_type, 'raw'):
# hack: magically hide the 'raw' predicate
arg_type = arg_type.args[0]
ta = self.type_of (arg, tenv)
self.unify (ta, arg_type, tenv, arg)
return result_type
else:
return sig
def type_of_conditional (self, exp, tenv):
t1 = self.type_of (exp.test_exp, tenv)
self.unify (t1, t_predicate ('bool', ()), tenv, exp.test_exp)
t2 = self.type_of (exp.then_exp, tenv)
t3 = self.type_of (exp.else_exp, tenv)
self.unify (t2, t3, tenv, exp)
return t2
def type_of_let_splat (self, exp, tenv):
n = len (exp.inits)
for i in range (n):
init = exp.inits[i]
name = exp.names[i]
ta = self.type_of (init, tenv)
# user-supplied type
if name.type is not None:
self.unify (ta, name.type, tenv, exp)
tenv = ([(name.name, ta)], tenv)
return self.type_of (exp.body, tenv)
def type_of_function (self, exp, tenv):
type_rib = []
arg_types = []
for formal in exp.formals:
t = optional_type (formal, tenv)
arg_types.append (t)
type_rib.append ((formal.name, t))
body_type = self.type_of (exp.body, (type_rib, tenv))
r = arrow (body_type, *arg_types)
# useful during complex type debugging
#if exp.name:
# print exp.name, apply_subst_to_type (r)
return r
def type_of_application (self, exp, tenv):
n = len (exp.rands)
rator = exp.rator
rator_type = self.type_of (exp.rator, tenv)
# normal application
arg_types = []
for i in range (n):
ta = self.type_of (exp.rands[i], tenv)
arg_types.append (ta)
result_type = t_var() # new type variable
self.unify (rator_type, arrow (result_type, *arg_types), tenv, exp)
return result_type
def type_of_varref (self, exp, tenv):
r = apply_tenv (tenv, exp.name)
return r
def type_of_varset (self, exp, tenv):
# XXX implement the no-generalize rule for vars that are assigned.
t1 = apply_tenv (tenv, exp.name)
t2 = self.type_of (exp.value, tenv)
self.unify (t1, t2, tenv, exp.value)
return t_undefined()
def type_of_sequence (self, exp, tenv):
for sub in exp.subs[:-1]:
# everything but the last, type it as don't-care
ti = self.type_of (sub, tenv)
return self.type_of (exp.subs[-1], tenv)
def type_of_primapp (self, exp, tenv):
# look it up in the environment.
scheme = self.lookup_special_names (exp.name, exp.name_params)
sig = instantiate_type_scheme (scheme)
# XXX almost identical to type_of_cexp(), factor it out.
result_type = sig.args[0]
arg_types = sig.args[1:]
for i in range (len (exp.args)):
arg_type = arg_types[i]
arg = exp.args[i]
ta = self.type_of (arg, tenv)
self.unify (ta, arg_type, tenv, arg)
return result_type
def lookup_special_names (self, name, params):
if name == '%rmake':
return forall ((), arrow (rproduct (rdefault (abs()))))
elif name.startswith ('%rextend/'):
what, label = name.split ('/')
# ∀XYZ.(Π(l:X;Y), Z) → Π(l:pre(Z);Y)
return forall (
(0,1,2),
arrow (
rproduct (rlabel (label, pre(2), 1)),
rproduct (rlabel (label, 0, 1)),
2
)
)
elif name.startswith ('%raccess/'):
what, label = name.split ('/')
# ∀XY.Π(l:pre(X);Y) → X
return forall ((0,1), arrow (0, rproduct (rlabel (label, pre(0), 1))))
elif name.startswith ('%rset/'):
what, label = name.split ('/')
# ∀XY.(Π(l:pre(X);Y), X) → undefined
return forall ((0,1), arrow (t_undefined(), rproduct (rlabel (label, pre(0), 1)), 0))
elif name == '%vfail':
return forall ((0,), arrow (0, rsum (rdefault (abs()))))
elif name.startswith ('%dtcon/'):
# lookup the type of the particular constructor
what, dtname, label = name.split ('/')
dt = self.context.datatypes[dtname]
# e.g. list := nil | cons X list
# %dtcon/list/cons := ∀X.(X,list(X)) → list(X)
args = dt.constructors[label]
return forall (dt.tvars, arrow (dt.scheme, *args))
elif name.startswith ('%vcon/'):
what, label, arity = name.split ('/')
arity = int(arity)
# remember each unique variant label
self.remember_variant_label (label)
if arity == 0:
# ∀X.() → Σ(l:pre (Π());X)
return forall ((1,), arrow (rsum (rlabel (label, pre (product()), 1))))
elif arity == 1:
# ∀XY.X → Σ(l:pre X;Y)
return forall ((0,1), arrow (rsum (rlabel (label, pre(0), 1)), 0))
else:
# ∀ABCD.Π(A,B,C) → Σ(l:pre (Π(A,B,C));D)
args = tuple(range (arity))
return forall (range(arity+1), arrow (rsum (rlabel (label, pre (product(*args)), arity)), *args))
elif name == '&vcase':
label, arity = params
# ∀012345.(3,4,5) → 0, Σ(l:1;2) → 0, Σ(l:pre(Π(3,4,5);2) → 0
# ∀012345.f0,f1,s1 → 0
args = range (3, arity+3)
# success continuation
f0 = arrow (0, *args)
# failure continuation
f1 = arrow (0, rsum (rlabel (label, 1, 2)))
# the sum argument
if arity == 1:
t = args[0]
else:
t = product (*args)
s1 = rsum (rlabel (label, pre (t), 2))
return forall (range(arity+3), arrow (0, f0, f1, s1))
elif name == '&vget':
label, arity, index = params
args = range (arity)
rest = arity
# e.g., to pick the second arg:
# ∀0123. Σ(l:pre (0,1,2);3) → 1
if arity > 1:
vtype = rsum (rlabel (label, pre (product (*args)), rest))
else:
vtype = rsum (rlabel (label, pre (args[0]), rest))
return forall (args + [arity], arrow (args[index], vtype))
elif name.startswith ('%nvget/'):
what, dtype, label, index = name.split ('/')
dt = self.context.datatypes[dtype]
ti = dt.constructors[label][int(index)]
return forall (dt.tvars[:], arrow (ti, dt.scheme))
elif name.startswith ('%vector-literal/'):
what, arity = name.split ('/')
arg_types = (0,) * int (arity)
return forall ((0,), arrow (vector(0), *arg_types))
elif name.startswith ('%make-vector'):
return forall ((0,), arrow (vector(0), t_int(), 0))
elif name.startswith ('%make-vec16'):
return forall ((), arrow (vector(t_int16()), t_int()))
elif name == '%%array-ref':
return forall ((0,), arrow (0, vector (0), t_int()))
elif name == '%%array-set':
return forall ((0,), arrow (t_undefined(), vector (0), t_int(), 0))
elif name == '%vec16-set':
return forall ((), arrow (t_undefined(), vector(t_int16()), t_int(), t_int16()))
elif name == '%vec16-ref':
return forall ((), arrow (t_int16(), vector(t_int16()), t_int(), t_int16()))
# ------
# pattern matching
# ------
elif name == '%%match-error':
return forall ((0,), arrow (0))
elif name == '%%fatbar':
return forall ((0,0), arrow (0, 0, 0))
elif name == '%%fail':
return forall ((0,), arrow (0))
# -------
elif name.count (':') == 1:
# a constructor used in a 'constructed literal'
dt, alt = name.split (':')
return self.lookup_special_names ('%%dtcon/%s/%s' % (dt, alt))
else:
raise UnboundVariable (name)
# XXX consider recording record labels at this point as well
def remember_variant_label (self, label):
vl = self.context.variant_labels
if not vl.has_key (label):
# adjust for the hacked pre-installed labels like 'cons' and 'nil'.
vl[label] = len (vl)
def type_of_fix (self, exp, tenv):
# reorder fix into dependency order
partition = graph.reorder_fix (exp, self.context.scc_graph)
n = len (exp.inits)
init_tvars = [None] * n
init_types = [None] * n
n2 = 0
# new type var for each init (or user type)
for i in range (n):
if exp.names[i].type:
# user-annotated type
init_tvars[i] = exp.names[i].type
else:
init_tvars[i] = t_var()
for part in partition:
type_rib = []
# build temp tenv for typing the inits
for i in part:
# for each function
init = exp.inits[i]
name = exp.names[i].name
type_rib.append ((name, init_tvars[i]))
temp_tenv = (type_rib, tenv)
# type each init in temp_tenv
for i in part:
init = exp.inits[i]
name = exp.names[i]
ti = self.type_of (init, temp_tenv)
self.unify (ti, init_tvars[i], temp_tenv, init)
ti = apply_subst_to_type (ti)
init_types[i] = ti
# now extend the environment with type schemes instead
type_rib = []
for i in part:
init = exp.inits[i]
name = exp.names[i]
tsi = build_type_scheme (init_types[i], tenv, name)
type_rib.append ((name.name, tsi))
# we now have a polymorphic environment for this subset
tenv = (type_rib, tenv)
n2 += len (type_rib)
assert (n2 == n)
# and type the body in that tenv
return self.type_of (exp.body, tenv)
def type_of_pvcase (self, exp, tenv):
# (pvcase <alt_formals> <alt0> <alt1> ...)
# each <alt> binds a separate set of variables (possibly empty)
# the last alt binds against either "else" (not yet implemented),
# or rdefault(abs()).
alts = exp.alts[:]
tv_exp = t_var()
if len(alts) == len (exp.alt_formals):
# no else clause, a closed sum type
row = rdefault (abs())
else:
# with an else clause, open sum type
row = t_var()
for i in range (len (exp.alt_formals)):
alt = alts[i]
label, n, formals = exp.alt_formals[i]
# row type extended with this label and its type
args = [t_var() for x in range (n)]
if len(args) == 1:
row = rlabel (label, pre (args[0]), row)
else:
row = rlabel (label, pre(product (*args)), row)
t_alt = self.type_of (alt, tenv)
# each alt must have the same type
self.unify (tv_exp, t_alt, tenv, exp)
if len(alts) > len (exp.alt_formals):
# an else clause
self.unify (tv_exp, self.type_of (alts[-1], tenv), tenv, exp)
# the value must have the row type determined
# by the set of polyvariant alternatives.
t_val = self.type_of (exp.value, tenv)
self.unify (rsum (row), t_val, tenv, exp)
return t_alt
def type_of_nvcase (self, exp, tenv):
# (nvcase <vtype> <val> <alt0> <alt1> ...)
# like a conditional, but with more branches.
dt = self.context.datatypes[exp.vtype]
t_val = self.type_of (exp.value, tenv)
if len(dt.tvars):
# it's a type scheme, instantiate it
dt_type = instantiate_type_scheme (forall (dt.tvars, dt.scheme))
self.unify (t_val, dt_type, tenv, exp)
else:
self.unify (t_val, dt.scheme, tenv, exp)
# each alt has the same type
tv_exp = t_var()
for alt in exp.alts:
self.unify (tv_exp, self.type_of (alt, tenv), tenv, exp)
# this will work even when else_clause is a dummy %%match-error
self.unify (tv_exp, self.type_of (exp.else_clause, tenv), tenv, exp)
return tv_exp