module Math.Algebra.Determinant where
import Control.Monad
import Control.Monad.ST
import Data.Array.Base
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unsafe
import Data.Array.ST
import Data.List
import Data.Ratio
import Data.STRef
import Data.Bits
import Data.Word
import Data.Int
import Foreign.C
import Foreign.Ptr
import Foreign.Marshal
import System.IO.Unsafe as Unsafe
import System.Random
import Debug.Trace
import GHC.IO ( unsafeIOToST )
import Math.Algebra.ModP
type Matrix a = Array (Int,Int) a
printMatrix :: Show a => Matrix a -> IO ()
printMatrix = putStrLn . showMatrix
showMatrix :: Show a => Matrix a -> String
showMatrix = unlines . showMatrix'
showMatrix' :: Show a => Matrix a -> [String]
showMatrix' mat = map mkRow (transpose cols) where
((1,1),(n,m)) = bounds mat
cols = map extend [ [ show (mat!(i,j)) | i<-[1..n] ] | j<-[1..m] ]
mkRow strs = "[ " ++ intercalate " " strs ++ " ]"
extend :: [String] -> [String]
extend xs = map f xs where
n = maximum (map length xs)
f s = replicate (n length s) ' ' ++ s
class (Eq a, Num a, Show a) => Determinant a where
determinant :: Matrix a -> a
instance Determinant Integer where determinant = bareissDeterminantFullRank
instance Determinant Int where determinant = bareissDeterminantFullRank
instance Determinant Rational where determinant = gaussElimDeterminant
instance Determinant Zp where determinant = gaussElimDeterminantInt64
foreign import ccall "c_det.h inv_modp" c_inv_modp :: Int64 -> Int64 -> Int64
foreign import ccall "c_det.h det_modp" c_det_modp :: Int64 -> CInt -> Ptr Int64 -> IO Int64
fastDetModP :: Int64 -> Matrix Int64 -> Int64
fastDetModP p mat = Unsafe.unsafePerformIO $ ioFastDetModP p mat
ioFastDetModP :: Int64 -> Matrix Int64 -> IO Int64
ioFastDetModP p mat = do
let ((1,1),(n,_)) = bounds mat
withArray (elems mat) $ \ptr -> c_det_modp p (fromIntegral n :: CInt) ptr
gaussElimDeterminantInt64 :: Matrix Zp -> Zp
gaussElimDeterminantInt64 mat =
Unsafe.unsafePerformIO $ do
let pp = fromIntegral p :: Int64
let ((1,1),(n,_)) = bounds mat
xs = map (fromIntegral . fromZp) (elems mat) :: [Int64]
d <- withArray xs $ \ptr -> c_det_modp pp (fromIntegral n :: CInt) ptr
return $ Zp $ fromIntegral d
type STMatrix s a = STArray s (Int,Int) a
bareissDeterminantFullRank :: forall a . Integral a => Matrix a -> a
bareissDeterminantFullRank mat =
if n>0
then runST $ do
ar1 <- thaw mat :: ST s (STMatrix s a)
ar2 <- newArray_ siz :: ST s (STMatrix s a)
last <- newSTRef 1 :: ST s (STRef s a)
(ar,_) <- foldM (worker last) (ar1,ar2) [1..n1]
readArray ar (n,n)
else 1
where
siz@((1,1),(n,_)) = bounds mat
unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a
unsafeReadArray ar ij = unsafeRead ar (index siz ij)
unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s ()
unsafeWriteArray ar ij x = unsafeWrite ar (index siz ij) x
worker :: STRef s a -> (STMatrix s a, STMatrix s a) -> Int -> ST s (STMatrix s a, STMatrix s a)
worker last (ar1,ar2) !k = do
q <- readSTRef last
when (q==0) $ unsafeIOToST $ do
putStrLn "divison by zero while computing the determinant..."
forM_ [k+1..n] $ \(!i) ->
forM_ [k+1..n] $ \(!j) -> do
a <- unsafeReadArray ar1 (k,k)
b <- unsafeReadArray ar1 (i,k)
c <- unsafeReadArray ar1 (k,j)
d <- unsafeReadArray ar1 (i,j)
unsafeWriteArray ar2 (i,j) $ (a*d b*c) `div` q
unsafeReadArray ar1 (k,k) >>= writeSTRef last
return (ar2,ar1)
gaussElimDeterminant :: forall a. (Eq a, Show a, Fractional a) => Matrix a -> a
gaussElimDeterminant mat =
if n <= 0
then 1
else runST $ do
neg <- newSTRef False
arr <- thaw mat :: ST s (STMatrix s a)
worker neg arr 1
where
siz@((1,1),(n,_)) = bounds mat
unsafeReadArray :: STMatrix s a -> (Int,Int) -> ST s a
unsafeReadArray !ar !ij = unsafeRead ar (index siz ij)
unsafeWriteArray :: STMatrix s a -> (Int,Int) -> a -> ST s ()
unsafeWriteArray !ar !ij !x = unsafeWrite ar (index siz ij) x
finish :: STRef s Bool -> STMatrix s a -> ST s a
finish !neg !arr = do
diag <- sequence [ unsafeReadArray arr (i,i) | i<-[1..n] ]
b <- readSTRef neg
return $ if b
then negate $ product diag
else product diag
worker :: STRef s Bool -> STMatrix s a -> Int -> ST s a
worker !neg !arr !i = if i >= n
then finish neg arr
else do
ps <- sequence [ unsafeReadArray arr (i,j) | j<-[i..n] ]
case findIndex (/=0) ps of
Nothing -> return 0
Just pivot -> cont neg arr i (i+pivot)
cont :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s a
cont !neg !arr !i !pivot = do
when (pivot > i) $ xchg neg arr i pivot
p <- unsafeReadArray arr (i,i)
forM_ [i+1..n] $ \k -> do
q <- unsafeReadArray arr (k,i)
unsafeWriteArray arr (k,i) 0
let z = q / p
forM_ [i+1..n] $ \j -> do
a <- unsafeReadArray arr (i,j)
b <- unsafeReadArray arr (k,j)
unsafeWriteArray arr (k,j) (b a*z)
worker neg arr (i+1)
xchg :: STRef s Bool -> STMatrix s a -> Int -> Int -> ST s ()
xchg !neg !arr !i !j = do
modifySTRef neg not
forM_ [i..n] $ \k -> do
a <- unsafeReadArray arr (k,i)
b <- unsafeReadArray arr (k,j)
unsafeWriteArray arr (k,j) a
unsafeWriteArray arr (k,i) b
naiveDeterminant :: forall a. (Num a) => Matrix a -> a
naiveDeterminant mat
| n <= 0 = 1
| n == 1 = mat!(1,1)
| n == 2 = mat!(1,1) * mat!(2,2) mat!(1,2) * mat!(2,1)
| otherwise = worker [1..n] [1..n]
where
siz@((1,1),(n,_)) = bounds mat
signs = cycle [True,False]
worker [] [] = 1
worker [a] [b] = mat!(a,b)
worker [a,b] [p,q] = mat!(a,p) * mat!(b,q) mat!(a,q) * mat!(b,p)
worker (i:is) js = foldl' (+) 0 (zipWith f signs js) where
f b j = if b
then mat!(i,j) * worker is (js\\[j])
else negate $ mat!(i,j) * worker is (js\\[j])
mkSquareMatrix :: (Int -> Int -> a) -> Int -> Matrix a
mkSquareMatrix f n = array ((1,1),(n,n)) [ ((i,j) , f i j ) | i<-[1..n] , j<-[1..n] ]
testMatrix :: Num a => Int -> Matrix a
testMatrix n = mkSquareMatrix f n where
f i j = fromIntegral
$ 3 + i*i*i j*j + (4*i*j + 3*i + 5*j + 7) + xor (13+i) (17+j) where
randomMatrix :: (Random a, Num a) => Int -> IO (Matrix a)
randomMatrix = randomMatrix' 10
randomMatrix' :: (Random a, Num a) => a -> Int -> IO (Matrix a)
randomMatrix' bnd n = do
xs <- replicateM (n*n) (randomRIO (bnd,bnd))
return $ listArray ((1,1),(n,n)) xs
printST :: Show a => a -> ST s ()
printST x = unsafeIOToST (print x)
test = do
forM_ [1..10] $ \n -> do
putStrLn $ "testing matrices of size " ++ show n ++ " x " ++ show n ++ "..."
replicateM_ 100 $ do
imat <- randomMatrix n :: IO (Matrix Integer)
let mat = fmap fromInteger imat :: Matrix Rational
let a = naiveDeterminant mat
b = gaussElimDeterminant mat
let ia = naiveDeterminant imat :: Integer
amodp = mkZp $ fromIntegral (mod ia (fromIntegral p))
let c = gaussElimDeterminant (fmap mkZp imat)
d0 = fastDetModP (fromIntegral p) (fmap (\a -> fromIntegral (mod a (fromIntegral p))) imat)
d = fromIntegral d0 :: Zp
when (a/=b) $ do
putStrLn "\nERROR!"
print (a,b)
print imat
when (c/=d || d/=amodp) $ do
putStrLn "\nC ERROR!"
print (c,d,amodp)
print imat