{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Reinforce.Spaces.Action
( DiscreteActionSpace(..)
, oneHot
, oneHot'
, allActions
, randomChoice
) where
import Reinforce.Prelude
import Control.MonadMWCRandom
import Numeric.LinearAlgebra.Static (R)
import qualified Numeric.LinearAlgebra.Static as LA
import qualified Data.Vector as V
class (Bounded a, Enum a) => DiscreteActionSpace a where
type Size a :: Nat
toAction :: Int -> a
toAction = toEnum
fromAction :: a -> Int
fromAction = fromEnum
oneHot :: forall a . (KnownNat (Size a), DiscreteActionSpace a) => a -> R (Size a)
oneHot e = LA.vector . V.toList
$ V.unsafeUpd (replicateZeros (Proxy :: Proxy a)) [(fromEnum e, 1)]
oneHot' :: forall a . (DiscreteActionSpace a) => a -> Vector Double
oneHot' e = V.unsafeUpd (replicateZeros (Proxy :: Proxy a)) [(fromEnum e, 1)]
replicateZeros :: forall a . (Enum a, Bounded a) => Proxy a -> Vector Double
replicateZeros _ = V.fromList $ replicate (fromEnum (maxBound :: a) + 1) 0
allActions :: DiscreteActionSpace a => [a]
allActions = [minBound..maxBound]
randomChoice
:: forall m a . (MonadIO m , MonadMWCRandom m, DiscreteActionSpace a)
=> m a
randomChoice = toEnum . fst <$> sampleFrom uniformDist
where
uniformDist :: [Double]
uniformDist = fmap (\a -> convert a / total) allActions
where
convert :: a -> Double
convert = fromIntegral . fromEnum
total :: Double
total = sum (fmap convert allActions)