-------------------------------------------------------------------------------
-- |
-- Module    :  Environments.Bandits
-- Copyright :  (c) Sentenai 2017
-- License   :  Proprietary
-- Maintainer:  sam@sentenai.com
-- Stability :  experimental
-- Portability: non-portable
--
-- Implementation of an n-armed bandit environment.
--
-- FIXME: currently this is only for a 10-armed bandit. This needs to be tied
-- to a config.
-------------------------------------------------------------------------------
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE DeriveGeneric #-}
{-# OPTIONS_GHC -Wno-unused-top-binds #-}
module Environments.Bandits
  ( Environment(..)
  , runEnvironment
  , Logger.Event(..)
  , Action
  , mkBandits
  , defaultBandits
  , mkAction
  ) where

import Control.MonadEnv
import Control.MonadMWCRandom
import qualified Data.Vector as V
import Data.Vector ((!))
import Data.DList
import qualified Data.Logger as Logger
import Control.Exception.Safe (assert)

import Reinforce.Prelude

-- | FIXME: only 10 arms for the time being. This is where a "discrete space"
-- would be nice
data Config = Config
  { nBandits :: Int
  , offset   :: Int
  , stdDev   :: Float
  , bandits  :: Vector NormalDistribution
  , gen      :: GenIO
  }

type Event = Logger.Event Reward () Action

-- | The slot machine index whose arm will be pulled
newtype Action = Action { unAction :: Int }
  deriving (Eq, Ord, Show, Enum, Generic)

instance Bounded Action where
  minBound = Action 0
  maxBound = Action 9

instance Hashable Action where

-- | Convert an Int to an Action  in the bandit environment. Throw if the Int
-- falls out of bounds.
mkAction :: Int -> Environment Action
mkAction i = Environment $ do
  n <- nBandits <$> ask
  assert (i > n || i < 0) (pure $ Action i)

-- | Monad for an n-armed bandit environment
newtype Environment a = Environment
  { getEnvironment :: RWST Config (DList Event) () IO a }
  deriving
    ( Functor
    , Applicative
    , Monad
    , MonadIO
    , MonadThrow
    , MonadReader Config
    , MonadWriter (DList Event)
    , MonadState ()
    , MonadRWS Config (DList Event) ()
    )

-- | run an n-armed bandit environment
runEnvironment :: Config -> Environment () -> IO (DList Event)
runEnvironment c (Environment m) = snd <$> evalRWST m c ()

-- | Give the default config of a 10-armed bandit
defaultBandits :: GenIO -> Config
defaultBandits = mkBandits 10 2 0.1

-- | helper function to build a bandits config with normally-distributed
-- reward functions
mkBandits :: Int -> Int -> Float -> GenIO -> Config
mkBandits n offset' std = Config n offset' std $
  V.fromList $ fmap (`rewardDist` std) [offset' .. offset' + n - 1]
  where
    rewardDist :: Int -> Float -> NormalDistribution
    rewardDist m s = normalDistr (fromIntegral m) (realToFrac s)

instance MonadMWCRandom Environment where
  getGen = Environment $ fmap gen ask

instance MonadEnv Environment () Action Reward where
  -- this isn't an episodic environment... we'll have to split this out later
  reset :: Environment (Initial ())
  reset = return $ Initial ()

  step :: Action -> Environment (Obs Reward ())
  step (Action a) = do
    rwd <- genContVar =<< (! a) . bandits <$> ask
    tell . pure $ Logger.Event 0 rwd () (Action a)
    return $ Next rwd ()