-------------------------------------------------------------------------------
-- |
-- Module    :  Classifiers.RL.Control.MonadMWCRandom
-- Copyright :  (c) Sentenai 2017
-- License   :  BSD3
-- Maintainer:  sam@sentenai.com
-- Stability :  experimental
-- Portability: non-portable
--
-- typeclass to remove extraneous mwc-random functions
-------------------------------------------------------------------------------
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.MonadMWCRandom
  ( MonadMWCRandom(..)
  , MonadMWCRandomIO
  , MWCRand
  , MWCRandT(..)
  , runMWCRand
  , runMWCRandT
  -- * re-exports from System.Random.MWC
  , GenIO
  -- * wrappers for System.Random.MWC
  , uniform
  , uniformR
  -- * wrappers for Statistics.Distribution
  , Control.MonadMWCRandom.genContVar
  -- * extras
  , sampleFrom
  , Variate
  , _uniform
  ) where

import Reinforce.Prelude
import qualified System.Random.MWC as MWC
import qualified Statistics.Distribution as Stats
import Control.MonadEnv (MonadEnv(..), Obs, Initial)
import Control.Monad.Primitive (PrimState, PrimMonad)

-- | a convenience helper to reference the underlying System.Random.MWC function
_uniform :: (PrimMonad m, Variate a) => MWC.Gen (PrimState m) -> m a
_uniform = MWC.uniform

-- | MonadMWCRandom for public use. FIXME: use with PrimState so that we can use ST
class Monad m => MonadMWCRandom m where
  getGen :: m GenIO

-- | A convenience type constraint with MonadMWCRandom and MonadIO.
type MonadMWCRandomIO m = (MonadIO m, MonadMWCRandom m)

-------------------------------------------------------------------------------

instance MonadMWCRandom m => MonadMWCRandom (StateT s m) where
  getGen :: StateT s m GenIO
  getGen = lift getGen

instance MonadMWCRandom m => MonadMWCRandom (ReaderT s m) where
  getGen :: ReaderT s m GenIO
  getGen = lift getGen

instance (Monoid w, MonadMWCRandom m) => MonadMWCRandom (WriterT w m) where
  getGen :: WriterT w m GenIO
  getGen = lift getGen

instance (Monoid w, MonadMWCRandom m) => MonadMWCRandom (RWST r w s m) where
  getGen :: RWST r w s m GenIO
  getGen = lift getGen


-- | in the end, we can always use IO to get our generator, but we will create a
-- new generator on each use.
instance MonadMWCRandom IO where
  getGen :: IO GenIO
  getGen = MWC.createSystemRandom


-------------------------------------------------------------------------------

-- | uniform referencing MonadMWCRandom's generator
uniform :: (MonadIO m, MonadMWCRandom m, Variate a) => m a
uniform = getGen >>= liftIO . MWC.uniform


-- | uniformR referencing MonadMWCRandom's generator
uniformR :: (MonadIO m, MonadMWCRandom m, Variate a) => (a, a) -> m a
uniformR r = getGen >>= liftIO . MWC.uniformR r


-- | genContVar referencing MonadMWCRandom's generator
genContVar :: (MonadIO m, MonadMWCRandom m, Stats.ContGen d) => d -> m Double
genContVar d = getGen >>= liftIO . Stats.genContVar d


-- ========================================================================= --
-- * Utility functions functions

-- | Sample a single index from a list of weights, converting the list into
-- a distribution
sampleFrom :: (MonadIO m, MonadMWCRandom m) => [Double] -> m (Int, [Double])
sampleFrom xs = fmap ((,dist) . choose) uniform
  where
    dist :: [Double]
    dist = fmap (/ total) xs

    total :: Double
    total = sum xs

    choose :: Double -> Int
    choose n =
      -- Return the head index (unsafeHead is safe since the last elem's snd must be 1.0)
      fst . unsafeHead .

      -- Drop while the cumulative sum is < the given value
      dropWhile ((< n) . snd) .

      -- Pair each elem with its index
      zip [0..] .

      -- Transform list of probabilities to cumulative sum
      scanl1 (+) $ dist


-- ========================================================================= --
-- * A concrete type for MonadMWCRandom

-- | a wrapper to share a generator without using reader
newtype MWCRandT m a = MWCRandT { getMWCRandT :: ReaderT GenIO m a }
  deriving (Functor, Applicative, Monad, MonadTrans, MonadThrow, MonadIO)


-- | unwrap MonadMWCRandom
runMWCRandT :: MWCRandT m a -> GenIO -> m a
runMWCRandT = runReaderT . getMWCRandT


-- | simple type alias for transformer-less variant
type MWCRand = MWCRandT Identity


-- | run a transformerless MWC-random Monad
runMWCRand :: MWCRand a -> GenIO -> a
runMWCRand = runMWCRand


-- | instance declaration of MonadMWCRandom for MWCRandT
instance Monad m => MonadMWCRandom (MWCRandT m) where
  getGen :: MWCRandT m GenIO
  getGen = MWCRandT ask


-- | An instance which allows for an environment to hold a reference to a shared
-- MWC-random generator
instance MonadEnv m s a r => MonadEnv (MWCRandT m) s a r where
  reset :: MWCRandT m (Initial s)
  reset = lift reset

  step :: a -> MWCRandT m (Obs r s)
  step a = lift $ step a