# -*- Python -*-
# See "The Implementation of Functional Programming Languages",
# Chapter 5: "Efficient Compilation of Pattern-Matching".
# http://research.microsoft.com/en-us/um/people/simonpj/papers/slpj-book-1987/
#
# Thanks for the hint, OCaml people! (Xavier Leroy?) They were kind enough to put this reference in
# their source code (ocaml/bytecomp/matching.ml), otherwise I may have never found out about this
# book. And thanks to Simon Peyton-Jones for putting his book online.
is_a = isinstance
from pdb import set_trace as trace
from pprint import pprint as pp
from lisp_reader import atom
class variable:
# creates a binding
def __init__ (self, name):
self.name = name
def __repr__ (self):
return '<%s>' % (self.name,)
class literal:
# matches a literal
def __init__ (self, value):
self.value = value
def __repr__ (self):
return 'L%s' % (repr(self.value))
def __cmp__ (self, other):
if is_a (other, literal):
v = self.value
o = other.value
return cmp ((v.kind,v.value), (o.kind,o.value))
else:
return -1
class constructor:
# matches a constructor
def __init__ (self, name, subs):
self.datatype, self.alt = name.split (':')
self.subs = subs
def __len__ (self):
# arity of this constructor
return len (self.subs)
def __repr__ (self):
return '(%s/%s %s)' % (self.datatype, self.alt, ' '.join ([repr(x) for x in self.subs]))
class record:
def __init__ (self, pairs):
self.pairs = []
for name, val in pairs:
# for now, ignore ... in patterns
if name != '...':
self.pairs.append ((name, val))
# keep 'em sorted
self.pairs.sort (lambda a,b: cmp (a[0], b[0]))
def __repr__ (self):
l = []
for i in range (len (self.pairs)):
name, sub = self.pairs[i]
l.append ("%s=%r" % (name, sub))
return '{%s}' % (' '.join (l))
# bad match
class MatchError (Exception):
pass
class IncompleteMatch (Exception):
pass
FAIL = ['%%fail']
ERROR = ['%%match-error']
# The next step in this code is to try to optimize the generated tree, which should be a matter of
# using heuristics to pick which pattern out of several to begin with. This code always starts
# with the left-most pattern, and descends recursively; see first_pats_are() below.
class compiler:
def __init__ (self, context):
self.context = context
self.gensym_counter = 0
def gensym (self):
c = self.gensym_counter
self.gensym_counter += 1
return 'm%d' % (c,)
def compile (self, rules, vars):
# how many pattern args?
nrules = len (rules)
pats, result = rules[0]
npats = len (pats)
#vars = [ self.gensym() for x in range (npats) ]
for pats, result in rules[1:]:
# must have the same number of patterns in each
assert (len(pats) == npats)
rules0 = []
for pats, code in rules:
kinds = [ self.kind (x) for x in pats ]
rules0.append ((kinds, code))
return vars, self.match (vars, rules0, ERROR)
def kind (self, p):
if is_a (p, list):
if len(p) == 0:
# () -> (list:nil)
return constructor ('list:nil', [])
elif p[0] == 'quote':
# a symbol
assert (is_a (p[1], str))
return literal (atom ('symbol', p[1]))
elif is_a (p[0], list) and p[0][0] == 'colon' and len(p[0]) == 3:
# a constructor
return constructor ('%s:%s' % (p[0][1], p[0][2]), [self.kind (x) for x in p[1:]])
else:
# (a b . c) => (list:cons ...)
if p[0] == '.':
# cdr
return self.kind (p[1])
else:
return constructor ('list:cons', [self.kind (p[0]), self.kind (p[1:])])
elif is_a (p, str):
return variable (p)
elif is_a (p, atom) and p.kind == 'record':
return record ([(name, self.kind (sub)) for (name, sub) in p.value])
else:
return literal (p)
def first_pats_are (self, rules, kind):
# are the first patterns in each rule of <kind>?
for pats, code in rules:
if not is_a (pats[0], kind):
return False
else:
return True
def match (self, vars, rules, default):
#print '-------- match -------------'
#print vars
#pp (rules)
#pp (default)
# the empty rule
if not vars:
if len(rules):
empty_pat, code = rules[0]
return code
else:
return default
# if every rule begins with a variable
# apply if every rule begins with a variable
if self.first_pats_are (rules, variable):
return self.variable_rule (vars, rules, default)
# if every rule is a constructor (i.e., no variables)
if self.first_pats_are (rules, constructor):
return self.constructor_rule (vars, rules, default)
if self.first_pats_are (rules, record):
return self.record_rule (vars, rules, default)
# if every rule is a constant
if self.first_pats_are (rules, literal):
return self.constant_rule (vars, rules, default)
# we have a mixture of variables and constructors..
return self.mixture_rule (vars, rules, default)
def subst (self, var0, var1, code):
# this will record a subst to be applied during node building (nodes.py)
if var1 == '_':
# unless it's a wildcard, no need.
return code
elif is_a (code, list) and len(code) and code[0] == 'let_subst':
return ['let_subst', code[1] + [(var1, var0)], code[2]]
else:
return ['let_subst', [(var1, var0)], code]
def variable_rule (self, vars, rules, default):
# if every rule begins with a variable, we can remove that column
# from the set of patterns and substitute the var within each body.
var = vars[0]
vars = vars[1:]
rules0 = []
for pats, code in rules:
rules0.append ((pats[1:], self.subst (var, pats[0].name, code)))
return self.match (vars, rules0, default)
def fatbar (self, e1, e2):
if e1 == FAIL:
return e2
elif e2 == FAIL:
return e1
else:
return ['%%fatbar', e1, e2]
def get_arity (self, rules):
# given a set of polymorphic variant rules:
# 1) compute the constructor arity
# 2) verify that they're all the same
arity = len (rules[0][0][0])
for pats, code in rules[1:]:
if len(pats[0]) != arity:
raise MatchError ("arity mismatch in polymorphic variant pattern", rules)
return arity
def constructor_rule (self, vars, rules, default):
# Note: this rule is used for normal constructors *and* polymorphic variants.
# ok, group them by constructor (retaining the order within each constructor alt).
alts = {}
datatype = rules[0][0][0].datatype
if datatype != 'None':
dt = self.context.datatypes[datatype]
else:
# polymorphic variant
dt = None
for pats, code in rules:
alt = pats[0].alt
# XXX raise this as a real syntax error...
assert (pats[0].datatype == datatype)
if not alts.has_key (alt):
alts[alt] = [(pats, code)]
else:
alts[alt].append ((pats, code))
cases = []
if default != ERROR:
default0 = FAIL
else:
default0 = default
for alt, rules0 in alts.iteritems():
# new variables to stand for the fields of the constructor
if dt:
arity = dt.arity (alt)
else:
arity = self.get_arity (rules0)
vars0 = [ self.gensym() for x in range (arity) ]
wild = [ True for x in vars0 ]
rules1 = []
for pats, code in rules0:
rules1.append ((pats[0].subs + pats[1:], code))
if len (pats[0].subs) != arity:
raise MatchError ("arity mismatch in variant pattern", rules0)
for i in range (len (pats[0].subs)):
sub = pats[0].subs[i]
if not (is_a (sub, variable) and sub.name == '_'):
wild[i] = False
# if every pattern has a wildcard for this arg of the constructor,
# then use '_' rather than the symbol we generated.
vars1 = vars0[:]
for i in range (len (vars0)):
if wild[i]:
vars1[i] = '_'
cases.append (
[[['colon', None, alt]] + vars1, self.match (vars0 + vars[1:], rules1, default0)]
)
if dt:
if len(alts) < len (dt.alts):
# an incomplete vcase, stick in an else clause.
cases.append (['else', default0])
result = ['vcase', datatype, vars[0]] + cases
else:
# this will turn into 'pvcase' when the missing datatype is detected
result = ['vcase', vars[0]] + cases
if default != ERROR:
return self.fatbar (result, default)
else:
return result
def record_rule (self, vars, rules, default):
def get_sig (pat):
return [x[0] for x in pat.pairs]
# sanity check
sig = get_sig (rules[0][0][0])
for pats, code in rules[1:]:
if get_sig (pats[0]) != sig:
raise MatchError (pats, sig)
# translate
vars0 = ['%s_%s' % (vars[0], field) for field in sig]
rules0 = []
for pats, code in rules:
pats0 = [ x[1] for x in pats[0].pairs ]
rules0.append ((pats0 + pats[1:], code))
bindings = [ [vars0[i], '%s.%s' % (vars[0], sig[i])] for i in range (len (sig)) ]
return ['let', bindings, self.match (vars0 + vars[1:], rules0, default)]
def constant_rule (self, vars, rules, default):
# This is a simplified version of the constructor rule. Here I'm departing from the book,
# which treats constants quite differently - they are translated into guard clauses. I
# would like to avoid doing guard clauses until I'm convinced they're necessary. And I
# just don't understand why constants should be treated differently from any other
# constructor.
groups = []
last = None
for pats, code in rules:
if pats[0] == last:
groups[-1].append ((pats, code))
else:
groups.append ([(pats,code)])
last = pats[0]
while groups:
group = groups.pop()
rules0 = []
for pats, code in group:
rules0.append ((pats[1:], code))
# decide which comparison function to use...
# eq? works on everything (so far) but strings.
if pats[0].value.kind == 'string':
comp_fun = 'string=?'
else:
comp_fun = 'eq?'
# we use fatbar here to avoid code duplication, which
# can easily lead to exponential code explosion.
default = self.fatbar (
['if', [comp_fun, pats[0].value, vars[0]],
self.match (vars[1:], rules0, FAIL),
FAIL],
default
)
return default
def mixture_rule (self, vars, rules, default):
# partition the rules into runs of either variables or constructors.
parts = []
part = []
last = type(None)
for pats, code in rules:
if not is_a (pats[0], last):
# start a new partition
parts.append (part)
part = [(pats, code)]
last = pats[0].__class__
else:
part.append ((pats, code))
parts.append (part)
parts = parts[1:]
while parts:
part = parts.pop()
default = self.match (vars, part, default)
return default