-- | Free modules implemented as sorted lists of @(base,coeff)@ pairs.
-- The functions 'coeff', 'maxTerm', 'split', 'unsafeJoin' are slow 
-- in this implementation.

{-# LANGUAGE TypeFamilies, DeriveFunctor #-}
module Math.FreeModule.SortedList
  ( module Math.FreeModule.Class  
  , baseMap
  , coeffMap
  , FreeMod
  , ZModule
  , QModule
  )
  where

--------------------------------------------------------------------------------

import Data.List
import Data.Ord

import Math.FreeModule.Class hiding (baseMap,coeffMap)
import Math.FreeModule.PrettyPrint
import Math.FreeModule.Helper

--------------------------------------------------------------------------------

newtype FreeMod b c = S [(b,c)] deriving (Eq,Ord,Show,Functor)

type ZModule b = FreeMod b Integer
type QModule b = FreeMod b Rational

--------------------------------------------------------------------------------

-- hackish solution to implementation-specific baseMap/coeffMap:
-- import this module only, which hides the generic implementation
baseMap :: Ord b => (a -> b) -> FreeMod a c -> FreeMod b c
baseMap = sortedlistBaseMap

coeffMap :: (c -> d) -> FreeMod b c -> FreeMod b d
coeffMap = sortedlistCoeffMap

sortedlistBaseMap :: Ord b => (a -> b) -> FreeMod a c -> FreeMod b c
sortedlistBaseMap  f (S xs) = S (sortByFst (map (f<#>id) xs))

sortedlistCoeffMap :: (c -> d) -> FreeMod b c -> FreeMod b d
sortedlistCoeffMap g (S xs) = S (map (id<#>g) xs)

-- does not work?
{- RULES "baseMap/SortedList"  baseMap  = slBaseMap  -}
{- RULES "coeffMap/SortedList" coeffMap = slCoeffMap -}

--------------------------------------------------------------------------------

instance (Ord b, Eq c, Num c) => FreeModule (FreeMod b c) where

  type Base  (FreeMod b c) = b
  type Coeff (FreeMod b c) = c
  
  isZero (S xs) = case xs of { [] -> True ; _ -> False }
  zero = S []
  fromBase b   = S [(b,1)]
  fromTerm b c = S [(b,c)]
  scalarMul c (S xs) = S (map (id<#>(*c)) xs)

  coeff b (S xs) = case lookup b xs of
    Nothing -> 0
    Just c  -> c
  
  unionWith f (S xs) (S ys) = S (unionWorker f xs ys)
  
  size (S xs) = length xs
  
  minTerm (S xs) = case xs of 
    [] -> error "minTerm: empty"
    _  -> head xs
  maxTerm (S xs) = case xs of 
    [] -> error "maxTerm: empty"
    _  -> last xs
    
  split (S xs) = (S ys, S zs) where (ys,zs) = splitAt (length xs `div` 2) xs
  unsafeJoin (S xs) (S ys) = S (xs++ys)
  
  toList (S xs) = xs
  fromList xs = S $ filterNotZero $ collapse $ sortByFst $ xs where
    collapse = map f . groupBy (equating fst) 
    f xs = (fst (head xs), sum (map snd xs))
  fromAscendingList = S

--------------------------------------------------------------------------------
  
unionWorker :: (Ord b, Eq c, Num c) => (c -> c -> c) -> [(b,c)] -> [(b,c)] -> [(b,c)]
unionWorker f xs [] = map (\(b,x) -> (b, f x 0)) xs
unionWorker f [] ys = map (\(b,y) -> (b, f 0 y)) ys
unionWorker f xxs@(x@(b1,c1):xs) yys@(y@(b2,c2):ys) = 
  case compare b1 b2 of
    LT -> g b1 c1 0  (unionWorker f xs  yys)
    GT -> g b2 0  c2 (unionWorker f xxs ys ) 
    EQ -> g b1 c1 c2 (unionWorker f xs  ys )
  where
    g b c1 c2 rest = case f c1 c2 of
      0 -> rest
      c -> (b,c) : rest
        
--------------------------------------------------------------------------------