;; -*- Mode: Irken -*- ;; based on pyaes. (define (AES key) (include "aes_tables.scm") (define (byte s pos) (char->int (string-ref s pos))) (define (be->u32 s pos) (let ((b0 (byte s (+ pos 0))) (b1 (byte s (+ pos 1))) (b2 (byte s (+ pos 2))) (b3 (byte s (+ pos 3)))) (logior* (<< b0 24) (<< b1 16) (<< b2 8) (<< b3 0)) )) ;; assumes |s| = 0 mod 4 (define (string->u32 s) (let ((slen (string-length s)) (r '())) (for-range i (/ slen 4) (PUSH r (be->u32 s (* i 4)))) (reverse r))) (define (u8>> n s) (logand #xff (>> n s))) (let ((klen (string-length key)) (rounds (match klen with 16 -> 10 24 -> 12 32 -> 14 _ -> (raise (:AES/BadKeySize klen)))) ;; encryption round keys (Ke (make-array ((+ rounds 1) 4) 0)) ;; decryption round keys (Kd (make-array ((+ rounds 1) 4) 0)) (KC (/ klen 4)) ;; convert key into ints (tk (list->vector (string->u32 key))) (rindex 0) (round-key-count (* 4 (+ 1 rounds))) (tt 0) (t KC)) ;; copy values into round key arrays (for-range i KC (set! Ke[(/ i 4)][(mod i 4)] tk[i]) (set! Kd[(- rounds (/ i 4))][(mod i 4)] tk[i])) ;; key expansion (while (< t round-key-count) (set! tt tk[(- KC 1)]) (set! tk[0] (logxor* tk[0] (<< S[(u8>> tt 16)] 24) (<< S[(u8>> tt 8)] 16) (<< S[(u8>> tt 0)] 8) (<< S[(u8>> tt 24)] 0) (<< rcon[rindex] 24))) (inc! rindex) (if (not (= KC 8)) (for-range* i 1 KC (set! tk[i] (logxor tk[i] tk[(- i 1)]))) ;; key expansion for 256-bit keys is "slightly different" (fips-197) (begin (for-range* i 1 (/ KC 2) (set! tk[i] (logxor tk[i] tk[(- i 1)]))) (set! tt tk[(- (/ KC 2) 1)]) (set! tk[(/ KC 2)] (logxor* tk[(/ KC 2)] (<< S[(u8>> tt 0)] 0) (<< S[(u8>> tt 8)] 8) (<< S[(u8>> tt 16)] 16) (<< S[(u8>> tt 24)] 24))) (for-range* i (+ 1 (/ KC 2)) KC (set! tk[i] (logxor tk[i] tk[(- i 1)]))))) ;; copy values into round key arrays (let ((j 0)) (while (and (< j KC) (< t round-key-count)) (set! Ke[(/ t 4)][(mod t 4)] tk[j]) (set! Kd[(- rounds (/ t 4))][(mod t 4)] tk[j]) (inc! j) (inc! t)))) ;; inverse-cipher-ify the decryption round key (fips-197 section 5.3) (for-range* r 1 rounds (for-range j 4 (set! tt Kd[r][j]) (set! Kd[r][j] (logxor* U1[(u8>> tt 24)] U2[(u8>> tt 16)] U3[(u8>> tt 8)] U4[(u8>> tt 0)])) )) (define (byte->char n) (int->char (logand #xff n))) (define (encrypt pt) (when (not (= 16 (string-length pt))) (raise (:AES/BadBlockLength))) (let ((a (make-vector 4 0)) (t (list->vector (string->u32 pt)))) ;; convert plaintext to ints ^ key (for-range i 4 (set! t[i] (logxor t[i] Ke[0][i]))) ;; apply round transforms (for-range* r 1 rounds (for-range i 4 (set! a[i] (logxor* T1[(u8>> t[(mod (+ i 0) 4)] 24)] T2[(u8>> t[(mod (+ i 1) 4)] 16)] T3[(u8>> t[(mod (+ i 2) 4)] 8)] T4[(u8>> t[(mod (+ i 3) 4)] 0)] Ke[r][i]))) (for-range i 4 (set! t[i] a[i]))) ;; the last round is special (let ((result '())) (for-range i 4 (let ((tt Ke[rounds][i])) (PUSH result (byte->char (logxor (>> tt 24) S[(u8>> t[(mod (+ i 0) 4)] 24)]))) (PUSH result (byte->char (logxor (>> tt 16) S[(u8>> t[(mod (+ i 1) 4)] 16)]))) (PUSH result (byte->char (logxor (>> tt 8) S[(u8>> t[(mod (+ i 2) 4)] 8)]))) (PUSH result (byte->char (logxor (>> tt 0) S[(u8>> t[(mod (+ i 3) 4)] 0)]))) )) (list->string (reverse result))))) (define (decrypt ct) (when (not (= 16 (string-length ct))) (raise (:AES/BadBlockLength))) (let ((a (make-vector 4 0)) (t (list->vector (string->u32 ct)))) ;; convert ciphertext to ints ^ key (for-range i 4 (set! t[i] (logxor t[i] Kd[0][i]))) ;; apply round transforms (for-range* r 1 rounds (for-range i 4 (set! a[i] (logxor* T5[(u8>> t[(mod (+ i 0) 4)] 24)] T6[(u8>> t[(mod (+ i 3) 4)] 16)] T7[(u8>> t[(mod (+ i 2) 4)] 8)] T8[(u8>> t[(mod (+ i 1) 4)] 0)] Kd[r][i]))) (for-range i 4 (set! t[i] a[i]))) ;; the last round is special (let ((result '())) (for-range i 4 (let ((tt Kd[rounds][i])) (PUSH result (byte->char (logxor (>> tt 24) Si[(u8>> t[(mod (+ i 0) 4)] 24)]))) (PUSH result (byte->char (logxor (>> tt 16) Si[(u8>> t[(mod (+ i 3) 4)] 16)]))) (PUSH result (byte->char (logxor (>> tt 8) Si[(u8>> t[(mod (+ i 2) 4)] 8)]))) (PUSH result (byte->char (logxor (>> tt 0) Si[(u8>> t[(mod (+ i 1) 4)] 0)]))) )) (list->string (reverse result))))) {encrypt=encrypt decrypt=decrypt} ))