{-# LANGUAGE CPP, BangPatterns #-}


-- | Lazy bitstrings, somewhat similar to lazy bytestrings.
-- This module is intended to be imported qualified.
module Data.BitString 
-- | Big-endian bitstrings. In this context, \"big-endian\" means that
-- the bits in the bytes are in the opposite order than what would be
-- logical. If you ask me, this is just plain stupid, but some people
-- apparently still use it...
module Data.BitString.BigEndian

  , empty
  , bitString
  , bitStringLazy
  , unsafeBitString'
  , take
  , drop
  , splitAt
  , append
  , concat
  , toList
  , fromList
  , to01List
  , from01List
  , null
  , length
  , foldl'
  , findSubstring
  , realizeBitStringLazy
  , realizeBitStringStrict
  , realizeBitString'  
  , runAllTest
  , BitChunk (..)
  , BitString ( BitString )
  , mypack
  , prop_fromToList 
  , prop_toFromList 
  , prop_append     
  , prop_drop       
  , prop_take       
  , prop_dropChunk  
  , prop_takeChunk  
  , prop_realign    
  , prop_realizeChunk 
  , prop_realize      
  , prop_realizeLen      
  , prop_findSubstring1 
  , prop_findSubstring1a 
  , prop_findSubstring1b 
  , prop_findSubstring2 

import Prelude hiding (take,drop,last,length,splitAt,concat,null,rem,init)

import Control.Monad
import Control.Applicative hiding ( empty )

import Data.Bits ()
import Data.Int  ()
import Data.Word ()
import Data.Maybe
import qualified Data.List as List

import Data.ByteString (ByteString)
import qualified Data.ByteString          as B
import qualified Data.ByteString.Internal as B
-- import qualified Data.ByteString.Unsafe   as U
import qualified Data.ByteString.Lazy     as L

import Test.QuickCheck hiding ( (.&.) )
import qualified Data.ByteString.Char8 as BC
import Data.Char (ord)

import Foreign
import System.IO.Unsafe


flippedFoldM_ :: Monad m => a -> [b] -> (a -> b -> m a) -> m () 
flippedFoldM_ x ys f = foldM_ f x ys

flippedFoldM ::  Monad m => a -> [b] -> (a -> b -> m a) -> m a 
flippedFoldM x ys f = foldM f x ys



bitReverseWord8_naive :: Word8 -> Word8
bitReverseWord8_naive x 
  = shiftR (x .&. 0x80) 7
  + shiftR (x .&. 0x40) 5
  + shiftR (x .&. 0x20) 3
  + shiftR (x .&. 0x10) 1
  + shiftL (x .&. 0x08) 1
  + shiftL (x .&. 0x04) 3
  + shiftL (x .&. 0x02) 5
  + shiftL (x .&. 0x01) 7
bitReverseWord8_table :: ByteString     
bitReverseWord8_table = B.pack $ map reverseWord8_naive [0..255]

bitReverseWord8 :: Word8 -> Word8
bitReverseWord8 = U.unsafeIndex reverseWord8_table . fromIntegral

byteReverseWord32 :: Word32 -> Word32
byteReverseWord32 w 
  = shiftR w 24 
  + shiftR w 8  .&. 0x0000ff00 
  + shiftL w 8  .&. 0x00ff0000 
  + shiftL w 24



data BitChunk = BitChunk
  { bitChunkOffset :: !Int64
  , bitChunkLength :: !Int64
  , bitChunkData   :: !ByteString


mypack :: String -> ByteString
mypack = B.pack . map c2w where
  c2w = fromIntegral . ord

instance Show BitChunk where
  show chunk = "BitChunk <" ++ map f (bitChunkTo01List chunk) ++ ">" where 
    f 0 = '0'
    f 1 = '1'


emptyBitChunk :: BitChunk
emptyBitChunk = BitChunk 0 0 B.empty

bitChunk' :: Int64 -> ByteString -> BitChunk  
bitChunk' ofs bs = BitString ofs (len-ofs) bs where 
  len = 8 * fromIntegral (B.length bs)

-- | warning! no boundary checks
  :: Int64      -- ^ offset 
  -> Int64      -- ^ length
  -> ByteString  -- ^ source
  -> BitChunk 
unsafeBitChunk' ofs len dat = BitChunk ofs len dat where 

bitChunk :: ByteString -> BitChunk  
bitChunk bs = unsafeBitChunk' 0 (8 * fromIntegral (B.length bs)) bs


bitChunkDrop :: Int64 -> BitChunk -> BitChunk
bitChunkDrop k (BitChunk ofs len dat) = if k<len
  then BitChunk (ofs+k) (len-k) dat
  else emptyBitChunk

bitChunkTake :: Int64 -> BitChunk -> BitChunk
bitChunkTake k bc@(BitChunk ofs len dat) 
  | k==0      = emptyBitChunk
  | k<=len    = BitChunk ofs k dat
  | otherwise = bc

-- TODO: better implementation    
splitBitChunkAt :: Int64 -> BitChunk -> (BitChunk,BitChunk)    
splitBitChunkAt k b = (bitChunkTake k b, bitChunkDrop k b)  


{-# INLINE boolToWord8 #-}
boolToWord8 :: Bool -> Word8
boolToWord8 bool = case bool of
  True  -> 1
  False -> 0

{-# INLINE word8ToBool #-}
word8ToBool :: Word8 -> Bool
word8ToBool w = (w/=0)


unsafeLookupBitChunk :: BitChunk -> Int64 -> Bool
unsafeLookupBitChunk chunk j = unsafeLookupBitChunk01 chunk j /= 0

unsafeLookupBitChunk01 :: BitChunk -> Int64 -> Word8
unsafeLookupBitChunk01 (BitChunk ofs len dat) j = bit where
  (n,k) = divMod (ofs+j) 8
  byte = B.index dat (fromIntegral n)
  bit = ((shiftR byte (fromIntegral k)) .&. 1) 
  bit = ((shiftR byte (fromIntegral (7-k))) .&. 1) 

bitChunkToList :: BitChunk -> [Bool]
bitChunkToList chunk@(BitChunk ofs len dat) = 
  [ unsafeLookupBitChunk chunk k | k<-[0..len-1] ]

bitChunkFromList :: [Bool] -> BitChunk
bitChunkFromList bits = BitChunk 0 (fromIntegral len) (B.pack bytes) where
  (len,bytes) = worker bits
  worker []   = ( 0, [] )
  worker bits = ( len' + List.length this , byte:ys ) where
    (this,rest) = List.splitAt 8 bits
    byte = List.foldl' (+) 0 $ zipWith shiftL (map boolToWord8 this) [0..7]
    byte = List.foldl' (+) 0 $ zipWith shiftL (map boolToWord8 this) [7,6..0]
    (len' , ys) = worker rest

bitChunkTo01List :: BitChunk -> [Word8]
bitChunkTo01List chunk@(BitChunk ofs len dat) = 
  [ unsafeLookupBitChunk01 chunk k | k<-[0..len-1] ]

bitChunkFrom01List :: [Word8] -> BitChunk
bitChunkFrom01List bits = BitChunk 0 (fromIntegral len) (B.pack bytes) where
  (len,bytes) = worker bits
  worker []   = ( 0, [] )
  worker bits = ( len' + List.length this , byte:ys ) where
    (this,rest) = List.splitAt 8 bits
    byte = List.foldl' (+) 0 $ zipWith shiftL (map (.&. 1) this) [0..7]
    byte = List.foldl' (+) 0 $ zipWith shiftL (map (.&. 1) this) [7,6..0]
    (len' , ys) = worker rest
instance Eq BitChunk where
  -- | warning! very slow! TODO: make a better routine
--  (==) :: BitChunk -> BitChunk -> Bool
  (==) x y = bitChunkToList x == bitChunkToList y

-- | Creates a new 'BitChunk' with offset field 0
realignBitChunk :: BitChunk -> BitChunk 
realignBitChunk (BitChunk ofs len dat) = 
  BitChunk 0 len $ case ofsFrac of 
    0 -> dat'
    _ -> B.pack $ B.zipWith f dat' (B.snoc (B.tail dat') 0) 
    ofsFrac2 = 8 - ofsFrac
    f b1 b2 = shiftR b1 (fromIntegral ofsFrac) + shiftL b2 (fromIntegral ofsFrac2)
    f b1 b2 = shiftL b1 (fromIntegral ofsFrac) + shiftR b2 (fromIntegral ofsFrac2)
    dat' = B.drop (fromIntegral ofsInt) dat
    (ofsInt, ofsFrac) = divMod ofs 8 
realizeBitChunk :: BitChunk -> (ByteString, Maybe (Word8,Int))
realizeBitChunk orig = (whole, end) where
  chunk@(BitChunk 0 len dat) = realignBitChunk orig
  (n,k) = divMod len 8
  whole = B.take (fromIntegral n) dat
  end = case k of
    0 -> Nothing
    _ -> let w' = B.index dat (fromIntegral n)
             mask = 2^k - 1 :: Word8
             w = w' .&. mask
-- and at this point, the the inventor and users of big endianness deserve an extra fuck!
             kk = fromIntegral k :: Int
             mask = shiftL (2^kk-1) (8-kk) 
             w = w' .&. mask
         in  Just (w, fromIntegral k) 

unBitString :: BitString -> [BitChunk]
unBitString (BitString xs) = xs
newtype BitString = BitString [BitChunk] 
  deriving Show
instance Show BitString where
  show bits = "BitString <" ++ map f (to01List bits) ++ ">" where 
    f 0 = '0'
    f 1 = '1'
    f _ = error "BitString/show: impossible"

empty :: BitString
empty = BitString [] -- emptyChunk]

-- | Create a 'BitString' from a portion of a 'ByteString'.
-- Warning! No boundary checks are performed!
  :: Int64      -- ^ offset 
  -> Int64      -- ^ length
  -> ByteString  -- ^ source 
  -> BitString
unsafeBitString' ofs len bs = BitString [unsafeBitChunk' ofs len bs] 

-- | Create a 'BitString' from a strict 'ByteString'
bitString :: ByteString -> BitString  
bitString bs = unsafeBitString' 0 (8 * fromIntegral (B.length bs)) bs

-- | Create a 'BitString' from a lazy 'ByteString'
bitStringLazy :: L.ByteString -> BitString  
bitStringLazy = concat . map bitString . L.toChunks 

drop :: Int64 -> BitString -> BitString
drop k (BitString cs) = BitString (worker k cs) where
  worker _ [] = []
  worker k (BitChunk ofs len dat : cs) = if k < len 
    then BitChunk (ofs+k) (len-k) dat : cs
    else worker (k-len) cs

take :: Int64 -> BitString -> BitString
take k (BitString cs) = BitString (worker k cs) where
  worker 0 _  = []
  worker _ [] = []
  worker k (c@(BitChunk ofs len dat) : cs) = if k <= len 
    then [ BitChunk ofs k dat ]
    else c : worker (k-len) cs

-- TODO: better implementation    
splitAt :: Int64 -> BitString -> (BitString,BitString)    
splitAt k b = (take k b, drop k b)

append :: BitString -> BitString -> BitString    
append (BitString chunks1) (BitString chunks2) = BitString (chunks1 ++ chunks2)

concat :: [BitString] -> BitString
concat xs = case xs of
  [] -> empty
  _  -> (BitString . List.concat . map unBitString) xs   -- hmm how strict or lazy this should be?

toList :: BitString -> [Bool]
toList (BitString chunks) = List.concatMap bitChunkToList chunks

fromList :: [Bool] -> BitString
fromList digits = BitString [bitChunkFromList digits]

to01List :: BitString -> [Word8]
to01List (BitString chunks) = List.concatMap bitChunkTo01List chunks

from01List :: [Word8] -> BitString
from01List digits = BitString [bitChunkFrom01List digits]


length :: BitString -> Int64
length (BitString chunks) = List.foldl' (+) 0 (map bitChunkLength chunks)

null :: BitString -> Bool
null bits = (length bits == 0)

-- | warning! very slow! TODO: make a better routine
instance Eq BitString where
  -- (==) :: BitString -> BitString -> Bool
  (==) = fallbackEqual
-- | slow, fallback equality test, via converting to list
fallbackEqual :: BitString -> BitString -> Bool
fallbackEqual x y = (toList x == toList y)


foldl' :: (a -> Bool -> a) -> a -> BitString -> a
foldl' fun init bits = List.foldl' fun init (toList bits)


  :: BitString    -- ^ the string to search for
  -> BitString    -- ^ the string to search in
  -> Maybe Int64  -- ^ the index of the first substring, if exists
findSubstring = findSubstring32

-- the basic unit is a Word32
  :: BitString    -- ^ the string to search for
  -> BitString    -- ^ the string to search in
  -> Maybe Int64  -- ^ the index of the first substring, if exists
findSubstring32 small large = 
  unsafePerformIO $ do
    withForeignPtr fptr_b_small $ \p'' -> do
      let p' = (plusPtr p'' ofs_b_small) :: Ptr Word8
      allocaArray (k+1) $ \q -> allocaArray (k+1) $ \p -> do
        -- we store in 'q' the last 'm' bits
        -- and in 'p' the bits we are searching for
        let p8 = (castPtr :: Ptr Word32 -> Ptr Word8) p
        forM_ [0..len_b_small-1] $ \i -> do { x <- peekElemOff p' i ; pokeElemOff p8 i x }
        peekElemOff p k >>= \x -> pokeElemOff p k (x .&. mask)
        pokeElemOff q k 0   -- this is quite important, because of the shifts!
        pokeElemOff p k 0   -- just to be on the safe side
        let p8 = (castPtr :: Ptr Word32 -> Ptr Word8) p
        forM_ [0..len_b_small-1] $ \i -> do { x <- peekElemOff p' i ; pokeElemOff p8 i x }
        forM_ [0..k] $ \j -> do { y <- peekElemOff p j ; pokeElemOff p j (byteReverseWord32 y) }
        peekElemOff p k >>= \x -> pokeElemOff p k (x .&. mask)
        pokeElemOff q k 0   -- this is quite important, because of the shifts!

        print (m,m32)
        print len_b_small
        print hmm
        peekArray ((k+1)*4) (castPtr p :: Ptr Word8) >>= print       

        worker p q 0 (to01List large)
    m = length small
    m32 = fromIntegral (mod m 32) :: Int    
    d32 = fromIntegral (div m 32) :: Int

    hmm :: (Int, Word32, Int)
    hmm@(k,mask,initShift) = case m32 of
      0 -> ( d32 - 1  , 0xffffffff     , 31                          )
      _ -> ( d32      , 2^m32  - 1     , fromIntegral (mod (m-1) 32) )
      0 -> ( d32 - 1  , 0xffffffff                , 0                              )  
      _ -> ( d32      , shiftL (2^m32-1) (32-m32) , fromIntegral ( 32 - mod m 32 ) )

    b_small = realizeBitStringStrict small
    (fptr_b_small, ofs_b_small, len_b_small) = B.toForeignPtr b_small
    worker :: Ptr Word32 -> Ptr Word32 -> Int64 -> [Word8] -> IO (Maybe Int64)
    worker !p !q !pos !bits = do

--      peekArray ((k+1)*4) (castPtr q :: Ptr Word8) >>= print       

      conds <- forM [0..k-1] $ \j -> do { x <- peekElemOff p j ; y <- peekElemOff q j ; return (x==y) }
      cond  <- do { x <- peekElemOff p k ; y <- peekElemOff q k ; return (x .&. mask == y .&. mask) }
      if and (cond:conds) && pos >= m
        then return (Just (pos - m))
        else case bits of
          [] -> return Nothing
          (b:bs) -> do

            let init_cr = (fromIntegral b , initShift)
            flippedFoldM_ init_cr [k,k-1..0] $ \(c,r) j -> do
              y <- peekElemOff q j
              let cr' = ( y .&. 1 , 31 )
              pokeElemOff q j (shiftR y 1 + shiftL c r)
              return cr'
            worker p q (pos+1) bs 
            let init_cr = (fromIntegral b , initShift)
            flippedFoldM_ init_cr [k,k-1..0] $ \(c,r) j -> do
              y <- peekElemOff q j
              let cr' = ( shiftR y 31 , 0 )
              pokeElemOff q j (shiftL y 1 + shiftL c r)
              return cr'
            worker p q (pos+1) bs 

realizeBitString' :: BitString -> [ByteString]   
realizeBitString' (BitString chunks) = worker Nothing chunks where
  worker :: Maybe (Word8,Int) -> [BitChunk] -> [ByteString]
  worker rem (b:bs) = 
    case rem of 
      Nothing -> 
        let (s, rem') = realizeBitChunk b
        in  s : worker rem' bs
      Just (w,k) -> 
        if r >= q
          then B.singleton t : s : worker rem' bs
               -- (B.cons t s) : worker rem' bs
          else worker (Just (t, k+fromIntegral r)) bs
          q = 8 - fromIntegral k
          r = bitChunkLength b
          (x,y) = splitBitChunkAt q b
          (s, rem') = realizeBitChunk y
#ifndef BITSTRING_BIGENDIAN          
          t = List.foldl' (+) w
            $ zipWith shiftL (bitChunkTo01List x) [k..]
          -- also, fuck. and fuck, again.
          u = 7-k
          t = List.foldl' (+) w
            $ zipWith shiftL (bitChunkTo01List x) [u,u-1..]
  worker rem [] = case rem of
    Nothing    -> []
    Just (w,_) -> [B.singleton w] 

realizeBitStringLazy :: BitString -> L.ByteString
realizeBitStringLazy = L.fromChunks . realizeBitString'

realizeBitStringStrict :: BitString -> B.ByteString
realizeBitStringStrict = B.concat . realizeBitString'



newtype Size     = Size     Int64  deriving Show
newtype BoolList = BoolList [Bool] deriving Show

newtype SearchFor = SearchFor BitString deriving Show

instance Arbitrary Size where
  arbitrary = Size <$> (fromIntegral :: Int -> Int64) <$> choose (0,64) -- 192)

instance Arbitrary BoolList where
  arbitrary = do
    Size k <- arbitrary 
    BoolList <$> vector (fromIntegral k) 

instance Arbitrary BitChunk where
  arbitrary = do
    k <- choose (0,24) :: Gen Int
    l <- choose (0,15) :: Gen Int
    BoolList list <- arbitrary
    let bits1 = bitChunkDrop (fromIntegral k) $ bitChunkFromList list
        len = bitChunkLength bits1
        bits2 = bitChunkTake (max 0 $ len - fromIntegral l) bits1
    return bits2

-- with 48 bits, it's unlikely that there are other random appearances 
instance Arbitrary SearchFor where
  arbitrary = do
    b <- arbitrary
    let l = length b
    if l >= 48 && l < 96 
      then return (SearchFor b) 
      else arbitrary
instance Arbitrary BitString where
  arbitrary = do
    k <- choose (0,7)
    BitString <$> vector k 

runAllTest :: IO ()     
runAllTest = do
  let mytest (text,prop) = do
        print text
        quickCheck prop
  mytest ("fromToList"    , prop_fromToList )
  mytest ("toFromList"    , prop_toFromList )
  mytest ("append"       , prop_append     )
  mytest ("drop"         , prop_drop       )
  mytest ("take"         , prop_take       )
  mytest ("dropChunk"    , prop_dropChunk  )
  mytest ("takeChunk"    , prop_takeChunk  )
  mytest ("realign"      , prop_realign    )
  mytest ("realizeChunk" , prop_realizeChunk )
  mytest ("realize"      , prop_realize      )
  mytest ("realizeLen"   , prop_realizeLen   )

  mytest ("findSubstring1"  , prop_findSubstring1  )
  mytest ("findSubstring1a" , prop_findSubstring1a )
  mytest ("findSubstring1b" , prop_findSubstring1b )
  mytest ("findSubstring2"  , prop_findSubstring2  )

prop_fromToList :: BitString -> Bool
prop_fromToList bits = fromList (toList bits) == bits

prop_toFromList :: BoolList -> Bool
prop_toFromList (BoolList list) = toList (fromList list) == list

prop_append :: [BitString] -> Bool
prop_append xs = toList (concat xs) == List.concat (map toList xs)

prop_drop :: Size -> BitString -> Bool
prop_drop (Size k) xs = toList (drop k xs) == List.drop (fromIntegral k) (toList xs)

prop_take :: Size -> BitString -> Bool
prop_take (Size k) xs = toList (take k xs) == List.take (fromIntegral k) (toList xs)

prop_dropChunk :: Size -> BitChunk -> Bool
prop_dropChunk (Size k) xs = bitChunkToList (bitChunkDrop k xs) == List.drop (fromIntegral k) (bitChunkToList xs)

prop_takeChunk :: Size -> BitChunk -> Bool
prop_takeChunk (Size k) xs = bitChunkToList (bitChunkTake k xs) == List.take (fromIntegral k) (bitChunkToList xs)

prop_realign :: BitChunk -> Bool
prop_realign chunk = realignBitChunk chunk == chunk

prop_realizeChunk :: BitChunk -> Bool
prop_realizeChunk chunk = append (bitString whole) (BitString [end]) == BitString [chunk] where
  (whole,remain) = realizeBitChunk chunk
  end = case remain of
    Nothing -> emptyBitChunk
    Just (w,k) -> BitChunk 0 (fromIntegral k) (B.singleton w)

prop_realize :: BitString -> Bool
prop_realize bits = let n = length bits in unsafeBitString' 0 n (realizeBitStringStrict bits) == bits

prop_realizeLen :: BitString -> Bool
prop_realizeLen bits = let n = length bits in div (n+7) 8 == (fromIntegral $ B.length $ realizeBitStringStrict bits) 
prop_findSubstring1 :: SearchFor -> BitString -> BitString -> Bool
prop_findSubstring1 (SearchFor what) pre post = findSubstring what big == Just (length pre) where
  big = concat [ pre , what , post ] 

prop_findSubstring1a :: SearchFor -> BitString -> Bool
prop_findSubstring1a (SearchFor what) pre = findSubstring what big == Just (length pre) where
  big = concat [ pre , what ] 

prop_findSubstring1b :: SearchFor -> BitString -> Bool
prop_findSubstring1b (SearchFor what) post = findSubstring what big == Just 0 where
  big = concat [ what , post  ] 

prop_findSubstring2 :: SearchFor -> BitString -> BitString -> Bool
prop_findSubstring2 (SearchFor what) pre post = findSubstring what big == Nothing where
  big = concat [ pre , post ] 

test = BitString xxx

xxx =
  [ BitChunk 17 49 (mypack "9\176N\152%\f\STX\144\ETX")
  , BitChunk 22 50 (mypack "\148\&6\184\RS\134\144+\241\210")
  , BitChunk 0 0 (mypack "")
  , BitChunk 0 0 (mypack "")
  , BitChunk 0 9 (mypack "$\NUL")
yyy = map (bitChunk) $ realizeBitString' test

let what = (BitString [BitChunk 0 16 (mypack "c\229W"),BitChunk 0 0 (mypack ""),BitChunk 20 13 (mypack "\169\188\"\DLEf\EOT"),BitChunk 0 0 (mypack ""),BitChunk 3 9 (mypack "/\FS+"),BitChunk 21 26 (mypack "&,\159\249\158&h\STX"),BitChunk 0 0 (mypack "")])

let pre = BitString [BitChunk 3 13 (mypack "\245\200\180\NUL"),BitChunk 17 4 (mypack "{\SI\DC2"),BitChunk 24 33 (mypack "\207\135W\RS\240\180{\SOH")]

let post = BitString [BitChunk 16 19 (mypack "e\142\&0z\209m\NUL"),BitChunk 14 27 (mypack "\248m\207\146\&8\224\NUL")]

