{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.MonadMWCRandom
( MonadMWCRandom(..)
, MonadMWCRandomIO
, MWCRand
, MWCRandT(..)
, runMWCRand
, runMWCRandT
, GenIO
, uniform
, uniformR
, Control.MonadMWCRandom.genContVar
, sampleFrom
, Variate
, _uniform
) where
import Reinforce.Prelude
import qualified System.Random.MWC as MWC
import qualified Statistics.Distribution as Stats
import Control.MonadEnv (MonadEnv(..), Obs, Initial)
import Control.Monad.Primitive (PrimState, PrimMonad)
_uniform :: (PrimMonad m, Variate a) => MWC.Gen (PrimState m) -> m a
_uniform = MWC.uniform
class Monad m => MonadMWCRandom m where
getGen :: m GenIO
type MonadMWCRandomIO m = (MonadIO m, MonadMWCRandom m)
instance MonadMWCRandom m => MonadMWCRandom (StateT s m) where
getGen :: StateT s m GenIO
getGen = lift getGen
instance MonadMWCRandom m => MonadMWCRandom (ReaderT s m) where
getGen :: ReaderT s m GenIO
getGen = lift getGen
instance (Monoid w, MonadMWCRandom m) => MonadMWCRandom (WriterT w m) where
getGen :: WriterT w m GenIO
getGen = lift getGen
instance (Monoid w, MonadMWCRandom m) => MonadMWCRandom (RWST r w s m) where
getGen :: RWST r w s m GenIO
getGen = lift getGen
instance MonadMWCRandom IO where
getGen :: IO GenIO
getGen = MWC.createSystemRandom
uniform :: (MonadIO m, MonadMWCRandom m, Variate a) => m a
uniform = getGen >>= liftIO . MWC.uniform
uniformR :: (MonadIO m, MonadMWCRandom m, Variate a) => (a, a) -> m a
uniformR r = getGen >>= liftIO . MWC.uniformR r
genContVar :: (MonadIO m, MonadMWCRandom m, Stats.ContGen d) => d -> m Double
genContVar d = getGen >>= liftIO . Stats.genContVar d
sampleFrom :: (MonadIO m, MonadMWCRandom m) => [Double] -> m (Int, [Double])
sampleFrom xs = fmap ((,dist) . choose) uniform
where
dist :: [Double]
dist = fmap (/ total) xs
total :: Double
total = sum xs
choose :: Double -> Int
choose n =
fst . unsafeHead .
dropWhile ((< n) . snd) .
zip [0..] .
scanl1 (+) $ dist
newtype MWCRandT m a = MWCRandT { getMWCRandT :: ReaderT GenIO m a }
deriving (Functor, Applicative, Monad, MonadTrans, MonadThrow, MonadIO)
runMWCRandT :: MWCRandT m a -> GenIO -> m a
runMWCRandT = runReaderT . getMWCRandT
type MWCRand = MWCRandT Identity
runMWCRand :: MWCRand a -> GenIO -> a
runMWCRand = runMWCRand
instance Monad m => MonadMWCRandom (MWCRandT m) where
getGen :: MWCRandT m GenIO
getGen = MWCRandT ask
instance MonadEnv m s a r => MonadEnv (MWCRandT m) s a r where
reset :: MWCRandT m (Initial s)
reset = lift reset
step :: a -> MWCRandT m (Obs r s)
step a = lift $ step a