{-# LANGUAGE CPP #-}
#define Flt Double
#define VECT_Double

-- TODO: the pointer versions of these functions should be really implemented 
-- via the pointer versions of the original opengl functions...

-- | OpenGL support, including 'Vertex', 'TexCoord', etc instances for 'Vec2', 'Vec3' and 'Vec4'.

module Data.Vect.Flt.OpenGL where

--------------------------------------------------------------------------------

import Control.Monad
import Data.Vect.Flt.Base
import Data.Vect.Flt.Util.Projective
import qualified Graphics.Rendering.OpenGL as GL

import Foreign
import Unsafe.Coerce

import Graphics.Rendering.OpenGL hiding 
  ( Normal3 , rotate , translate , scale
  , matrix , currentMatrix , withMatrix , multMatrix 
  )

--------------------------------------------------------------------------------

-- the new opengl bindings ruin my day...

#ifdef VECT_Float
type GLflt = GLfloat
#endif

#ifdef VECT_Double
type GLflt = GLdouble
#endif

glflt :: Flt -> GLflt
unflt :: GLflt -> Flt
-- ah, fuck it, let's go for speed.
-- this way we don't even need to check the OpenGL binding version.
glflt = unsafeCoerce  -- realToFrac
unflt = unsafeCoerce  -- realToFrac

#define GL_XY   (glflt x) (glflt y) 
#define GL_XYZ  (glflt x) (glflt y) (glflt z)
#define GL_XYZW (glflt x) (glflt y) (glflt z) (glflt w)

#define GL_RGB  (glflt r) (glflt g) (glflt b)
#define GL_RGBA (glflt r) (glflt g) (glflt b) (glflt a)

#define GL_UV   (glflt u) (glflt v) 
#define GL_UVW  (glflt u) (glflt v) (glflt w) 
#define GL_UVWZ (glflt u) (glflt v) (glflt w) (glflt z)

#define UN_XY   (unflt x) (unflt y) 
#define UN_XYZ  (unflt x) (unflt y) (unflt z)
#define UN_XYZW (unflt x) (unflt y) (unflt z) (unflt w)

--------------------------------------------------------------------------------

-- | There should be a big warning here about the different conventions, 
-- hidden transpositions, and all the confusion this will inevitably cause...
--
-- As it stands, 
--
-- > glRotate t1 axis1 >> glRotate t2 axis2 >> glRotate t3 axis3
-- 
-- has the same result as
--
-- > multMatrix (rotMatrixProj4 t3 axis3 .*. rotMatrixProj4 t2 axis2 .*. rotMatrixProj4 t1 axis1)
--
-- because at the interface of OpenGL and this library there is a transposition
-- to compensate for the different conventions. (This transposition is implicit
-- in the code, because the way the matrices are stored in the memory is also
-- different: OpenGL stores them column-major, and we store them row-major).

class ToOpenGLMatrix m where
  makeGLMatrix :: m -> IO (GLmatrix GLflt)

class FromOpenGLMatrix m where
  peekGLMatrix :: GLmatrix GLflt -> IO m
  
setMatrix :: ToOpenGLMatrix m => Maybe MatrixMode -> m -> IO ()
setMatrix mode m = makeGLMatrix m >>= \x -> GL.matrix mode $= x
 
getMatrix :: FromOpenGLMatrix m => Maybe MatrixMode -> IO m
getMatrix mode = get (GL.matrix mode) >>= peekGLMatrix

matrix :: (ToOpenGLMatrix m, FromOpenGLMatrix m) => Maybe MatrixMode -> StateVar m
matrix mode = makeStateVar (getMatrix mode) (setMatrix mode)

currentMatrix :: (ToOpenGLMatrix m, FromOpenGLMatrix m) => StateVar m
currentMatrix = matrix Nothing

multMatrix :: ToOpenGLMatrix m => m -> IO ()
multMatrix m = makeGLMatrix m >>= GL.multMatrix

instance ToOpenGLMatrix Mat4 where
  makeGLMatrix m = GL.withNewMatrix GL.ColumnMajor (flip poke m . (castPtr :: Ptr GLflt -> Ptr Mat4)) 
 
instance FromOpenGLMatrix Mat4 where
  -- huh? GL.withMatrix is strange
  peekGLMatrix x = GL.withMatrix x $ \_ p -> peek (castPtr p)
  
instance ToOpenGLMatrix Mat3 where
  makeGLMatrix m = makeGLMatrix (extendWith 1 m :: Mat4)
 
instance ToOpenGLMatrix Mat2 where
  makeGLMatrix m = makeGLMatrix (extendWith 1 m :: Mat4)

instance ToOpenGLMatrix Ortho4 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat4)

