-- | Prime fields. 
--
-- TODO: do it properly; and fast implementation for specialized prime fields
--

{-# LANGUAGE BangPatterns #-}
module Math.Algebra.ModP where

--------------------------------------------------------------------------------

import Data.Bits
import Data.Ratio
import Data.Int

--------------------------------------------------------------------------------

-- | @2^31-1@ is a prime (in practice this seems to be significantly faster than @2^63-25@)
p :: Int64
p = 2^31 - 1      

-- p = 20551         -- max coefficient in 3/1/8 is 20460
-- p = 2^31 - 1      -- @2^31-1@ 
-- p = 2^33 - 9      -- @2^33-9@
-- p = 2^62 - 57     -- @2^62-57@
-- p = 2^63 - 25     -- @2^63-25@


--------------------------------------------------------------------------------

newtype Zp = Zp Int64 deriving (Eq, Show)

fromZp :: Zp -> Int
fromZp (Zp k) = fromIntegral k

mkZp :: Integral a => a -> Zp
mkZp n = Zp (mod (fromIntegral n) p)

--------------------------------------------------------------------------------

instance Num Zp where
  (+)          = addZp 
  (-)          = subZp 
  (*)          = mulZp 
  fromInteger  = mkZp . fromInteger
  abs          = id
  signum _     = Zp 1

instance Fractional Zp where
  recip (Zp a)   = mkZp $ invZp_euclid a
  a / b          = a * recip b
  fromRational r = fromInteger (numerator r) / fromInteger (denominator r)

--------------------------------------------------------------------------------

addZp :: Zp -> Zp -> Zp
addZp (Zp a) (Zp b) 
  | c <  0    = Zp (c - p)               -- overflow
  | c >= p    = Zp (c - p)
  | otherwise = Zp  c
  where
    c = a + b

subZp :: Zp -> Zp -> Zp
subZp (Zp a) (Zp b) = Zp (if b<=a then a-b else a+p-b)

mulZp :: Zp -> Zp -> Zp
mulZp (Zp a0) (Zp b0) = Zp (fromInteger c) where
  a = fromIntegral a0 :: Integer                    -- because Int can overflow :(
  b = fromIntegral b0 :: Integer
  c = mod (a * b) (fromIntegral p)

-- | Inverse using the binary Euclidean algorithm 
invZp_euclid :: Int64 -> Int64
invZp_euclid a 
  | a == 0     = 0
  | otherwise  = go 1 0 a p
  where
  
    modp :: Int64 -> Int64
    modp n = mod n p

    halfp1 = shiftR (p+1) 1

    go :: Int64 -> Int64 -> Int64 -> Int64 -> Int64
    go !x1 !x2 !u !v 
      | u==1       = x1
      | v==1       = x2
      | otherwise  = stepU x1 x2 u v

    stepU :: Int64 -> Int64 -> Int64 -> Int64 -> Int64
    stepU !x1 !x2 !u !v = if even u 
      then let u'  = shiftR u 1
               x1' = if even x1 then shiftR x1 1 else shiftR x1 1 + halfp1
           in  stepU x1' x2 u' v
      else     stepV x1  x2 u  v

    stepV :: Int64 -> Int64 -> Int64 -> Int64 -> Int64
    stepV !x1 !x2 !u !v = if even v
      then let v'  = shiftR v 1
               x2' = if even x2 then shiftR x2 1 else shiftR x2 1 + halfp1
           in  stepV x1 x2' u v' 
      else     final x1 x2  u v

    final :: Int64 -> Int64 -> Int64 -> Int64 -> Int64
    final !x1 !x2 !u !v = if u>=v

      then let u'  = u-v
               x1' = if x1 >= x2 then modp (x1-x2) else modp (x1+p-x2)               
           in  go x1' x2  u' v 

      else let v'  = v-u
               x2' = if x2 >= x1 then modp (x2-x1) else modp (x2+p-x1)
           in  go x1  x2' u  v'

--------------------------------------------------------------------------------