;; -*- Mode: Irken -*- ;; REDC ;; https://en.wikipedia.org/wiki/Montgomery_modular_multiplication (require "lib/basis.scm") (require "lib/bignum.scm") (require "lib/codecs/hex.scm") (define egcd 0 b -> (:tuple b 0 1) a b -> (let (((q r) (divmod b a)) ((g y x) (egcd r a))) (:tuple g (- x (* q y)) y)) ) (define (mod-inv a p) (if (< a 0) (- p (mod-inv (- 0 a) p)) (let (((g x y) (egcd a p))) (if (not (= g 1)) (raise (:NoInverse a)) (mod x p))))) ;; this is egcd, but more french (define (bezout a b) (let loop ((s 0) (s0 1) (r b) (r0 a)) (if (not (zero? r)) (let ((q (/ r0 r))) (loop (- s0 (* q s)) s (- r0 (* q r)) r)) (:tuple s0 (/ (- r0 (* s0 a)) b)) ))) (define (mont N R) (let ((R' (mod-inv R N)) (N' (/ (* R' R) N)) (r2n (mod (* R R) N))) (printf "N = " (int N) " R = " (int R) "\n") (printf "N' = " (int N') " R' = " (int R') "\n") (printf "RR' - NN' = " (int (- (* R R') (* N N'))) "\n") (printf "R^2N = " (int r2n) "\n") ;; montgomery reduction (define (redc T) (let ((m (mod (* (mod T R) N') R)) (t (/ (+ T (* m N)) R))) (if (>= t N) (- t N) t))) (define (tm a) (redc (* (mod a N) r2n))) {redc=redc tm=tm fm=redc} )) ;; R = 2^bits (define (mont2 N bits) (let ((R (<< 1 bits)) (R' (mod-inv R N)) (N' (/ (* R' R) N)) (r2n (mod (* R R) N)) (one (mod R N)) (mask (- (<< 1 bits) 1))) (printf "bits " (int bits) "\n") (printf "mask " (zpad bits (bin mask)) "\n") (printf "N = " (int N) " R = " (int R) "\n") (printf "N' = " (int N') " R' = " (int R') "\n") (printf "RR' - NN' = " (int (- (* R R') (* N N'))) "\n") (printf "R^2N = " (int r2n) "\n") ;; montgomery reduction (define (redc T) (let ((m (logand (* (logand T mask) N') mask)) (t (>> (+ T (* m N)) bits))) (if (>= t N) (- t N) t))) (define (tm a) (redc (* (mod a N) r2n))) (define (pow x n) (let ((z one)) (while (> n 0) (when (odd? n) (set! z (redc (* z x)))) (set! x (redc (* x x))) (set! n (>> n 1))) (redc z) ;;z )) {redc=redc tm=tm fm=redc pow=pow} )) ;; (mont 17 100) ;; (mont 251 256) ;; (mont 1021 1024) ;; (mont 2039 2048) (let ((N 13) (R 16) (M (mont2 13 4))) (for-range i 13 (printf (lpad 3 (int i)) (lpad 3 (int (M.tm i))) (lpad 3 (int (M.fm (M.tm i)))) (lpad 3 (int (M.fm (M.tm (+ i 7))))) "\n" )) (printf "redc(12) = " (int (M.redc 12)) "\n") (printf "fm(4) = " (int (M.fm 4)) "\n") (printf "tm(4) = " (int (M.tm 4)) "\n") (printf "redc(9*12) = " (int (M.redc (* 9 12))) "\n") (printf "pow(3,5) = " (int (M.pow (M.tm 3) 5)) "\n") (printf "pow(7,12) = " (int (M.pow (M.tm 7) 12)) "\n") )