instance ToOpenGLMatrix Ortho3 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat3)

instance ToOpenGLMatrix Ortho2 where
  makeGLMatrix m = makeGLMatrix (fromOrtho m :: Mat2)

instance ToOpenGLMatrix Proj4 where
  makeGLMatrix m = makeGLMatrix (fromProjective m :: Mat4)

instance ToOpenGLMatrix Proj3 where
  makeGLMatrix m = makeGLMatrix (fromProjective m :: Mat3)
  
--------------------------------------------------------------------------------

{-# SPECIALISE radianToDegrees :: Float  -> Float  #-}
{-# SPECIALISE radianToDegrees :: Double -> Double #-}
radianToDegrees :: RealFrac a => a -> a
radianToDegrees x = x * 57.295779513082322

{-# SPECIALIZE degreesToRadian :: Float  -> Float  #-}
{-# SPECIALIZE degreesToRadian :: Double -> Double #-}
degreesToRadian :: Floating a => a -> a
degreesToRadian x = x * 1.7453292519943295e-2

-- | The angle is in radians. (WARNING: OpenGL uses degrees!)
glRotate :: Flt -> Vec3 -> IO ()
glRotate angle (Vec3 x y z) = GL.rotate (glflt $ radianToDegrees angle) (Vector3 GL_XYZ)

glTranslate :: Vec3 -> IO ()
glTranslate (Vec3 x y z) = GL.translate (Vector3 GL_XYZ)

glScale3 :: Vec3 -> IO ()
glScale3 (Vec3 x y z) = GL.scale GL_XYZ

glScale :: Flt -> IO ()
glScale x = let s = glflt x in GL.scale s s s

--------------------------------------------------------------------------------
 
-- | \"Orthogonal projecton\" matrix, a la OpenGL 
-- (the corresponding functionality is removed in OpenGL 3.1)
orthoMatrix 
  :: (Flt,Flt)   -- ^ (left,right)
  -> (Flt,Flt)   -- ^ (bottom,top)
  -> (Flt,Flt)   -- ^ (near,far)
  -> Mat4 
orthoMatrix (l,r) (b,t) (n,f) = Mat4
  (Vec4 (2/(r-l)) 0 0 0)
  (Vec4 0 (2/(t-b)) 0 0)
  (Vec4 0 0 (-2/(f-n)) 0)
  (Vec4 (-(r+l)/(r-l)) (-(t+b)/(t-b)) (-(f+n)/(f-n)) 1)
  
-- | The same as "orthoMatrix", but with a different parametrization.
orthoMatrix2 {- ' CPP is sensitive to primes -}
  :: Vec3     -- ^ (left,top,near)
  -> Vec3     -- ^ (right,bottom,far)
  -> Mat4 
orthoMatrix2 (Vec3 l t n) (Vec3 r b f) = orthoMatrix (l,r) (b,t) (n,f)

-- | \"Perspective projecton\" matrix, a la OpenGL 
-- (the corresponding functionality is removed in OpenGL 3.1).
frustumMatrix
  :: (Flt,Flt)   -- ^ (left,right)
  -> (Flt,Flt)   -- ^ (bottom,top)
  -> (Flt,Flt)   -- ^ (near,far)
  -> Mat4 
frustumMatrix (l,r) (b,t) (n,f) = Mat4
  (Vec4 (2*n/(r-l)) 0 0 0)
  (Vec4 0 (2*n/(t-b)) 0 0)
  (Vec4 ((r+l)/(r-l)) ((t+b)/(t-b)) (-(f+n)/(f-n)) (-1))
  (Vec4 0 0 (-2*f*n/(f-n)) 0)
  
-- | The same as "frustumMatrix", but with a different parametrization.
frustumMatrix2 {- ' CPP is sensitive to primes -}
  :: Vec3     -- ^ (left,top,near)
  -> Vec3     -- ^ (right,bottom,far)
  -> Mat4 
frustumMatrix2 (Vec3 l t n) (Vec3 r b f) = frustumMatrix (l,r) (b,t) (n,f)

-- | Inverse of "frustumMatrix".
inverseFrustumMatrix
  :: (Flt,Flt)   -- ^ (left,right)
  -> (Flt,Flt)   -- ^ (bottom,top)
  -> (Flt,Flt)   -- ^ (near,far)
  -> Mat4 
inverseFrustumMatrix (l,r) (b,t) (n,f) = Mat4
  (Vec4 (0.5*(r-l)/n) 0 0 0)
  (Vec4 0 (0.5*(t-b)/n) 0 0)
  (Vec4 0 0 0 (0.5*(n-f)/(f*n)))
  (Vec4 (0.5*(r+l)/n) (0.5*(t+b)/n) (-1) (0.5*(f+n)/(f*n)))

--------------------------------------------------------------------------------
-- Vertex instances

instance GL.Vertex Vec2 where
  vertex (Vec2 x y) = GL.vertex (GL.Vertex2 GL_XY)
  vertexv p = peek p >>= vertex 
  
instance GL.Vertex Vec3 where
  vertex (Vec3 x y z) = GL.vertex (GL.Vertex3 GL_XYZ)
  vertexv p = peek p >>= vertex   
  
instance GL.Vertex Vec4 where
  vertex (Vec4 x y z w) = GL.vertex (GL.Vertex4 GL_XYZW)
  vertexv p = peek p >>= vertex   

--------------------------------------------------------------------------------
-- the Normal instance
-- note that there is no Normal2\/Normal4 in the OpenGL binding

instance GL.Normal Normal3 where
  normal u = GL.normal (GL.Normal3 GL_XYZ) 
    where Vec3 x y z = fromNormal u 
  normalv p = peek p >>= normal 

instance GL.Normal Vec3 where
  normal (Vec3 x y z) = GL.normal (GL.Normal3 GL_XYZ) 
  normalv p = peek p >>= normal 

--------------------------------------------------------------------------------
-- Color instances
  
instance GL.Color Vec3 where
  color (Vec3 r g b) = GL.color (GL.Color3 GL_RGB)
  colorv p = peek p >>= color

instance GL.Color Vec4 where
  color (Vec4 r g b a) = GL.color (GL.Color4 GL_RGBA)
  colorv p = peek p >>= color

instance GL.SecondaryColor Vec3 where
  secondaryColor (Vec3 r g b) = GL.secondaryColor (GL.Color3 GL_RGB)
  secondaryColorv p = peek p >>= secondaryColor

{-
-- there is no such thing?
instance GL.SecondaryColor Vec4 where
  secondaryColor (Vec4 r g b a) = GL.secondaryColor (GL.Color4 r g b a)
  secondaryColorv p = peek p >>= secondaryColor
-}

--------------------------------------------------------------------------------
-- TexCoord instances

instance GL.TexCoord Vec2 where
  texCoord (Vec2 u v) = GL.texCoord (GL.TexCoord2 GL_UV)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec2 u v) = GL.multiTexCoord unit (GL.TexCoord2 GL_UV)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

instance GL.TexCoord Vec3 where
  texCoord (Vec3 u v w) = GL.texCoord (GL.TexCoord3 GL_UVW)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec3 u v w) = GL.multiTexCoord unit (GL.TexCoord3 GL_UVW)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

instance GL.TexCoord Vec4 where
  texCoord (Vec4 u v w z) = GL.texCoord (GL.TexCoord4 GL_UVWZ)
  texCoordv p = peek p >>= texCoord
  multiTexCoord unit (Vec4 u v w z) = GL.multiTexCoord unit (GL.TexCoord4 GL_UVWZ)
  multiTexCoordv unit p = peek p >>= multiTexCoord unit

--------------------------------------------------------------------------------
-- Vertex Attributes (experimental)

class VertexAttrib' a where
  vertexAttrib :: GL.AttribLocation -> a -> IO ()
  
instance VertexAttrib' {- ' CPP is sensitive to primes -} Flt where
  vertexAttrib loc x = GL.vertexAttrib1 loc (glflt x)

instance VertexAttrib' Vec2 where
  vertexAttrib loc (Vec2 x y) = GL.vertexAttrib2 loc GL_XY

instance VertexAttrib' Vec3 where
  vertexAttrib loc (Vec3 x y z) = GL.vertexAttrib3 loc GL_XYZ

instance VertexAttrib' Vec4 where
  vertexAttrib loc (Vec4 x y z w) = GL.vertexAttrib4 loc GL_XYZW

instance VertexAttrib' Normal2 where
  vertexAttrib loc u = GL.vertexAttrib2 loc GL_XY
    where Vec2 x y = fromNormal u 

instance VertexAttrib' Normal3 where
  vertexAttrib loc u = GL.vertexAttrib3 loc GL_XYZ
    where Vec3 x y z = fromNormal u 

instance VertexAttrib' Normal4 where
  vertexAttrib loc u = GL.vertexAttrib4 loc GL_XYZW
    where Vec4 x y z w = fromNormal u 
   
--------------------------------------------------------------------------------
-- Uniform (again, experimental)

-- (note that the uniform location code in the OpenGL 2.2.1.1 is broken; 
-- a work-around is to put a zero character at the end of uniform names)

{-
toFloat :: Flt -> Float
toFloat = realToFrac

fromFloat :: Float -> Flt
fromFloat = realToFrac
-}

-- Uniforms are always floats...
#ifdef VECT_Float

instance GL.Uniform Flt where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Index1 x) -> (unflt x)) $ get (uniform loc)
    setter x = ($=) (uniform loc) (Index1 (glflt x)) 
  uniformv loc cnt ptr = uniformv loc cnt (castPtr ptr :: Ptr (Index1 GLflt))

instance GL.Uniform Vec2 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex2 x y) -> Vec2 UN_XY) $ get (uniform loc)
    setter (Vec2 x y) = ($=) (uniform loc) (Vertex2 GL_XY) 
