# -*- Mode: Python -*-
# lambda language
from pdb import set_trace as trace
# used to generate a unique identifier for each node.
class serial_counter:
def __init__ (self):
self.counter = 0
def next (self):
result = self.counter
self.counter += 1
return result
serial = serial_counter()
# node types
class node:
# The initial implementation of the node class did the 'right' thing,
# by subclassing <node> to get each of the individual node types.
# Then of course I started using all these nice attributes. For example,
# I would have a 'calls' attribute on each function that would track
# every time it was called.
# Once I started writing code that transformed this tree of nodes, however,
# this all came back to bite me. The problem is that you need to be able
# to safely *copy* trees of these objects. Otherwise you get dangling
# pointers and lists, etc... even simple transformations were coming out
# completely mangled. It's *really* made me appreciate the ideas of pure
# functional data structures!
# So, this was rewritten with <kind>, <params>, and <subs>. Clumsy, but easy
# to copy and rewrite. Accessing these attributes is a real pain, but I
# don't have to worry about losing things, or surprises. Also, walking the
# tree of nodes is much simpler.
# However. The rest of the compiler doesn't want to be rewritten in this way,
# so the last thing we do before handing off to cps.py is to add a bunch of
# attributes to each node, in fix_attribute_names().
# I plan to fix this when it becomes either too clumsy or too embarrassing.
#
# XXX Pyrex solves this interestingly... it uses a special attribute to list
# the *names* of attributes that refer to sub-expressions. I think I
# considered this and discarded it because the 'set of all
# sub-expressions' is still difficult to synthesize, in cases where one
# attribute might hold a single expression, and others might hold sets
# of them.
# generic flag
leaf = False
escapes = False
def __init__ (self, kind, params=(), subs=(), type=None):
self.kind = kind
self.params = params
# XXX consider making this tuple(subs)
self.subs = subs
self.serial = serial.next()
size = 1
for sub in subs:
size += sub.size
self.size = size
self.type = type
self.fix_attribute_names()
def pprint (self, depth=0):
if self.leaf:
leaf = 'L'
else:
leaf = ' '
print '%3d %s' % (self.serial, leaf),
print ' ' * depth, self.kind,
print '[%d]' % (self.size,),
if self.params:
print self.params,
if self.type:
print '%s ' % (self.type,)
else:
print '? '
for sub in self.subs:
sub.pprint (depth+1)
def __repr__ (self):
if self.params:
return '<%s %r %d>' % (self.kind, self.params, self.serial)
else:
return '<%s %d>' % (self.kind, self.serial)
def __iter__ (self):
return walk_node (self)
def is_a (self, kind):
return self.kind == kind
def one_of (self, *kinds):
return self.kind in kinds
def is_var (self, name):
return self.kind == 'varref' and self.params == name
def copy (self):
return node (self.kind, self.params, self.subs, self.type)
def deep_copy (self):
# XXX ugliness. because self.params is sometimes a list, it would behoove
# us to use a copy of that list! However it's not always a list.
if is_a (self.params, list):
params = self.params[:]
else:
params = self.params
r = node (self.kind, params, [ x.deep_copy() for x in self.subs ], self.type)
# special-case: binding positions are not nodes or sub-expressions, but
# we want fresh copies of them as well...
if r.binds():
binds = [ vardef (x.name, x.type) for x in r.get_names() ]
if self.is_a ('let_splat'):
r.params = binds
elif self.is_a ('fix'):
r.params = binds
# update function attributes
for i in range (len (binds)):
if r.subs[i].is_a ('function'):
binds[i].function = r.subs[i]
elif self.is_a ('function'):
# function
r.params[1] = binds
else:
raise ValueError ("new binding construct?")
return r
def binds (self):
return self.kind in ('let_splat', 'function', 'fix')
def get_names (self):
if self.kind == 'function':
return self.params[1]
elif self.kind in ('let_splat', 'fix'):
return self.params
else:
raise ValueError ("get_names() not valid for this node")
def get_body (self):
# get the body of an expression
if self.one_of ('let_splat', 'fix', 'function'):
return self.subs[-1]
else:
return self
def get_rator (self):
assert (self.kind == 'application')
return self.subs[0]
def get_rands (self):
assert (self.kind == 'application')
return self.subs[1:]
def fix_attribute_names (self):
if self.kind == 'varref':
self.name = self.params
elif self.kind == 'varset':
self.name = self.params
self.value = self.subs[0]
elif self.kind == 'literal':
self.ltype, self.value = self.params
elif self.kind == 'constructed':
self.value = self.params
elif self.kind == 'primapp':
self.name, self.name_params = self.params
self.args = self.subs
elif self.kind == 'sequence':
self.exprs = self.subs
elif self.kind == 'cexp':
self.form, self.type_sig = self.params
self.args = self.subs
elif self.kind == 'verify':
self.tc, self.safety = self.params
self.arg = self.subs[0]
elif self.kind == 'conditional':
[self.test_exp, self.then_exp, self.else_exp] = self.subs
elif self.kind == 'function':
self.name, self.formals, self.recursive, self.type = self.params
self.body = self.subs[0]
elif self.kind == 'fix':
self.names = self.params
self.inits = self.subs[:-1]
self.body = self.subs[-1]
elif self.kind == 'let_splat':
self.names = self.params
self.inits = self.subs[:-1]
self.body = self.subs[-1]
elif self.kind == 'let_subst':
self.vars = self.params
self.body = self.subs[0]
elif self.kind == 'application':
self.recursive = self.params
self.rator = self.subs[0]
self.rands = self.subs[1:]
elif self.kind == 'pvcase':
self.alt_formals = self.params
self.value = self.subs[0]
self.alts = self.subs[1:]
elif self.kind == 'nvcase':
self.vtype, self.tags = self.params
self.value = self.subs[0]
self.alts = self.subs[1:-1]
self.else_clause = self.subs[-1]
else:
raise ValueError (self.kind)
def walk_node (n):
yield n
for sub in n.subs:
for x in walk_node (sub):
yield x
def walk_up (n):
for sub in n.subs:
for x in walk_node (sub):
yield x
yield n
# this is *not* a node!
class vardef:
tvar = False
def __init__ (self, name, type=None):
if is_a (name, str):
self.name = name
self.type = type
elif is_a (name, list) and len(name) == 3 and name[0] == 'colon':
# infix colon syntax for type declaration
assert (type is None)
self.name = name[1]
self.type = parse_type (name[2])
elif is_a (name, list) and len(name) == 2 and name[0] == 'quote':
# type variable argument
self.tvar = True
self.name = name[1]
self.type = None
self.assigns = []
self.refs = []
self.function = None
self.serial = serial.next()
self.escapes = False
self.inline = None
self.alias = None
def __repr__ (self):
#return '{%s.%d}' % (self.name, self.serial)
if self.type:
return '{%s:%s}' % (self.name, self.type)
else:
return '{%s}' % (self.name,)
#return '{%s.%d}' % (self.name, len(self.assigns))
def varref (name):
return node ('varref', name)
def varset (name, value):
return node ('varset', name, [value])
def literal (kind, value):
return node ('literal', (kind, value))
# a literal built by constructors & immediates
def constructed (value):
return node ('constructed', value)
def primapp (name, args, params=None):
return node ('primapp', (name, params), args)
def sequence (exprs):
if not exprs:
exprs = [literal ('undefined', 'undefined')]
return node ('sequence', (), exprs, type=exprs[-1].type)
def cexp (form, type_sig, args):
return node ('cexp', (form, type_sig), args)
def conditional (test_exp, then_exp, else_exp):
return node ('conditional', (), [test_exp, then_exp, else_exp])
def function (name, formals, body, type=None):
return node ('function', [name, formals, False, type], [body])
def fix (names, inits, body, type=None):
n = node ('fix', names, inits + [body], type)
for i in range (len (names)):
if inits[i].is_a ('function'):
names[i].function = inits[i]
return n
def let_splat (names, inits, body, type=None):
return node ('let_splat', names, inits + [body], type)
def let_subst (vars, body):
return node ('let_subst', vars, [body])
def application (rator, rands):
return node ('application', False, [rator] + rands)
def pvcase (value, alt_formals, alts):
return node ('pvcase', alt_formals, [value] + alts)
def nvcase (vtype, value, tags, alts, else_clause):
return node ('nvcase', (vtype, tags), [value] + alts + [else_clause])
# ================================================================================
class ConfusedError (Exception):
pass
import itypes
# should this be moved to transform.py? [or itypes.py?]
def parse_type (exp, tvars=None):
if tvars is None:
tvars = {}
tv_counter = serial_counter()
def new_tvar():
return get_tvar ('r%d' % (tv_counter.next()))
def get_tvar (name):
if not tvars.has_key (name):
tvars[name] = itypes.t_var()
return tvars[name]
def pfun (x):
if is_a (x, list):
if len(x) and x[0] == 'quote':
# a type variable
return get_tvar (x[1])
elif len(x) >= 2 and x[-2] == '->':
# an arrow type
result_type = pfun (x[-1])
arg_types = tuple ([pfun(y) for y in x[:-2]])
return itypes.arrow (result_type, *arg_types)
elif len(x) > 0 and is_a (x[0], str):
# a predicate
arg_types = tuple ([pfun(y) for y in x[1:]])
return itypes.t_predicate (x[0], arg_types)
else:
raise ValueError ("malformed type: %r" % (x,))
elif is_a (x, str):
if itypes.base_types.has_key (x):
return itypes.base_types[x]
else:
# allow nullary constructors
#raise ValueError ("unknown type: %r" % (x,))
return itypes.t_predicate (x, ())
elif is_a (x, atom) and x.kind == 'record':
# record type
pairs = x.value[:]
pairs.reverse()
if pairs[0][0] == '...':
t = new_tvar()
pairs = pairs[1:]
else:
t = itypes.rdefault (itypes.abs())
for fname, ftype in pairs:
t = itypes.rlabel (fname, itypes.pre (pfun (ftype)), t)
return itypes.rproduct (t)
else:
raise ValueError
return x
#print 'pfun (%r) => %r' % (exp, pfun (exp))
return pfun (exp)
from lisp_reader import atom
is_a = isinstance
class walker:
"""The walker converts from 's-expression' => 'node tree' representation"""
def __init__ (self, context):
self.context = context
def walk_exp (self, exp):
WALK = self.walk_exp
if is_a (exp, str):
return varref (exp)
elif is_a (exp, atom):
return literal (exp.kind, exp.value)
elif is_a (exp, list):
rator = exp[0]
simple = is_a (rator, str)
if simple:
if rator == '%%cexp':
assert (is_a (exp[2], atom))
assert (exp[2].kind == 'string')
tvars = {}
type_sig = parse_type (exp[1], tvars)
form = exp[2].value
return cexp (form, (tvars.values(), type_sig), [ WALK (x) for x in exp[3:]])
elif rator == '%nvcase':
ignore, vtype, value, alts, ealt = exp
dt = self.context.datatypes[vtype]
tags = [x[0] for x in alts]
alts = [x[1] for x in alts]
return nvcase (vtype, WALK(value), tags, [WALK (x) for x in alts], WALK(ealt))
elif rator.startswith ('%'):
return primapp (rator, [WALK (x) for x in exp[1:]])
elif rator.startswith ('&'):
# primap with parameters
return primapp (rator, [WALK (x) for x in exp[2:]], exp[1])
elif rator == 'begin':
return sequence ([WALK (x) for x in exp[1:]])
elif rator == 'set_bang':
ignore, name, val = exp
return varset (name, WALK (val))
elif rator == 'quote':
return literal (exp[1].kind, exp[1].value)
elif rator == 'constructed':
return constructed (WALK (exp[1]))
elif rator == 'if':
return conditional (WALK (exp[1]), WALK (exp[2]), WALK (exp[3]))
elif rator == 'function':
fun_name, fun_type = exp[1]
formals = exp[2]
formals = [vardef (x) for x in formals]
return function (fun_name, formals, WALK (exp[3]), fun_type)
elif rator == 'let_splat':
ignore, vars, body = exp
names = [vardef(x[0]) for x in vars]
inits = [WALK (x[1]) for x in vars]
return let_splat (names, inits, WALK (body))
elif rator == 'let_subst':
ignore, vars, body = exp
return let_subst (vars, WALK (body))
elif rator == 'fix':
ignore, names, inits, body = exp
names = [vardef (x) for x in names]
inits = [WALK (x) for x in inits]
return fix (names, inits, WALK (body))
else:
# a varref application
return application (WALK (rator), [WALK (x) for x in exp[1:]])
else:
# a non-simple application
return application (WALK (rator), [ WALK (x) for x in exp[1:]])
else:
raise ValueError, exp
def go (self, exp):
exp = self.walk_exp (exp)
for node in exp:
node.fix_attribute_names()
return exp
# walk the node tree, applying subst nodes.
def apply_substs (exp):
def shadow (names, lenv):
if lenv is None:
return lenv
else:
rib, tail = lenv
rib = [(x[0], x[1]) for x in rib if x[0] not in names]
return (rib, shadow (names, tail))
def lookup (name, lenv):
while lenv:
rib, lenv = lenv
for xf, xt in rib:
if xf == name:
return xt
return name
def walk (exp, lenv):
if exp.binds():
names = [x.name for x in exp.get_names()]
lenv = shadow (names, lenv)
elif exp.is_a ('let_subst'):
# filter out wildcards from match expressions
names = []
for fn, tn in exp.params:
if fn != '_':
# follow chains of replacments...
names.append ((fn, lookup (tn, lenv)))
lenv = (names, lenv)
return walk (exp.subs[0], lenv)
elif exp.one_of ('varref', 'varset'):
exp.params = lookup (exp.params, lenv)
# XXX urgh, destructive update bad, bad, bad.
exp.subs = [walk (sub, lenv) for sub in exp.subs ]
return exp
return walk (exp, None)
# alpha conversion
def rename_variables (exp, context):
vars = []
datatypes = context.datatypes
def lookup_var (name, lenv):
while lenv:
rib, lenv = lenv
# walk rib backwards for the sake of <let*>
# (e.g., (let ((x 1) (x 2)) ...))
for x in reversed (rib):
if x.name == name:
return x
if datatypes.has_key (name):
return None
elif name.startswith ('&'):
return None
else:
raise ValueError ("unbound variable: %r" % (name,))
# walk <exp>, inventing a new name for each <vardef>,
# renaming varref/varset as we go...
def rename (exp, lenv):
if exp.binds():
defs = exp.get_names()
for vd in defs:
vd.alpha = len (vars)
vars.append (vd)
if exp.is_a ('let_splat'):
# this one is tricky
names = []
lenv = (names, lenv)
for i in range (len (defs)):
# add each name only after its init
init = exp.subs[i]
rename (init, lenv)
names.append (defs[i])
# now all the inits are done, rename body
rename (exp.subs[-1], lenv)
# ugh, non-local exit
return
else:
# normal binding behavior
lenv = (defs, lenv)
if exp.is_a ('fix'):
# rename functions
for i in range (len (defs)):
if exp.subs[i].is_a ('function'):
if not defs[i].name.startswith ('&'):
exp.subs[i].params[0] = '%s_%d' % (defs[i].name, defs[i].alpha)
for sub in exp.subs:
rename (sub, lenv)
elif exp.is_a ('pvcase'):
# this is a strangely shaped binding construct
# note: nvcase uses let internally for binding, and does not need to be here.
# XXX I think this is probably true of pvcase too, now.
rename (exp.value, lenv)
n = len (exp.alts)
for i in range (n):
selector, defs = exp.alt_formals[i]
alt = exp.alts[i]
for vd in defs:
vd.alpha = len (vars)
vars.append (vd)
lenv = (defs, lenv)
rename (alt, lenv)
elif exp.one_of ('varref', 'varset'):
name = exp.params
probe = lookup_var (name, lenv)
if probe:
exp.var = probe
if exp.is_a ('varset'):
probe.assigns.append (exp)
else:
probe.refs.append (exp)
exp.params = exp.name = '%s_%d' % (name, exp.var.alpha)
for sub in exp.subs:
rename (sub, lenv)
else:
for sub in exp.subs:
rename (sub, lenv)
# first, apply any pending substs
exp = apply_substs (exp)
# because of the destructive update we gotta redo this
for node in exp:
node.fix_attribute_names()
#exp.pprint()
rename (exp, None)
# now go back and change the names of the vardefs
for vd in vars:
if vd.name.startswith ('&'):
vd.name = vd.name[1:]
else:
vd.name = '%s_%d' % (vd.name, vd.alpha)
result = {}
for vd in vars:
result[vd.name] = vd
return result
# leave this here for tests/t_lex.scm
42