{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DataKinds #-}
module Data.CartPole
  ( StateCP(..)
  , Action(..)
  , Event
  ) where
import Reinforce.Prelude
import qualified Reinforce.Spaces.State as Spaces
import Data.Aeson
import Data.Aeson.Types
import Control.Exception (AssertionFailed(..))
import Reinforce.Spaces
import Reinforce.Spaces.Action (Size)
import Numeric.LinearAlgebra.Static
import qualified Data.Logger as Logger
import qualified Data.Vector as V
type Event = Logger.Event Double StateCP Action
data Action
  = GoLeft
  | GoRight
  deriving (Show, Eq, Enum, Bounded, Ord, Generic)
instance Hashable Action
instance DiscreteActionSpace Action where
  type Size Action = 2
instance ToJSON Action where
  toJSON :: Action -> Value
  toJSON GoLeft  = toJSON (0 :: Int)
  toJSON GoRight = toJSON (1 :: Int)
data StateCP = StateCP
  { position  :: Float  
  , angle     :: Float  
  , velocity  :: Float  
  , angleRate :: Float  
  } deriving (Show, Eq, Generic, Ord)
instance Hashable StateCP
instance Monoid StateCP where
  mempty = StateCP 0 0 0 0
  mappend (StateCP a0 b0 c0 d0) (StateCP a1 b1 c1 d1)
    = StateCP (a0+a1) (b0+b1) (c0+c1) (d0+d1)
instance Spaces.StateSpace StateCP where
  toVector (StateCP p a v r) = Spaces.toVector [p, a, v, r]
  fromVector vec =
    case getVals of
      Nothing -> throw $ AssertionFailed "malformed vector found"
      Just s -> return s
    where
      getVals :: Maybe StateCP
      getVals = StateCP
        <$> findField 0
        <*> findField 1
        <*> findField 2
        <*> findField 3
      findField :: Int -> Maybe Float
      findField i = double2Float <$> vec V.!? i
instance FromJSON StateCP where
  parseJSON :: Value -> Parser StateCP
  parseJSON arr@(Array _)= do
    (p, a, v, r) <- parseJSON arr :: Parser (Float, Float, Float, Float)
    return $ StateCP p a v r
  parseJSON invalid    = typeMismatch "StateCP" invalid
instance StateSpaceStatic StateCP where
  type Size StateCP = 4
  toR = vector . V.toList . Spaces.toVector