;; -*- Mode: Irken -*-

(require "self/backend.scm")

;; REMOVE ME when the time comes
(define (not-yet s)
  (printf "not yet: " s "\n")
  (%exit #f -1))

;; peephole optimization ideas:
;; combine str, str into stp?
;; same with ldr, ldr into ldp
;; elide branch to following code
;; elide labels never referenced

;; runtime register assignments
(define Rk 28)
(define Rlenv 27)
(define Rfreep 26)
(define Rscratch 25)
(define Rreturn 24) ;; XXX save/restore me!
;; another regvar candidate would be 'top'

(datatype armcc
  (:eq) ;; Equal Z
  (:ne) ;; Not equal !Z
  (:cs) ;; Carry set, Unsigned higher or same C
  (:cc) ;; Carry clear, Unsigned lower !C
  (:mi) ;; Minus, Negative N
  (:pl) ;; Plus, Positive or zero !N
  (:vs) ;; Overflow V
  (:vc) ;; No overflow !V
  (:hi) ;; Unsigned higher C & !Z
  (:ls) ;; Unsigned lower or same !C | Z
  (:ge) ;; Signed greater than or equal N = V
  (:lt) ;; Signed less than N 6= V
  (:gt) ;; Signed greater than !Z & N = V
  (:le) ;; Signed less than or equal Z | N 6= V
  )

(define armcc->name
  (armcc:eq) -> 'eq (armcc:ne) -> 'ne (armcc:cs) -> 'cs (armcc:cc) -> 'cc
  (armcc:mi) -> 'mi (armcc:pl) -> 'pl (armcc:vs) -> 'vs (armcc:vc) -> 'vc
  (armcc:hi) -> 'hi (armcc:ls) -> 'ls (armcc:ge) -> 'ge (armcc:lt) -> 'lt
  (armcc:gt) -> 'gt (armcc:le) -> 'le
  )

;; question: do we want to keep & update the free regs from cps?
;; I'm imagining doing some kinds of optimizations on this output
;; assembly - I guess peephole - but I can't picture them yet.
;;
;; two different approaches to register allocation at this stage:
;; so the CPS code pretty much did all the work, but we sometimes
;; need extra regs to perform CPS operations - do we just keep a
;; set of 'scratch' registers and use those, or do we try to layer
;; another register allocator on top of the decisions already made
;; in the CPS IL?

(datatype arm
  (:return)                   ;; return via <k>
  (:b string)                 ;; branch to label
  (:br int)                   ;; branch via register
  (:ret)                      ;; C return
  (:mov int int)              ;; mov a <- b
  (:bcond armcc symbol)       ;; conditional branch
  (:alloc int int int)        ;; dst tag size
  (:stra int int string)      ;; scratch src label
  (:ldra int symbol)          ;; dst label
  (:lit16 int int)            ;; dst literal
  (:lit32 int int)            ;; dst literal
  (:ref0 int int)             ;; dst offset (varref 0 n)
  (:set0 int int)             ;; src offset
  (:reftop int int)           ;; dst offset top[n]
  (:str int int int)          ;; src dst offset
  (:ldr int int int)          ;; dst src offset
  (:arith symbol int int int) ;; op dst a b
  (:cmp int int)              ;; a b
  (:tag int)                  ;; [inplace] reg (x<<1)|1
  (:untag int)                ;; [inplace] reg src x>>1
  (:csel int int int armcc)   ;; dst a b cc
  (:test int string)          ;; test-reg after-label
  (:label string)             ;; L1:
  (:adr int string)           ;; dst label (address of [local] label)
  (:comment string)           ;; for inline comments
  (:linecom arm string)       ;; comment on a particular line
  ;; disabled for now, inlined.
  ;;(:ref int int int)          ;; dst depth index
  )

(define format-arm
  (arm:return)             -> "return x25" ;; x25 is the scratch register, needed to fetch return address.
  (arm:mov dst src)        -> (format "mov x" (int dst) ", x" (int src))
  (arm:b name)             -> (format "b " name)
  (arm:br reg)             -> (format "br x" (int reg))
  (arm:ret)                -> "ret"
  (arm:bcond cc label)     -> (format "b." (sym (armcc->name cc)) " " (sym label))
  (arm:lit16 reg val)      -> (format "mov x" (int reg) ", " (int val))
  ;; consider changing this to movz,movk combo and splitting val into 16-bit chunks.
  (arm:lit32 reg val)      -> (format "ldr x" (int reg) ", =" (int val))
  (arm:arith op dst a b)   -> (format (sym op) " x" (int dst) ", x" (int a) ", x" (int b))
  (arm:tag reg)            -> (format "tag x" (int reg))
  (arm:untag reg)          -> (format "untag x" (int reg))
  (arm:csel dst a b cc)    -> (format "csel x" (int dst) ", x" (int a) ", x" (int b) ", " (sym (armcc->name cc)))
  (arm:cmp a b)            -> (format "cmp x" (int a) ", x" (int b))
  (arm:test x label)       -> (format "b.eq " label)
  (arm:label name)         -> (format name ":")
  (arm:ldr dst src off)    -> (format "ldr x" (int dst) ", [x" (int src) ", #" (int off) "]")
  (arm:str src dst off)    -> (format "str x" (int src) ", [x" (int dst) ", #" (int off) "]")
  (arm:alloc dst tag size) -> (format "alloc x" (int dst) ", #" (int tag) ", " (int size) ", " (int (* 8 (+ size 1))))
  (arm:adr dst name)       -> (format "adr x" (int dst) ", " name)
  (arm:ref0 dst i)         -> (format "varref0 x" (int dst) ", " (int (* 8 (+ i 2)))) ;; [tc, lenv, v0, v1, ...]
  (arm:reftop dst i)       -> (format "topref x" (int dst) ", " (int (* 8 (+ i 2))))
  (arm:stra r src lab)     -> (format "str_addr x" (int r) ", x" (int src) ", " lab)
  (arm:comment s)          -> (format ";; " s)
  (arm:linecom insn com)   -> (format (format-arm insn) " ;; " com)
  x -> (begin
         (printn x)
         (not-yet "arm insn"))
  )

(define (cps->arm cps)

  (let ((used-jumps (find-jumps cps)))

    (define new-label
      (let ((counter 0))
        (lambda (prefix)
          (inc! counter)
          (format "L" prefix (int counter)))))

    (define (emitk acc k)
      (if (null-cont? k)
          acc
          (append acc (emit k.insn))))

    ;; XXX these should really go in backend.scm and not be duplicated
    ;;     by each backend.

    (define (UITAG n) (+ TC_USERIMM (<< n TAGSIZE)))
    (define (UOTAG n) (+ TC_USEROBJ (<< n 2)))

    ;; hacks for datatypes known by the runtime
    (define (get-uotag dtname altname index)
      (match dtname altname with
	'list 'cons -> TC_PAIR
	'symbol 't  -> TC_SYMBOL
	_ _         -> (UOTAG index)))

    (define (get-uitag dtname altname index)
      (match dtname altname with
	'list 'nil   -> TC_NIL
	'bool 'true  -> immediate-true
	'bool 'false -> immediate-false
	_ _          -> (UITAG index)))

    (define (emit-return reg)
      (list (arm:mov Rreturn reg)
            (arm:return)))

    (define (emit-literal lit trg)
      ;; literals are tricky on arm, because of the fixed-size insns
      ;; there are several different ways to load immediates, depending
      ;; on their size (and even bit patterns).
      (if (= trg -1)
          (list:nil) ;; dead literal
          (let ((enc (encode-immediate lit)))
            (if (< enc (<< 1 16))
                (list (arm:lit16 trg enc))
                (if (< enc (<< 1 32))
                    (list (arm:lit32 trg enc))
                    (not-yet "64-bit literals"))))))

    (define (emit-arith op arg0 arg1 target)
      ;; XXX think about ways to avoid untag/tag, depending on <op>
      ;; e.g. (2a+1)+(2b+1) == 2a + 2b + 2 == 2(a+b)+2 so leave tagged and sub1 from result.
      (list (arm:untag arg0)
            (arm:untag arg1)
            (arm:arith op target arg0 arg1)
            (arm:tag target)))

    (define cc->armcc
      ;; translate between llvm CC and arm CC
      ;; there are more condition codes, e.g. PL/MI for pos/neg,
      ;; might be useful.
      'eq  -> (armcc:eq)
      'slt -> (armcc:lt)
      'sle -> (armcc:le)
      'sgt -> (armcc:gt)
      'sge -> (armcc:ge)
      'ugt -> (armcc:hi)
      'ult -> (armcc:cc)
      'ule -> (armcc:ls)
      'uge -> (armcc:cs)
      cc -> (raise (:UnknownConditionCode cc))
      )

    (define (emit-icmp cc arg0 arg1 target)
      ;; want to use CSEL here, which needs the two values in registers
      ;; let's put #t in target, and #f in scratch and csel from there?
      (append
       (list (arm:cmp arg0 arg1))
       ;; we avoid needing another scratch register by putting the cmp
       ;; before wiping out target (which is likely to be arg0 or arg1).
       (emit-literal (literal:bool #t) target)
       (emit-literal (literal:bool #f) Rscratch)
       (list (arm:csel target target Rscratch (cc->armcc cc)))
       ))

    (define (emit-dtcon dtname altname args target)
      (match (alist/lookup the-context.datatypes dtname) with
        (maybe:no)
        -> (raise (:NoSuchDatatype "emit-dtcon" dtname))
        (maybe:yes dt)
        -> (let ((alt (dt.get altname))
                 (nargs (length args)))
             (if (= nargs 0) ;; immediate constructor
                 (list (arm:lit32 target (get-uitag dtname altname alt.index)))
                 (append
                  ;; we can't build directly into target because it's likely target is in args.
                  (list (arm:alloc Rscratch (get-uotag dtname altname alt.index) nargs))
                  (map-range i nargs (arm:str (nth args i) Rscratch (* 8 (+ i 1))))
                  (list (arm:mov target Rscratch)))))
        ))

    (define (emit-primop name params type args target)
      (match name params args with
        '%llarith (sexp:symbol op _) (arg0 arg1)   -> (emit-arith op arg0 arg1 target)
        '%llicmp  (sexp:symbol cc _) (arg0 arg1)   -> (emit-icmp cc arg0 arg1 target)
        '%dtcon   (sexp:cons dtname altname) args  -> (emit-dtcon dtname altname args target)
        _ _ _
        -> (begin
             (printf "unknown primop: " (sym name))
             (raise (:Arm/UnknownPrimop)))
        ))

    (define (emit-jump-continuation jn k)
      (match (used-jumps::get jn) with
        (maybe:yes free)
        -> (emit k)
        (maybe:no)
        -> (list:nil)
        ))

    (define (emit-test reg jn k0 k1 k)
      (let ((skip-label (new-label "L"))
            (jcont (emit-jump-continuation jn k.insn)))
        (append
         (emit-literal (literal:bool #f) Rscratch)
         (list
          (arm:cmp Rscratch reg)
          (arm:test reg skip-label))
         (emit k0)
         (list (arm:label skip-label))
         (emit k1)
         jcont)))

    (define (emit-label jn next)
      (list:cons (arm:label (format "J" (int jn))) (emit next)))

    (define (emit-jump reg target jn free)
      (list
       (arm:mov target reg)
       (arm:b (format "J" (int jn)))))

    ;; insn:move has two different meanings/uses,
    ;;  from either varref or varset.
    (define (emit-move var src target)
      ;; MOV <dst-ref> <src-reg>
      (cond ((and (>= src 0) (not (= src var)))
             ;; from varset
             ;; XXX target := #u
             (list (arm:mov var src)))
            ((and (>= target 0) (not (= target var)))
             ;; from varref
             (list (arm:mov target var)))
            (else '())))

    (define (emit-alloc tag size target)
      (list (arm:alloc target (UOTAG tag) size)))

    (define (emit-store off arg tup index)
      ;; hmmmm I feel look the purpose of <off> is to capture
      ;; the offset that includes the header?  no?
      ;; str arg, [tup, 8 * (1+off+index)]
      (list (arm:str arg tup (* 8 (+ 1 off index)))))

    (define (emit-close name nreg body target)
      (let ((l0 (new-label "L"))
            (flabel (format "F" (sym name)))
            (gc-check (if (vars-get-flag name VFLAG-ALLOCATES)
                          (list (arm:comment "gc preamble here"))
                          '())))
        (append
         (list
          (arm:b l0)
          (arm:label flabel))
         gc-check
         (emit body)
         (list
          (arm:label l0)
          (arm:alloc target TC_CLOSURE 2)
          (arm:adr Rscratch flabel)
          (arm:str Rscratch target 8)
          (arm:str Rlenv target 16)
          ))))

    ;; to think about: high-depth varref could be:
    ;; 1) a 'normal' funcall to varref()
    ;; 2) a bl to a scratch-only function to get the right rib?
    ;; 3) same, but with different labels for different depths [only one copy]

    (define (emit-varref depth index target)
      (match depth with
         0 -> (list (arm:ref0 target index))
        -1 -> (list (arm:linecom (arm:reftop target index) (format "topref " (int index))))
        _  -> (append
               ;;(arm:ref target depth index)
               ;; for now, we inline the dive to depth
               (list
                (arm:comment (format "(varref " (int depth) " " (int index) ")"))
                (arm:mov Rscratch Rlenv))
               (n-of depth (arm:ldr Rscratch Rscratch 8))
               (list (arm:ldr target Rscratch (* 8 (+ index 2)))))
        ))

    (define (emit-new-env size top? types target)
      ;; types are just for comment metadata [for now]
      (list:cons
       (arm:alloc target TC_ENV (+ size 1))
       (if top?
           (list (arm:stra Rscratch target "_top"))
           (list:nil))))

    (define (emit-push reg)
      (list
       (arm:str Rlenv reg 8) ;; r[1] = lenv
       (arm:mov Rlenv reg))) ;; lenv = r

    (define (funcall mname funreg)
      (match mname with
        (maybe:yes name)
        -> (list (arm:b (format "F" (sym name))))
        (maybe:no)
        -> (list (arm:ldr Rscratch funreg 8) ;; address via closure
                 (arm:br Rscratch))
        ))

    (define (emit-tail mname fun args)
      (let ((call (funcall mname fun)))
        (match args with
          -1 -> (list:cons (arm:ldr Rlenv fun 8) call) ;; no args
          _  -> (list:cons (arm:str fun args 8) (list:cons (arm:mov Rlenv args) call))
          )))

    (define (emit-trcall depth name regs)
      (let ((nargs (length regs))
            (npop (if (= nargs 0) depth (- depth 1)))
            (name (format "F" (sym name)))
            (pops (n-of npop (arm:ldr Rlenv Rlenv 8)))
            (stores (map-range i nargs (arm:str (nth regs i) Rlenv (* 8 (+ 2 i))))))
        (append pops stores (list (arm:b name)))
        ))

    ;; we emit insns for k0, which may or may not jump to fail continuation in k1
    (define (emit-fatbar label jn k0 k1 k)
      (let ((lfail (format "Fail" (int label))))
        (append
         (emit k0)                          ;; k0
         (list (arm:label lfail))           ;; Lfail:
         (emit k1)                          ;; k1
         (emit-jump-continuation jn k.insn) ;; Ljump:
         )))                                ;; k

    (define (emit-fail label npop free)
      (append
       (n-of npop (arm:ldr Rlenv Rlenv 8))
       (list (arm:b (format "Fail" (int label))))))

    (define (emit-call mname fun args k)
      (let ((free (sort < k.free))
	    (nregs (length free))
	    (target k.target)
            (lreturn (new-label "R")))
        (append
         (list (arm:comment "build continuation...")
               (arm:alloc Rscratch TC_SAVE (+ nregs 3))
               (arm:str Rk Rscratch 8)
               (arm:str Rlenv Rscratch 16)
               (arm:mov Rk Rscratch) ;; need scratch for label addr
               (arm:adr Rscratch lreturn) ;; get label
               (arm:str Rscratch Rk 24)) ;; k[3] = label
         (list (arm:comment (format "save free registers: [" (join int->string ", " free) "]")))
         (map-range i nregs (arm:str (nth free i) Rk (* 8 (+ 4 i))))
         (list (arm:comment "perform the call...")
               (arm:linecom (arm:ldr Rscratch fun 16) "args[1] = fun[2]")
               (arm:str Rscratch args 8)
               (arm:linecom (arm:mov Rlenv args)) "lenv = args")
         (funcall mname fun)
         (list (arm:comment "label for return...")
               (arm:label lreturn))
         (list (arm:comment (format "restore free registers..." (join int->string ", " free))))
         (map-range i nregs (arm:ldr (nth free i) Rk (* 8 (+ 4 i))))
         (list (arm:comment (format "pop k..."))
               (arm:linecom (arm:ldr Rscratch Rk 16) "lenv = k[2]")
               (arm:mov Rlenv Rscratch)
               (arm:linecom (arm:ldr Rscratch Rk 8) "k = k[1]")
               (arm:mov Rk Rscratch))
         (list (arm:linecom (arm:mov target Rreturn) "target = Rreturn"))
         )))

    (define (emit insn)
      ;;(print-insn insn) (newline)
      (match insn with
        (insn:literal lit k)                  -> (emitk (emit-literal lit k.target) k)
        (insn:return target)                  -> (emit-return target)
        (insn:primop name params type args k) -> (emitk (emit-primop name params type args k.target) k)
        (insn:test reg jn k0 k1 k)            -> (emit-test reg jn k0 k1 k)
        (insn:label jn next)                  -> (emit-label jn next)
        (insn:jump reg target jn free)        -> (emit-jump reg target jn free.val)
        (insn:move dst var k)                 -> (emitk (emit-move dst var k.target) k)
        (insn:alloc tag size k)               -> (emitk (emit-alloc tag size k.target) k)
        (insn:store off arg tup i k)          -> (emitk (emit-store off arg tup i) k)
        (insn:close name nreg body k)         -> (emitk (emit-close name nreg body k.target) k)
        (insn:varref d i k)                   -> (emitk (emit-varref d i k.target) k)
        (insn:new-env size top? types k)      -> (emitk (emit-new-env size top? types k.target) k)
        (insn:push r k)                       -> (emitk (emit-push r) k)
        (insn:tail name fun args)             -> (emit-tail name fun args)
        (insn:trcall d n args)                -> (emit-trcall d n args)
        (insn:fatbar lab jn k0 k1 k)          -> (emit-fatbar lab jn k0 k1 k)
        (insn:fail label npop free)           -> (emit-fail label npop free.val)
        (insn:invoke name fun args k)         -> (emitk (emit-call name fun args k) k)

        _ -> (begin
               (print-insn insn)
               (not-yet "cps insn not implemented"))
        ))

    (emit cps)
    ))

(define (emit-arm o cname arm)
  (verbose (printf "emit-arm...\n") (flush))
  (o.copy
   (get-file-contents "include/arm-preamble.s"))
  (o.write (format ";;; arm64 output (apple flavored)"))
  (o.write "\t.global _toplevel")
  (o.write "_toplevel:")
  ;; XXX need code to save and restore x25-x28.
  (o.indent)
  (o.write "stp x25, x26, [sp, -16]!")
  (o.write "stp x27, x28, [sp, -16]!")
  (o.write "ldr_addr x26, _freep")
  (o.write "ldr_addr x27, _lenv")
  (o.write "ldr_addr x28, _k")
  (o.write "lld_addr x25, Lreturn")
  (o.write "str x25, [x28, #24]")
  (for-list insn arm
    (match insn with
      (arm:label name)
      -> (begin
           (o.dedent)
           (o.write (format name ":"))
           (o.indent))
      _ -> (o.write (format-arm insn))))
  (o.dedent)
  (o.write "Lreturn:")
  (o.indent)
  (o.write "ldp x27, x28, [sp], 16")
  (o.write "ldp x25, x26, [sp], 16")
  (o.write "b _exit_continuation")
  (o.dedent)
  (o.close)
  ;; (verbose (printf "emit constructed literals...\n") (flush))
  ;; (arm-emit-constructed o)
  ;; (verbose (printf "emit lookup-field hashtables...\n"))
  ;; (emit-arm-lookup-field-hashtables o)
  ;; (verbose (printf "emit metadata...\n"))
  ;; (emit-arm-get-metadata o)
  ;; (verbose (printf "emit declarations...\n"))
  ;; (emit-ffi-declarations o)
  )

(define (compile-to-arm base cps)
  (let ((arm (cps->arm cps)))
    ;; do stuff with armcps
    (let ((armpath (format base ".s"))
          (arm-file (file/open-write armpath #t #o644))
          (oarm (make-writer arm-file)))
      (oarm.set-indent "\t")
      (notquiet (printf "\n-- ARM output --\n : " armpath "\n"))
      (emit-arm oarm "toplevel" arm)
      (notquiet (printf "wrote " (int (oarm.get-total)) " bytes to " armpath ".\n"))
      (let (((path0 ignore-file) (find-file the-context.options.include-dirs "include/header1.c")))
        (file/close ignore-file)
        (list armpath path0)))))