-------------------------------------------------------------------------------
-- |
-- Module    :  Data.CartPole
-- Copyright :  (c) Sentenai 2017
-- License   :  BSD3
-- Maintainer:  sam@sentenai.com
-- Stability :  experimental
-- Portability: non-portable
--
-- Shared datatypes between Gym environments and the haskell implementation of
-- CartPole.
-------------------------------------------------------------------------------
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DataKinds #-}
module Data.CartPole
  ( StateCP(..)
  , Action(..)
  , Event
  ) where

import Reinforce.Prelude
import qualified Reinforce.Spaces.State as Spaces

import Data.Aeson
import Data.Aeson.Types
import Control.Exception (AssertionFailed(..))
import Reinforce.Spaces
import Reinforce.Spaces.Action (Size)
import Numeric.LinearAlgebra.Static
import qualified Data.Logger as Logger
import qualified Data.Vector as V

-- | Specific datatype for a CartPole event
type Event = Logger.Event Double StateCP Action

-- | Cartpole can only go left or right has an action space
-- of "discrete 2" containing {0..n-1}.
--
-- FIXME: Migrate this to either a more generic "directions" actions
-- (would need things like "up", "down" versions as well) or a "discrete
-- actions" version. I'm a fan of the former.
data Action
  = GoLeft
  | GoRight
  deriving (Show, Eq, Enum, Bounded, Ord, Generic)

instance Hashable Action

instance DiscreteActionSpace Action where
  type Size Action = 2

instance ToJSON Action where
  toJSON :: Action -> Value
  toJSON GoLeft  = toJSON (0 :: Int)
  toJSON GoRight = toJSON (1 :: Int)


-- | The state of a cart on a pole in a CartPole environment
data StateCP = StateCP
  { position  :: Float  -- ^ position of the cart on the track
  , angle     :: Float  -- ^ angle of the pole with the vertical
  , velocity  :: Float  -- ^ cart velocity
  , angleRate :: Float  -- ^ rate of change of the angle
  } deriving (Show, Eq, Generic, Ord)

instance Hashable StateCP

instance Monoid StateCP where
  mempty = StateCP 0 0 0 0
  mappend (StateCP a0 b0 c0 d0) (StateCP a1 b1 c1 d1)
    = StateCP (a0+a1) (b0+b1) (c0+c1) (d0+d1)

instance Spaces.StateSpace StateCP where
  toVector (StateCP p a v r) = Spaces.toVector [p, a, v, r]
  fromVector vec =
    case getVals of
      Nothing -> throw $ AssertionFailed "malformed vector found"
      Just s -> return s

    where
      getVals :: Maybe StateCP
      getVals = StateCP
        <$> findField 0
        <*> findField 1
        <*> findField 2
        <*> findField 3

      findField :: Int -> Maybe Float
      findField i = double2Float <$> vec V.!? i

instance FromJSON StateCP where
  parseJSON :: Value -> Parser StateCP
  parseJSON arr@(Array _)= do
    (p, a, v, r) <- parseJSON arr :: Parser (Float, Float, Float, Float)
    return $ StateCP p a v r
  parseJSON invalid    = typeMismatch "StateCP" invalid

instance StateSpaceStatic StateCP where
  type Size StateCP = 4
  toR = vector . V.toList . Spaces.toVector
  -- fromR = Spaces.fromVector . V.fromList . LA.toList . unwrap