-- | Schur polynomials

{-# LANGUAGE ScopedTypeVariables, TypeFamilies, BangPatterns #-}
module Math.Algebra.Schur where


import Control.Monad
import Control.Monad.ST

import Data.Array.Base
import Data.Array.IArray
import Data.Array.MArray
import Data.Array.Unsafe
import Data.Array.ST

import Data.List
import Data.Ratio
import Data.STRef

import Math.Combinat.Classes
import Math.Combinat.Partitions.Integer
import Math.Combinat.Sets

import qualified Data.Map as Map

import Debug.Trace
import GHC.IO ( unsafeIOToST )

import Math.Algebra.Determinant
import Math.Algebra.ModP


-- segre :: Num a => Int -> [a] -> a
-- segre k xs = sum $ map product $ combine k xs

-- * Elementary and complete symmetric polynomials

-- | Precalc chern classes
elemSymmArray :: forall a . Num a => [a] -> Array Int a
elemSymmArray xs = 
  runST $ do
    ar <- newArray (1,n) 0 :: ST s (STArray s Int a)
    mapM_ (worker ar) (zip [1..n] xs)
    unsafeFreeze ar
    n = length xs
    worker ar (i,x) = 
      forM_ [i,i-1..1] $ \j -> do
        a  <- lkp ar    j
        b  <- lkp ar (  j - 1 )
        writeArray ar j (a + x*b)
    lkp ar j = if j>=1 
      then  readArray ar j 
      else  return 1

-- | Precalc segre classes
completeSymmArray :: forall a . Num a => Int -> [a] -> Array Int a
completeSymmArray m xs = 
  runST $ do
    ar <- newArray ((1,1),(n,m)) 0 :: ST s (STArray s (Int,Int) a)
    mapM_ (worker ar) (zip [1..n] xs)
    ys <- forM [1..m] $ \j -> readArray ar (n,j)
    return $ listArray (1,m) ys
    n = length xs

    worker :: (STArray s (Int,Int) a) -> (Int,a) -> ST s ()
    worker ar (i,x) = 
      forM_ [1..m] $ \j -> do
        a  <- lkp ar (i-1) (j  )
        b  <- lkp ar (i  ) (j-1)
        writeArray ar (i,j) (a + x*b)

    lkp ar i j 
      | j>=1 && i>=1  = readArray ar (i,j)
      | j==0          = return 1
      | i==0          = return 0

-- * Schur polynomials
schurMatrixChern :: Num a => (Int -> a) -> Partition -> Matrix a
schurMatrixChern c shape = schurMatrixSegre c (dualPartition shape)

schurMatrixSegre :: Num a => (Int -> a) -> Partition -> Matrix a
schurMatrixSegre s shape = matrix where
  matrix = array ((1,1),(n,n)) entries
  n = height (dualPartition shape)
  f k  | k  <  0  =  0
       | k  == 0  =  1
       | k  >  0  =  s k 
  entries = [ ( (i,j) , f (k + j - i) ) | (i,k) <- zip [1..n] shape' , j<-[1..n] ]
  shape' = fromPartition shape ++ repeat 0



{-# SPECIALIZE schurDeterminantChern :: (Int -> Integer) -> Partition -> Integer #-}
{-# SPECIALIZE schurDeterminantSegre :: (Int -> Integer) -> Partition -> Integer #-}  

-- | Jacobi-Trudi formula
schurDeterminantChern :: Integral a => (Int -> a) -> Partition -> a
schurDeterminantChern chern = bareissDeterminantFullRank . schurMatrixChern chern

schurDeterminantSegre :: Integral a => (Int -> a) -> Partition -> a
schurDeterminantSegre segre = bareissDeterminantFullRank . schurMatrixSegre segre

schurFromChernArray :: Integral a => Array Int a -> Partition -> a
schurFromChernArray ar part = schurDeterminantChern f part where
  (1,n) = bounds ar
  f k | k<=n  =  ar!k
      | k> n  =  0

schurFromSegreArray :: Integral a => Array Int a -> Partition -> a
schurFromSegreArray ar part = schurDeterminantSegre f part where
  (1,n) = bounds ar
  f k | k<=n = ar!k
      | k>n  = error $ "schur-segre " ++ show k ++ " " ++ show n ++ " " ++ show part


-- * caching

makeSegreSchurCache :: forall s. Array Int Integer -> ST s (Partition -> ST s Integer)
makeSegreSchurCache ar = do
  cacheRef <- newSTRef Map.empty :: ST s (STRef s (Map.Map Partition Integer))
  let fun !part = do
        table <- readSTRef cacheRef
        case Map.lookup part table of
          Just y   -> return y
          Nothing  -> do
            let y = schurFromSegreArray ar part
            writeSTRef cacheRef $! Map.insert part y table
            return y
  return fun