-- | Determinants.
--
-- TODO: specialized prime fields; fast C implementation; pivoting for Bareiss
--

{-# LANGUAGE ScopedTypeVariables, TypeFamilies, BangPatterns, 
             FlexibleInstances, TypeSynonymInstances,
             ForeignFunctionInterface
  #-}
module Math.Algebra.Determinant where

--------------------------------------------------------------------------------

import Control.Monad
import Control.Monad.ST

import Data.Array.Base
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unsafe
import Data.Array.ST

import Data.List
import Data.Ratio
import Data.STRef

import Data.Bits
import Data.Word
import Data.Int

import Foreign.C
import Foreign.Ptr
import Foreign.Marshal
import System.IO.Unsafe as Unsafe

import System.Random

import Debug.Trace
import GHC.IO ( unsafeIOToST )

import Math.Algebra.ModP

--------------------------------------------------------------------------------
-- * matrices

type Matrix a = Array (Int,Int) a

printMatrix :: Show a => Matrix a -> IO ()
printMatrix = putStrLn . showMatrix

showMatrix :: Show a => Matrix a -> String
showMatrix = unlines . showMatrix'

showMatrix' :: Show a => Matrix a -> [String]
showMatrix' mat = map mkRow (transpose cols) where
  ((1,1),(n,m)) = bounds mat
  cols = map extend [ [ show (mat!(i,j)) | i<-[1..n] ] | j<-[1..m] ]

  mkRow strs = "[ " ++ intercalate " " strs ++ " ]"

  extend :: [String] -> [String]
  extend xs = map f xs where
    n = maximum (map length xs)
    f s = replicate (n - length s) ' ' ++ s

--------------------------------------------------------------------------------
-- * a type class for determinants

class (Eq a, Num a, Show a) => Determinant a where 
  determinant :: Matrix a -> a

instance Determinant Integer  where determinant = bareissDeterminantFullRank
instance Determinant Int      where determinant = bareissDeterminantFullRank
instance Determinant Rational where determinant = gaussElimDeterminant
instance Determinant Zp       where determinant = gaussElimDeterminantInt64

--------------------------------------------------------------------------------
-- * C implementation of determinant in a prime field (gaussian elimination, fitting into 64 bit)

foreign import ccall "c_det.h inv_modp" c_inv_modp :: Int64 -> Int64 -> Int64
foreign import ccall "c_det.h det_modp" c_det_modp :: Int64 -> CInt  -> Ptr Int64 -> IO Int64

fastDetModP :: Int64 -> Matrix Int64 -> Int64
fastDetModP p mat = Unsafe.unsafePerformIO $ ioFastDetModP p mat

ioFastDetModP :: Int64 -> Matrix Int64 -> IO Int64
ioFastDetModP p mat = do
  let ((1,1),(n,_)) = bounds mat
  withArray (elems mat) $ \ptr -> c_det_modp p (fromIntegral n :: CInt) ptr

gaussElimDeterminantInt64 :: Matrix Zp -> Zp
gaussElimDeterminantInt64 mat = 
  Unsafe.unsafePerformIO $ do
    let pp = fromIntegral p :: Int64
    let ((1,1),(n,_)) = bounds mat
        xs = map (fromIntegral . fromZp) (elems mat) :: [Int64]
    d <- withArray xs $ \ptr -> c_det_modp pp (fromIntegral n :: CInt) ptr
    return $ Zp $ fromIntegral d

--------------------------------------------------------------------------------
-- * Bareiss determinant algorithm

type STMatrix s a = STArray s (Int,Int) a

-- | Works only if the top-left minors all have nonzero determinants
{-# SPECIALIZE bareissDeterminantFullRank :: Matrix Integer -> Integer #-}
{-# SPECIALIZE bareissDeterminantFullRank :: Matrix Int     -> Int     #-}
bareissDeterminantFullRank :: forall a . Integral a => Matrix a -> a
bareissDeterminantFullRank mat = 

  if n>0 
    then runST $ do
      ar1   <- thaw mat       :: ST s (STMatrix s a)  
      ar2   <- newArray_ siz  :: ST s (STMatrix s a)
      last  <- newSTRef 1     :: ST s (STRef s a)
      (ar,_) <- foldM (worker last) (ar1,ar2) [1..n-1] 
      readArray ar (n,n) 
    else 1  -- determinant of the empty matrix is 1

  where 

    siz@((1,1),(n,_)) = bounds mat

    unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a
    unsafeReadArray ar ij = unsafeRead ar (index siz ij)

    unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s ()
    unsafeWriteArray ar ij x = unsafeWrite ar (index siz ij) x

    worker :: STRef s a -> (STMatrix s a, STMatrix s a)  -> Int -> ST s (STMatrix s a, STMatrix s a)
    worker last (ar1,ar2) !k = do
      q <- readSTRef last             

      when (q==0) $ unsafeIOToST $ do
        putStrLn "divison by zero while computing the determinant..."

      forM_ [k+1..n] $ \(!i) -> 
        forM_ [k+1..n] $ \(!j) -> do
          a <- unsafeReadArray ar1 (k,k)
          b <- unsafeReadArray ar1 (i,k)
          c <- unsafeReadArray ar1 (k,j)
          d <- unsafeReadArray ar1 (i,j)
          unsafeWriteArray ar2 (i,j) $ (a*d - b*c) `div` q      
      unsafeReadArray ar1 (k,k) >>= writeSTRef last 
      return (ar2,ar1)

--------------------------------------------------------------------------------
-- * Gaussian elimination

{-# SPECIALIZE gaussElimDeterminant :: Matrix Rational -> Rational #-}
{-# SPECIALIZE gaussElimDeterminant :: Matrix Zp       -> Zp       #-}
gaussElimDeterminant :: forall a. (Eq a, Show a, Fractional a) => Matrix a -> a
gaussElimDeterminant mat =  

  if n <= 0 
    then 1             -- determinant of the empty matrix is 1
    else runST $ do
      -- unsafeIOToST (printMatrix mat >> putStrLn "")
      neg <- newSTRef False 
      arr <- thaw mat :: ST s (STMatrix s a)  
      worker neg arr 1

  where 

    siz@((1,1),(n,_)) = bounds mat

    unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a
    unsafeReadArray !ar !ij = unsafeRead ar (index siz ij)

    unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s ()
    unsafeWriteArray !ar !ij !x = unsafeWrite ar (index siz ij) x

    finish :: STRef s Bool -> STMatrix s a -> ST s a
    finish !neg !arr = do
      diag <- sequence [ unsafeReadArray arr (i,i) | i<-[1..n] ]        
      b    <- readSTRef neg
      return $ if b 
        then negate $ product diag
        else          product diag

    worker :: STRef s Bool -> STMatrix s a -> Int -> ST s a
    worker !neg !arr !i = if i >= n 
      then finish neg arr
      else do
        ps <- sequence [ unsafeReadArray arr (i,j) | j<-[i..n] ]
        case findIndex (/=0) ps of
          Nothing    -> return 0                    -- no pivot -> line is full zero -> determinant is zero
          Just pivot -> cont neg arr i (i+pivot)

    cont :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s a
    cont !neg !arr !i !pivot = do
--      printST (i,pivot)
      when (pivot > i) $ xchg neg arr i pivot
      p <- unsafeReadArray arr (i,i)
      forM_ [i+1..n] $ \k -> do
        q <- unsafeReadArray arr (k,i)
        unsafeWriteArray arr (k,i) 0
        let z = q / p
        forM_ [i+1..n] $ \j -> do
          a <- unsafeReadArray arr (i,j)
          b <- unsafeReadArray arr (k,j)
          unsafeWriteArray arr (k,j) (b - a*z)              
      worker neg arr (i+1)  

    xchg :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s ()
    xchg !neg !arr !i !j = do
      modifySTRef neg not             -- exchanging two rows flip the sign of the determinant
      forM_ [i..n] $ \k -> do
        a <- unsafeReadArray arr (k,i)
        b <- unsafeReadArray arr (k,j)        
        unsafeWriteArray arr (k,j) a
        unsafeWriteArray arr (k,i) b

--------------------------------------------------------------------------------
-- * naive determinant algorithm (for testing purposes)

naiveDeterminant :: forall a. (Num a) => Matrix a -> a
naiveDeterminant mat
  | n <= 0    = 1
  | n == 1    = mat!(1,1)
  | n == 2    = mat!(1,1) * mat!(2,2) - mat!(1,2) * mat!(2,1)
  | otherwise = worker [1..n] [1..n]
  where

    siz@((1,1),(n,_)) = bounds mat

    signs = cycle [True,False]

    worker []     []     = 1
    worker [a]    [b]    = mat!(a,b)
    worker [a,b]  [p,q]  = mat!(a,p) * mat!(b,q) -  mat!(a,q) * mat!(b,p)
    worker (i:is) js     = foldl' (+) 0 (zipWith f signs js) where
      f b j = if b 
        then          mat!(i,j) * worker is (js\\[j])
        else negate $ mat!(i,j) * worker is (js\\[j])


--------------------------------------------------------------------------------
-- * random matrices

mkSquareMatrix :: (Int -> Int -> a) -> Int -> Matrix a
mkSquareMatrix f n = array ((1,1),(n,n)) [ ((i,j) , f i j ) | i<-[1..n] , j<-[1..n] ]

testMatrix :: Num a => Int -> Matrix a
testMatrix n = mkSquareMatrix f n where
  f i j = fromIntegral 
        $ 3 + i*i*i - j*j + (4*i*j + 3*i + 5*j + 7) + xor (13+i) (17+j) where

randomMatrix :: (Random a, Num a) => Int -> IO (Matrix a)
randomMatrix = randomMatrix' 10

randomMatrix' :: (Random a, Num a) => a -> Int -> IO (Matrix a)
randomMatrix' bnd n = do
  xs <- replicateM (n*n) (randomRIO (-bnd,bnd))
  return $ listArray ((1,1),(n,n)) xs

printST :: Show a => a -> ST s ()
printST x = unsafeIOToST (print x)

--------------------------------------------------------------------------------
-- * testing

test = do
  forM_ [1..10] $ \n -> do
    putStrLn $ "testing matrices of size " ++ show n ++ " x " ++ show n ++ "..."
    replicateM_ 100 $ do
      imat <- randomMatrix n :: IO (Matrix Integer)
      let mat = fmap fromInteger imat :: Matrix Rational

      let a = naiveDeterminant     mat
          b = gaussElimDeterminant mat

      let ia = naiveDeterminant  imat :: Integer
          amodp = mkZp $ fromIntegral (mod ia (fromIntegral p))

      let c  = gaussElimDeterminant (fmap mkZp imat)
          d0 = fastDetModP (fromIntegral p) (fmap (\a -> fromIntegral (mod a (fromIntegral p))) imat)
          d  = fromIntegral d0 :: Zp

      when (a/=b) $ do
        putStrLn "\nERROR!"
        print (a,b)
        print imat

      when (c/=d || d/=amodp) $ do
        putStrLn "\nC ERROR!"
        print (c,d,amodp)
        print imat