# -*- Mode: Python -*-

# simple lambda language for the purpose of exercising the type inference engine

import lisp_reader
is_a = isinstance

class node:
    pass

class vardef (node):
    def __init__ (self, name):
        self.name = name
        self.alpha = 0

class varref (node):
    def __init__ (self, name):
        self.name = name

class literal (node):
    def __init__ (self, kind, value):
        self.kind = kind
        self.value = value

class function (node):
    def __init__ (self, formals, body):
        self.formals = formals
        self.body = body

class application (node):
    def __init__ (self, rator, rands):
        self.rator = rator
        self.rands = rands

class let (node):
    def __init__ (self, names, inits, body):
        self.names = names
        self.inits = inits
        self.body = body

class walker:

    def go (self, exp):
        return self.walk (exp)
    
    def walk (self, exp):
        # Note: for simplicity, no real syntax checking here
        if is_a (exp, str):
            return varref (exp)
        elif is_a (exp, lisp_reader.atom):
            return literal (exp.kind, exp.value)
        elif is_a (exp, list):
            rator = exp[0]
            if is_a (rator, str):
                if rator == 'lambda':
                    # ['lambda', [formal0, formal1, ...], body]
                    ignore, formals, body = exp
                    return function ([vardef (x) for x in formals], self.walk (body))
                elif rator == 'let':
                    # ['let', [[name0, init0], [name1, init1], ...], body]
                    names = [vardef (x[0]) for x in exp[1]]
                    inits = [self.walk (x[1]) for x in exp[1]]
                    body = self.walk (exp[2])
                    return let (names, inits, body)
                else:
                    return application (self.walk (rator), [self.walk (x) for x in exp[1:]])                    
            else:
                return application (self.walk (rator), [self.walk (x) for x in exp[1:]])

def rename_variables (exp):
    # alpha convert <exp>

    vars = []

    def lookup (name, lenv):
        while lenv:
            rib, lenv = lenv
            for i in range (len(rib)):
                x = rib[i]
                if x.name == name:
                    return x
        raise ValueError ("unbound variable: %r" % (name,))

    def rename (exp, lenv):
        if is_a (exp, let) or is_a (exp, function):
            if is_a (exp, let):
                defs = exp.names
            else:
                defs = exp.formals
            for vd in defs:
                vd.alpha = len(vars)
                vars.append (vd)
            if is_a (exp, let):
                for i in range (len (defs)):
                    rename (exp.inits[i], lenv)
                lenv = (defs, lenv)
                rename (exp.body, lenv)
            else:
                lenv = (defs, lenv)
                rename (exp.body, lenv)
        elif is_a (exp, varref):
            if exp.name.startswith ('%'):
                # primitive
                pass
            else:
                exp.var = lookup (exp.name, lenv)
                exp.name = '%s_%d' % (exp.name, exp.var.alpha)
        elif is_a (exp, application):
            rename (exp.rator, lenv)
            for rand in exp.rands:
                rename (rand, lenv)
        elif is_a (exp, literal):
            pass
        else:
            raise ValueError ("unknown node type")

    rename (exp, None)
    # now go back and change the names of the vardefs
    for vd in vars:
        vd.name = '%s_%d' % (vd.name, vd.alpha)

    return