-- | Calculates the Thom polynomial of @Sigma^{i}@ with localization and the substitution trick
-- (for sanity testing only, as we know the answer anyway)
--

{-# LANGUAGE ScopedTypeVariables, TypeFamilies, BangPatterns, PackageImports #-}
module Math.ThomPoly.SigmaI where

--------------------------------------------------------------------------------

import Control.Monad
import Control.Monad.ST
import Data.STRef

import Data.Array.IArray
import Data.Array.Unsafe
import Data.Array.ST

import Data.List
import Data.Ratio
import Data.Proxy

import Debug.Trace
import GHC.IO ( unsafeIOToST )

import Math.Combinat.Classes
import Math.Combinat.Partitions.Integer
import Math.Combinat.Sets

import Math.FreeModule.Symbol
import Math.FreeModule.SortedList
-- import Math.FreeModule.PrettyPrint
-- import Math.FreeModule.PP
-- import Math.FreeModule.Parser

import Math.Algebra.ModP
import Math.Algebra.Schur

import Math.ThomPoly.Subs
import Math.ThomPoly.Shared

--------------------------------------------------------------------------------

instance Problem SigmaI where
  calcStats = statsI
  solve     = sigmai
  baseFName (SigmaI i n)  = "sigmai__i" ++ show i ++ "_n" ++ show n
  
--------------------------------------------------------------------------------
-- * @Sigma^{i}@

data SigmaI = SigmaI
  { _i :: !Int            -- ^ corank of the differential
  , _n :: !Int            -- ^ source dimension
  }
 deriving (Eq,Show)

-- | We need @n >= mu@ with this method
smallestI :: Int -> SigmaI
smallestI i = SigmaI i i

-- | The codimension of @Sigma^{i}(n,m)@ is @codim = i*(m-n+i)@
codim :: SigmaI -> Int -> Int
codim (SigmaI i n) m = i * (m-n+i)  

-- | There is a sign in the localization formula.
signCorrection :: SigmaI -> Int
signCorrection (SigmaI i n) = i*(n-i)

statsI :: SigmaI -> Stats
statsI prob@(SigmaI i n) = 
  Stats 
    { _mu       = i 
    , _codim0   = codim prob n 
    , _maxPairs = length posneg
    } 
  where
    posneg = partitionPairs mu n i 0
    mu = i


--------------------------------------------------------------------------------
-- @Sigma^i@  

type Fixpoint1 = [Int]   
 
sigmai :: CoeffRing coeff => Proxy coeff -> Batch -> SigmaI -> FreeMod Schur (FieldOfFractions coeff)
sigmai pxy batch problem@(SigmaI i n) = sigmai' pxy problem (selectBatch batch posneg) where
  posneg = partitionPairs mu n i 0
  mu     = i

sigmai' 
  :: forall coeff. CoeffRing coeff 
  => Proxy coeff -> SigmaI -> [(Partition,Partition)] -> FreeMod Schur (FieldOfFractions coeff)
sigmai' _ problem@(SigmaI i n) posneg = result where

  result = runST stuff  

  stuff :: forall s. ST s (FreeMod Schur (FieldOfFractions coeff))
  stuff = do
    starr <- newArray (1,nparts) 0 :: ST s (STArray s Int (FieldOfFractions coeff))
    forM_ fixpoints (worker starr)
    arr <- unsafeFreeze starr :: ST s (Array Int (FieldOfFractions coeff))
    let g (j,x) = ( Schur (renormLambdaArr!j) , x ) 
        bcs = map g (assocs arr)
    return (fromList bcs)
    
  renormLambdaArr = 
    listArray (1,nparts) 
      [ posnegPairToPartition (   i,mu) (pos,neg) | (pos,neg) <- posneg ]
        :: Array Int Partition
  complLambdaArr = 
    listArray (1,nparts) 
      [ posnegPairToPartition ( n-i,mu) (neg,pos) | (pos,neg) <- posneg ]
        :: Array Int Partition

{-
  subs :: Term -> Integer
  subs = evaluate f where 
    f (Symbol "alpha" (Just i)) coeff = coeff * q^(i-1)
    q = 1 + fromIntegral n :: Integer
-}

  subs :: Term coeff -> coeff
  subs = evaluate f where 
    f (Symbol "alpha" (Just i)) coeff = coeff * fromInteger (subsTable!i)

  subsTable = getSubs n

  worker :: STArray s Int (FieldOfFractions coeff) -> Fixpoint1 -> ST s ()
  worker arr fixpoint = do
    let sol = map subs $ solution fixpoint
        tng = map subs $ tangent  fixpoint
        z = product tng
        chern = elemSymmArray sol
    forM_ [1..nparts] $ \j -> do
      let clambda = complLambdaArr ! j
          y = schurFromChernArray chern clambda
      readArray arr j >>= \x -> writeArray arr j (x + correctTheSign (embed y / embed z))
      return ()

  correctTheSign :: FieldOfFractions coeff -> FieldOfFractions coeff
  correctTheSign = if signCorrection problem < 0 then negate else id
  
  solution :: Fixpoint1 -> [Term coeff]
  solution = map alpha 
  
  tangent :: Fixpoint1 -> [Term coeff]
  tangent xs = [ alpha j ^-^ alpha i | i<-xs, j<-ys ] where ys = [1..n] \\ xs
  
  mu     = i
  nparts = length posneg
  fixpoints = choose i [1..n] 
  
--------------------------------------------------------------------------------