;; -*- Mode: Irken -*-

;; based on pyaes.

(define (AES key)

  (include "aes_tables.scm")

  (define (byte s pos)
    (char->int (string-ref s pos)))

  (define (be->u32 s pos)
    (let ((b0 (byte s (+ pos 0)))
          (b1 (byte s (+ pos 1)))
          (b2 (byte s (+ pos 2)))
          (b3 (byte s (+ pos 3))))
      (logior* (<< b0 24) (<< b1 16) (<< b2 8) (<< b3 0))
      ))

  ;; assumes |s| = 0 mod 4
  (define (string->u32 s)
    (let ((slen (string-length s))
          (r '()))
      (for-range i (/ slen 4)
        (PUSH r (be->u32 s (* i 4))))
      (reverse r)))

  (define (u8>> n s)
    (logand #xff (>> n s)))

  (let ((klen (string-length key))
        (rounds (match klen with
                  16 -> 10
                  24 -> 12
                  32 -> 14
                  _  -> (raise (:AES/BadKeySize klen))))
        ;; encryption round keys
        (Ke (make-array ((+ rounds 1) 4) 0))
        ;; decryption round keys
        (Kd (make-array ((+ rounds 1) 4) 0))
        (KC (/ klen 4))
        ;; convert key into ints
        (tk (list->vector (string->u32 key)))
        (rindex 0)
        (round-key-count (* 4 (+ 1 rounds)))
        (tt 0)
        (t KC))
    ;; copy values into round key arrays
    (for-range i KC
      (set! Ke[(/ i 4)][(mod i 4)] tk[i])
      (set! Kd[(- rounds (/ i 4))][(mod i 4)] tk[i]))
    ;; key expansion
    (while (< t round-key-count)
      (set! tt tk[(- KC 1)])
      (set! tk[0] (logxor*
                   tk[0]
                   (<< S[(u8>> tt 16)] 24)
                   (<< S[(u8>> tt  8)] 16)
                   (<< S[(u8>> tt  0)]  8)
                   (<< S[(u8>> tt 24)]  0)
                   (<< rcon[rindex] 24)))
      (inc! rindex)
      (if (not (= KC 8))
          (for-range* i 1 KC
            (set! tk[i] (logxor tk[i] tk[(- i 1)])))
          ;; key expansion for 256-bit keys is "slightly different" (fips-197)
          (begin
            (for-range* i 1 (/ KC 2)
              (set! tk[i] (logxor tk[i] tk[(- i 1)])))
            (set! tt tk[(- (/ KC 2) 1)])
            (set! tk[(/ KC 2)] (logxor*
                                tk[(/ KC 2)]
                                (<< S[(u8>> tt  0)]  0)
                                (<< S[(u8>> tt  8)]  8)
                                (<< S[(u8>> tt 16)] 16)
                                (<< S[(u8>> tt 24)] 24)))
            (for-range* i (+ 1 (/ KC 2)) KC
              (set! tk[i] (logxor tk[i] tk[(- i 1)])))))
      ;; copy values into round key arrays
      (let ((j 0))
        (while (and (< j KC) (< t round-key-count))
          (set! Ke[(/ t 4)][(mod t 4)] tk[j])
          (set! Kd[(- rounds (/ t 4))][(mod t 4)] tk[j])
          (inc! j)
          (inc! t))))
    ;; inverse-cipher-ify the decryption round key (fips-197 section 5.3)
    (for-range* r 1 rounds
      (for-range j 4
        (set! tt Kd[r][j])
        (set! Kd[r][j] (logxor*
                        U1[(u8>> tt 24)]
                        U2[(u8>> tt 16)]
                        U3[(u8>> tt  8)]
                        U4[(u8>> tt  0)]))
        ))

    (define (byte->char n)
      (int->char (logand #xff n)))

    (define (encrypt pt)
      (when (not (= 16 (string-length pt)))
        (raise (:AES/BadBlockLength)))
      (let ((a (make-vector 4 0))
            (t (list->vector (string->u32 pt))))
        ;; convert plaintext to ints ^ key
        (for-range i 4 (set! t[i] (logxor t[i] Ke[0][i])))
        ;; apply round transforms
        (for-range* r 1 rounds
          (for-range i 4
            (set! a[i] (logxor*
                        T1[(u8>> t[(mod (+ i 0) 4)] 24)]
                        T2[(u8>> t[(mod (+ i 1) 4)] 16)]
                        T3[(u8>> t[(mod (+ i 2) 4)]  8)]
                        T4[(u8>> t[(mod (+ i 3) 4)]  0)]
                        Ke[r][i])))
          (for-range i 4 (set! t[i] a[i])))
        ;; the last round is special
        (let ((result '()))
          (for-range i 4
            (let ((tt Ke[rounds][i]))
              (PUSH result (byte->char (logxor (>> tt 24) S[(u8>> t[(mod (+ i 0) 4)] 24)])))
              (PUSH result (byte->char (logxor (>> tt 16) S[(u8>> t[(mod (+ i 1) 4)] 16)])))
              (PUSH result (byte->char (logxor (>> tt  8) S[(u8>> t[(mod (+ i 2) 4)]  8)])))
              (PUSH result (byte->char (logxor (>> tt  0) S[(u8>> t[(mod (+ i 3) 4)]  0)])))
              ))
          (list->string (reverse result)))))

    (define (decrypt ct)
      (when (not (= 16 (string-length ct)))
        (raise (:AES/BadBlockLength)))
      (let ((a (make-vector 4 0))
            (t (list->vector (string->u32 ct))))
        ;; convert ciphertext to ints ^ key
        (for-range i 4 (set! t[i] (logxor t[i] Kd[0][i])))
        ;; apply round transforms
        (for-range* r 1 rounds
          (for-range i 4
            (set! a[i] (logxor*
                        T5[(u8>> t[(mod (+ i 0) 4)] 24)]
                        T6[(u8>> t[(mod (+ i 3) 4)] 16)]
                        T7[(u8>> t[(mod (+ i 2) 4)]  8)]
                        T8[(u8>> t[(mod (+ i 1) 4)]  0)]
                        Kd[r][i])))
          (for-range i 4 (set! t[i] a[i])))
        ;; the last round is special
        (let ((result '()))
          (for-range i 4
            (let ((tt Kd[rounds][i]))
              (PUSH result (byte->char (logxor (>> tt 24) Si[(u8>> t[(mod (+ i 0) 4)] 24)])))
              (PUSH result (byte->char (logxor (>> tt 16) Si[(u8>> t[(mod (+ i 3) 4)] 16)])))
              (PUSH result (byte->char (logxor (>> tt  8) Si[(u8>> t[(mod (+ i 2) 4)]  8)])))
              (PUSH result (byte->char (logxor (>> tt  0) Si[(u8>> t[(mod (+ i 1) 4)]  0)])))
              ))
          (list->string (reverse result)))))

    {encrypt=encrypt decrypt=decrypt}
    ))