;; -*- Mode: Irken -*-

;; REDC
;; https://en.wikipedia.org/wiki/Montgomery_modular_multiplication

(require "lib/basis.scm")
(require "lib/bignum.scm")
(require "lib/codecs/hex.scm")

(define egcd
  0 b -> (:tuple b 0 1)
  a b -> (let (((q r) (divmod b a))
               ((g y x) (egcd r a)))
           (:tuple g (- x (* q y)) y))
  )

(define (mod-inv a p)
  (if (< a 0)
      (- p (mod-inv (- 0 a) p))
      (let (((g x y) (egcd a p)))
        (if (not (= g 1))
            (raise (:NoInverse a))
            (mod x p)))))

;; this is egcd, but more french

(define (bezout a b)
  (let loop ((s 0) (s0 1) (r b) (r0 a))
    (if (not (zero? r))
        (let ((q (/ r0 r)))
          (loop (- s0 (* q s)) s
                (- r0 (* q r)) r))
        (:tuple s0 (/ (- r0 (* s0 a)) b))
        )))

(define (mont N R)
  (let ((R' (mod-inv R N))
        (N' (/ (* R' R) N))
        (r2n (mod (* R R) N)))

    (printf "N  = " (int N)  " R  = " (int R)  "\n")
    (printf "N' = " (int N') " R' = " (int R') "\n")
    (printf "RR' - NN' = " (int (- (* R R') (* N N'))) "\n")
    (printf "R^2N      = " (int r2n) "\n")

    ;; montgomery reduction
    (define (redc T)
      (let ((m (mod (* (mod T R) N') R))
            (t (/ (+ T (* m N)) R)))
        (if (>= t N)
            (- t N)
            t)))

    (define (tm a)
      (redc (* (mod a N) r2n)))

    {redc=redc tm=tm fm=redc}

    ))

;; R = 2^bits
(define (mont2 N bits)
  (let ((R (<< 1 bits))
        (R' (mod-inv R N))
        (N' (/ (* R' R) N))
        (r2n (mod (* R R) N))
        (one (mod R N))
        (mask (- (<< 1 bits) 1)))

    (printf "bits " (int bits) "\n")
    (printf "mask " (zpad bits (bin mask)) "\n")
    (printf "N  = " (int N)  " R  = " (int R)  "\n")
    (printf "N' = " (int N') " R' = " (int R') "\n")
    (printf "RR' - NN' = " (int (- (* R R') (* N N'))) "\n")
    (printf "R^2N      = " (int r2n) "\n")

    ;; montgomery reduction
    (define (redc T)
      (let ((m (logand (* (logand T mask) N') mask))
            (t (>> (+ T (* m N)) bits)))
        (if (>= t N)
            (- t N)
            t)))

    (define (tm a)
      (redc (* (mod a N) r2n)))

    (define (pow x n)
      (let ((z one))
        (while (> n 0)
          (when (odd? n)
            (set! z (redc (* z x))))
          (set! x (redc (* x x)))
          (set! n (>> n 1)))
        (redc z)
        ;;z
        ))

    {redc=redc tm=tm fm=redc pow=pow}

    ))

;; (mont 17 100)
;; (mont 251 256)
;; (mont 1021 1024)
;; (mont 2039 2048)

(let ((N 13)
      (R 16)
      (M (mont2 13 4)))
  (for-range i 13
    (printf (lpad 3 (int i))
            (lpad 3 (int (M.tm i)))
            (lpad 3 (int (M.fm (M.tm i))))
            (lpad 3 (int (M.fm (M.tm (+ i 7)))))
            "\n"
            ))
  (printf "redc(12)   = " (int (M.redc 12)) "\n")
  (printf "fm(4)      = " (int (M.fm 4)) "\n")
  (printf "tm(4)      = " (int (M.tm 4)) "\n")
  (printf "redc(9*12) = " (int (M.redc (* 9 12))) "\n")
  (printf "pow(3,5)   = " (int (M.pow (M.tm 3) 5)) "\n")
  (printf "pow(7,12)  = " (int (M.pow (M.tm 7) 12)) "\n")
  )