-- | Calculates the Thom polynomial of @Sigma^{ij}@ with localization 
-- and the substitution trick

{-# LANGUAGE ScopedTypeVariables, TypeFamilies, BangPatterns, PackageImports #-}
module Math.ThomPoly.SigmaIJ where

--------------------------------------------------------------------------------

import Control.Monad
import Control.Monad.ST
import Data.STRef

import Data.Array.IArray
import Data.Array.Unsafe
import Data.Array.ST

import Data.List
import Data.Ratio
import Data.Proxy

import Debug.Trace
import GHC.IO ( unsafeIOToST )

import System.Mem
import System.IO

import Math.Combinat.Classes
import Math.Combinat.Partitions.Integer
import Math.Combinat.Sets

import Math.FreeModule.Symbol
import Math.FreeModule.SortedList
import Math.FreeModule.PrettyPrint
import Math.FreeModule.PP
-- import Math.FreeModule.Parser

import Math.Algebra.ModP
import Math.Algebra.Schur

import Math.ThomPoly.Subs
import Math.ThomPoly.Shared

--------------------------------------------------------------------------------

instance Problem SigmaIJ where
  calcStats = statsIJ
  solve     = sigmaij
  baseFName (SigmaIJ i j n)  = "sigmaij__i" ++ show i ++ "_j" ++ show j ++ "_n" ++ show n
  
--------------------------------------------------------------------------------
-- * @Sigma^{ij}@

data SigmaIJ = SigmaIJ
  { _i :: !Int     -- ^ the index @i@
  , _j :: !Int     -- ^ the index @j@
  , _n :: !Int     -- ^ the source dimension @n@
  }
 deriving (Eq,Show)

-- | We need @n >= mu@ with this method
smallestIJ :: (Int,Int) -> SigmaIJ
smallestIJ ij@(i,j) = SigmaIJ i j (calcMu ij)

-- | The codimension of @Sigma^{i,j}(n,m)@
codim :: SigmaIJ -> Int -> Int
codim (SigmaIJ i j n) m = calcMu (i,j) * (m-n+i)  - (i-j)*j

-- | There is a sign in the localization formula.
signCorrection :: SigmaIJ -> Int
signCorrection (SigmaIJ i j n) = (-1)^p where
  p = n*mu + i*(j-mu)-j*j 
  mu = calcMu (i,j)

-- | computes the (shifted) algebraic multiplicity @mu = i + (j `o` i)@
calcMu :: (Int,Int) -> Int 
calcMu (i,j) = i + (j `o` i)

-- | Signed pairs of partitions appearing in the Thom polynomial of @Sigma^{ij}@
listPosNeg :: SigmaIJ -> [(Partition,Partition)]
listPosNeg (SigmaIJ i j n) = list where
  list = partitionPairs mu n i (-j*(i-j))
  mu   = i + (j `o` i)

statsIJ :: SigmaIJ -> Stats
statsIJ prob@(SigmaIJ i j n) = Stats 
  { _mu       = calcMu (i,j) 
  , _codim0   = codim prob n
  , _maxPairs = length $ listPosNeg prob
  }

--------------------------------------------------------------------------------

-- | A fixed point   
data Fixpoint2 = Fix2 
  { _ii  :: [Int] 
  , _jj  :: [Int] 
  , _ioj :: [(Int,Int)] -- ioj = jj `o` ii
  , _kk  :: [Int]       -- kk = nn\ii
  , _ss  :: [Int]       -- ioj resze
  , _rr  :: [Int]       -- nn\\ii resze
  }
  deriving Show

-- | dimension of a \"half-symmetric tensor product\"
o :: Int -> Int -> Int  
j `o` i = 
  if j<=i 
    then div (j*(j+1)) 2 + j*(i-j)
    else error "half-symmetric tensor product [dim]: error"
    
-- | \"half-symmetric tensor product\"
--
-- > length (js `oo` is) == (length js) `o` (length is)
--
oo :: [Int] -> [Int] -> [(Int,Int)]
jj `oo` ii = 
  if and [ j `elem` ii | j<-jj ] 
    then map (\[x,y]->(x,y)) (choose 2 jj) ++ 
         [ (j,j) | j<-jj ] ++
         [ (j,i) | j<-jj, i<-ii_minus_jj ] 
    else error "half-symmetric tensor product [list]: error"
  where
    ii_minus_jj = ii \\ jj 

--------------------------------------------------------------------------------

sigmaij :: CoeffRing coeff => Proxy coeff -> Batch -> SigmaIJ -> FreeMod Schur (FieldOfFractions coeff)
sigmaij pxy batch problem@(SigmaIJ i j n) = sigmaij' pxy problem (selectBatch batch posneg) where
  posneg = partitionPairs mu n i (-j*(i-j))
  mu     = i + (j `o` i)
    
sigmaij' 
  :: forall coeff. CoeffRing coeff 
  => Proxy coeff -> SigmaIJ -> [(Partition,Partition)] -> FreeMod Schur (FieldOfFractions coeff)
sigmaij' _ problem@(SigmaIJ i j n) posneg = {- if n<mu then error "n<mu" else -} result where

  result = runST stuff  

  phi (j,i) = alpha j ^+^ alpha i

  stuff :: forall s. ST s (FreeMod Schur (FieldOfFractions coeff))
  stuff = do
    starr <- newArray (1,nparts) 0 :: ST s (STArray s Int (FieldOfFractions coeff))
    
    forM_ (choose i nn) $ \ii -> do
      let ni = nn \\ ii
          tng1' = [ alpha b ^-^ alpha a | a<-ii, b<-ni ] 
          sol1' = [ alpha a | a<-ii] 
          tng1 = map subs tng1'
          sol1 = map subs sol1'
      forM_ (choose j ii) $ \jj -> do
        let ij = ii \\ jj    :: [Int]
            ioj = jj `oo` ii :: [(Int,Int)]          
            tng2' = [ alpha b ^-^ alpha a | a<-jj, b<-ij ]
            tng2  = map subs tng2'
        forM_ [0..mu'] $ \k -> do
          forM_ (choose k ioj) $ \ss -> do     -- ss is 'coim'
            forM_ (choose k ni) $ \rr -> do    -- rr is 'im'
              let ker   = ioj \\ ss
                  coker = ni \\ rr
                  tng3' =  [ alpha b ^-^ phi   a | a<-ss  , b<-rr    ]
                        ++ [ phi   a ^-^ phi   b | a<-ss  , b<-ker   ] -- itt van az elojel!
                        ++ [ alpha b ^-^ alpha a | a<-rr  , b<-coker ] 
                        ++ [ alpha b ^-^ phi   a | a<-ker , b<-coker ] 
                  tng3 = map subs tng3'

              let tng123' = tng1' ++ tng2' ++ tng3'
                  tng123  = tng1  ++ tng2  ++ tng3
                  z = product tng123
                  sol2 = map subs 
                       $ [ phi a | a<-ker ] ++ [ alpha b | b<-rr]  

              when (z==0) $ unsafeIOToST $ do
                putStrLn $ "error: zero denominator!"
                putStrLn $ "substitution table: " ++ show (elems subsTable)
                forM_ (zip tng123 tng123') $ \(a,p) -> do
                  when (a==0) $ putStrLn (pretty p ++ " == 0")
                     
              let sol = sol1 ++ sol2
                  -- chern = elemSymmArray sol
                  segre = completeSymmArray (i*(n-i)+j*(i-j)+mu+(n-i)) sol

              -- cachedSchur <- makeSegreSchurCache segre
                  
              forM_ [1..nparts] $ \j -> do
                let clambda = complLambdaArr ! j
                -- let y = (if odd k then negate else id) (schurFromChernArray chern clambda)
                let y = (if odd k then negate else id) (schurFromSegreArray segre clambda)
                x <- readArray starr j 
                x `seq` y `seq` z `seq` writeArray starr j (x + correctTheSign (embed y / embed z))
                return ()
    
    arr <- unsafeFreeze starr :: ST s (Array Int (FieldOfFractions coeff))
    let g (j,x) = ( Schur (renormLambdaArr!j) , x ) 
        bcs = map g (assocs arr)
    return (fromList bcs)

  correctTheSign :: FieldOfFractions coeff -> FieldOfFractions coeff
  correctTheSign = if signCorrection problem < 0 then negate else id
    
  nn  = [1..n] 
  mu' = j `o` i 
  mu  = i + mu'
  nparts = length posneg
    
  renormLambdaArr = 
    listArray (1,nparts) 
      [ posnegPairToPartition (   i,mu) (pos,neg) | (pos,neg) <- posneg ]
        :: Array Int Partition
  complLambdaArr = 
    listArray (1,nparts) 
      [ posnegPairToPartition ( n-i,mu) (neg,pos) | (pos,neg) <- posneg ]
        :: Array Int Partition

  subs :: Term coeff -> coeff
  subs = evaluate f where 
    f (Symbol "alpha" (Just i)) coeff = coeff * fromInteger (subsTable!i)

  subsTable = getSubsNum n

{-
  subs :: Term -> Integer
  subs = evaluate f where 
    f (Symbol "alpha" (Just i)) coeff = coeff * q^(i-1)
    q = 1 + fromIntegral n :: Integer
-}

--------------------------------------------------------------------------------