--  uniformv loc cnt ptr = uniformv loc (2*cnt) (castPtr ptr :: Ptr (Index1 GLflt))
  uniformv loc cnt ptr = uniformv loc cnt (castPtr ptr :: Ptr (Vertex2 GLflt))

instance GL.Uniform Vec3 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex3 x y z) -> Vec3 UN_XYZ) $ get (uniform loc)
    setter (Vec3 x y z) = ($=) (uniform loc) (Vertex3 GL_XYZ) 
--  uniformv loc cnt ptr = uniformv loc (3*cnt) (castPtr ptr :: Ptr (Index1 GLflt))
  uniformv loc cnt ptr = uniformv loc cnt (castPtr ptr :: Ptr (Vertex3 GLflt))

instance GL.Uniform Vec4 where
  uniform loc = GL.makeStateVar getter setter where
    getter = liftM (\(GL.Vertex4 x y z w) -> Vec4 UN_XYZW) $ get (uniform loc)
    setter (Vec4 x y z w) = ($=) (uniform loc) (Vertex4 GL_XYZW) 
--  uniformv loc cnt ptr = uniformv loc (4*cnt) (castPtr ptr :: Ptr (Index1 GLflt))
  uniformv loc cnt ptr = uniformv loc cnt (castPtr ptr :: Ptr (Vertex4 GLflt))
    
#endif