{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DataKinds #-}
module Data.CartPole
( StateCP(..)
, Action(..)
, Event
) where
import Reinforce.Prelude
import qualified Reinforce.Spaces.State as Spaces
import Data.Aeson
import Data.Aeson.Types
import Control.Exception (AssertionFailed(..))
import Reinforce.Spaces
import Reinforce.Spaces.Action (Size)
import Numeric.LinearAlgebra.Static
import qualified Data.Logger as Logger
import qualified Data.Vector as V
type Event = Logger.Event Double StateCP Action
data Action
= GoLeft
| GoRight
deriving (Show, Eq, Enum, Bounded, Ord, Generic)
instance Hashable Action
instance DiscreteActionSpace Action where
type Size Action = 2
instance ToJSON Action where
toJSON :: Action -> Value
toJSON GoLeft = toJSON (0 :: Int)
toJSON GoRight = toJSON (1 :: Int)
data StateCP = StateCP
{ position :: Float
, angle :: Float
, velocity :: Float
, angleRate :: Float
} deriving (Show, Eq, Generic, Ord)
instance Hashable StateCP
instance Monoid StateCP where
mempty = StateCP 0 0 0 0
mappend (StateCP a0 b0 c0 d0) (StateCP a1 b1 c1 d1)
= StateCP (a0+a1) (b0+b1) (c0+c1) (d0+d1)
instance Spaces.StateSpace StateCP where
toVector (StateCP p a v r) = Spaces.toVector [p, a, v, r]
fromVector vec =
case getVals of
Nothing -> throw $ AssertionFailed "malformed vector found"
Just s -> return s
where
getVals :: Maybe StateCP
getVals = StateCP
<$> findField 0
<*> findField 1
<*> findField 2
<*> findField 3
findField :: Int -> Maybe Float
findField i = double2Float <$> vec V.!? i
instance FromJSON StateCP where
parseJSON :: Value -> Parser StateCP
parseJSON arr@(Array _)= do
(p, a, v, r) <- parseJSON arr :: Parser (Float, Float, Float, Float)
return $ StateCP p a v r
parseJSON invalid = typeMismatch "StateCP" invalid
instance StateSpaceStatic StateCP where
type Size StateCP = 4
toR = vector . V.toList . Spaces.toVector