1
Fork 0

Higher order effects.

main
Joshua Potter 2022-03-14 07:48:30 -04:00
parent 02a3b4f715
commit 206f8c9594
6 changed files with 561 additions and 35 deletions

View File

@ -8,6 +8,6 @@ filesToFormat=$(
for path in $filesToFormat for path in $filesToFormat
do do
ormolu --mode inplace $path ormolu --ghc-opt -XTypeApplications --mode inplace $path
git add $path git add $path
done; done;

View File

@ -1,2 +0,0 @@
packages:
free-monads

View File

@ -77,4 +77,6 @@ common free-monads-common
library library
import: free-monads-common import: free-monads-common
hs-source-dirs: src hs-source-dirs: src
exposed-modules: Control.Monad.Free exposed-modules: Control.Monad.Free,
Control.Monad.Free.Compose,
Control.Monad.Free.Scoped

View File

@ -20,7 +20,6 @@ module Control.Monad.Free
-- * Free Monad -- * Free Monad
Free (..), Free (..),
monadicAp,
-- * Teletype -- * Teletype
Teletype (..), Teletype (..),
@ -77,17 +76,15 @@ instance (Functor f) => Functor (NonEmptyList' f) where
fmap f (Last' a) = Last' (f a) fmap f (Last' a) = Last' (f a)
fmap f (Cons' a g) = Cons' (f a) (fmap (fmap f) g) fmap f (Cons' a g) = Cons' (f a) (fmap (fmap f) g)
{- ORMOLU_DISABLE -}
twoPlusThree :: NonEmptyList' (Reader Int) Int twoPlusThree :: NonEmptyList' (Reader Int) Int
twoPlusThree = twoPlusThree =
Cons' Cons' 2 (reader (\a ->
2 Cons' 3 (reader (\b ->
( reader Last' (a + b)))))
( \a ->
Cons' {- ORMOLU_ENABLE -}
3
(reader (\b -> Last' (a + b)))
)
)
-- | -- |
-- --
@ -101,31 +98,28 @@ runNonEmptyList' (Cons' a f) = runNonEmptyList' (runReader f a)
-- Third Pass -- Third Pass
-- ======================================== -- ========================================
data Wrap a b c = Wrap a (b c) deriving (Functor) data Container a m k = Container a (m k) deriving (Functor)
data NonEmptyList'' f a = Last'' a | Cons'' (f (NonEmptyList'' f a)) data NonEmptyList'' f a = Last'' a | Cons'' (f (NonEmptyList'' f a))
deriving (Functor) deriving (Functor)
threePlusFour :: NonEmptyList'' (Wrap Int (Reader Int)) Int {- ORMOLU_DISABLE -}
threePlusFour :: NonEmptyList'' (Container Int (Reader Int)) Int
threePlusFour = threePlusFour =
Cons'' Cons'' (Container 3 (reader (\a ->
( Wrap Cons'' (Container 4 (reader (\b ->
3 Last'' (a + b)))))))
( reader
( \a -> {- ORMOLU_ENABLE -}
Cons''
(Wrap 4 (reader (\b -> Last'' (a + b))))
)
)
)
-- | -- |
-- --
-- >>> runNonEmptyList'' threePlusFour -- >>> runNonEmptyList'' threePlusFour
-- 5 -- 5
runNonEmptyList'' :: NonEmptyList'' (Wrap Int (Reader Int)) Int -> Int runNonEmptyList'' :: NonEmptyList'' (Container Int (Reader Int)) Int -> Int
runNonEmptyList'' (Last'' a) = a runNonEmptyList'' (Last'' a) = a
runNonEmptyList'' (Cons'' (Wrap a f)) = runNonEmptyList'' (runReader f a) runNonEmptyList'' (Cons'' (Container a f)) = runNonEmptyList'' (runReader f a)
instance (Functor f) => Applicative (NonEmptyList'' f) instance (Functor f) => Applicative (NonEmptyList'' f)
@ -148,12 +142,6 @@ instance (Functor f) => Functor (Free f) where
fmap f (Pure a) = Pure (f a) fmap f (Pure a) = Pure (f a)
fmap f (Free g) = Free (fmap (fmap f) g) fmap f (Free g) = Free (fmap (fmap f) g)
monadicAp :: forall f a b. Functor f => Free f (a -> b) -> Free f a -> Free f b
monadicAp f g = do
f' <- f
g' <- g
pure (f' g')
instance (Functor f) => Applicative (Free f) where instance (Functor f) => Applicative (Free f) where
pure = Pure pure = Pure
@ -168,7 +156,7 @@ instance (Functor f) => Monad (Free f) where
-- Teletype -- Teletype
-- ======================================== -- ========================================
data Teletype a = Read a | Write String a deriving (Functor, Show) data Teletype k = Read k | Write String k deriving (Functor, Show)
read :: Free Teletype String read :: Free Teletype String
read = Free (Read (Pure "hello")) read = Free (Read (Pure "hello"))

View File

@ -0,0 +1,274 @@
{-# LANGUAGE ViewPatterns #-}
module Control.Monad.Free.Compose
( -- * State
State (..),
increment,
runState,
-- * Sum
(:+:) (..),
runTwoState,
runState',
threadedState,
threadedState',
-- * Member
Member (..),
Void,
inject,
project,
get,
put,
run,
threadedState'',
threadedStateM'',
-- * Exceptions
Throw (..),
throw,
catch,
runThrow,
countDown,
countDown',
)
where
import Control.Monad.Free
import Data.Text (pack)
import qualified Text.Show as S
import Prelude hiding (State, Void, get, put, runState)
-- ========================================
-- State
-- ========================================
data State s k = Get (s -> k) | Put s k deriving (Functor)
instance (Show s, Show k) => Show (State s k) where
show (Get _) = "Get <function>"
show (Put s k) = "Put " <> S.show s <> " " <> S.show k
runState :: forall s a. s -> Free (State s) a -> (s, a)
runState s (Free (Get f)) = runState s (f s)
runState _ (Free (Put s' f)) = runState s' f
runState s (Pure a) = (s, a)
-- |
--
-- >>> runState 0 increment
-- (1, ())
increment :: Free (State Int) ()
increment = Free (Get (\s -> Free (Put (s + 1) (Pure ()))))
-- ========================================
-- Sum
-- ========================================
data (f :+: g) k = L (f k) | R (g k) deriving (Functor, Show)
infixr 4 :+:
runTwoState ::
forall s1 s2 a.
s1 ->
s2 ->
Free (State s1 :+: State s2) a ->
(s1, s2, a)
runTwoState s1 s2 (Free (L (Get f))) = runTwoState s1 s2 (f s1)
runTwoState s1 s2 (Free (R (Get f))) = runTwoState s1 s2 (f s2)
runTwoState _ s2 (Free (L (Put s1 f))) = runTwoState s1 s2 f
runTwoState s1 _ (Free (R (Put s2 f))) = runTwoState s1 s2 f
runTwoState s1 s2 (Pure a) = (s1, s2, a)
runState' ::
forall s a sig.
Functor sig =>
s ->
Free (State s :+: sig) a ->
Free sig (s, a)
runState' s (Pure a) = pure (s, a)
runState' s (Free (L (Get f))) = runState' s (f s)
runState' _ (Free (L (Put s f))) = runState' s f
runState' s (Free (R other)) = Free (fmap (runState' s) other)
{- ORMOLU_DISABLE -}
-- |
--
-- >>> runState "" (runState' 0 threadedState)
-- ("a",(1,()))
threadedState :: Free (State Int :+: State String) ()
threadedState =
Free (L (Get (\s1 ->
Free (R (Get (\s2 ->
Free (L (Put (s1 + 1)
(Free (R (Put (s2 ++ "a")
(Pure ()))))))))))))
threadedState' :: Free (State String :+: State Int) ()
threadedState' =
Free (R (Get (\s1 ->
Free (L (Get (\s2 ->
Free (R (Put (s1 + 1)
(Free (L (Put (s2 ++ "a")
(Pure ()))))))))))))
{- ORMOLU_ENABLE -}
-- ========================================
-- Membership
-- ========================================
class Member sub sup where
inj :: sub a -> sup a
prj :: sup a -> Maybe (sub a)
instance Member sig sig where
inj = id
prj = Just
instance
{-# OVERLAPPABLE #-}
Member sig (l1 :+: (l2 :+: r)) =>
Member sig ((l1 :+: l2) :+: r)
where
inj sub = case inj sub of
L l1 -> L (L l1)
R (L l2) -> L (R l2)
R (R r) -> R r
prj sup = case sup of
L (L l1) -> prj (L @l1 @(l2 :+: r) l1)
L (R l2) -> prj (R @l1 @(l2 :+: r) (L @l2 l2))
R r -> prj (R @l1 @(l2 :+: r) (R @l2 @r r))
instance {-# OVERLAPPABLE #-} Member sig (sig :+: r) where
inj = L
prj (L f) = Just f
prj _ = Nothing
instance {-# OVERLAPPABLE #-} (Member sig r) => Member sig (l :+: r) where
inj = R . inj
prj (R g) = prj g
prj _ = Nothing
data Void k deriving (Functor)
run :: forall a. Free Void a -> a
run (Pure a) = a
run _ = error (pack "impossible")
{- ORMOLU_DISABLE -}
threadedState'' ::
Functor sig =>
Member (State Int) sig =>
Member (State String) sig =>
Free sig ()
threadedState'' =
Free (inj (Get @Int (\s1 ->
Free (inj (Get (\s2 ->
Free (inj (Put (s1 + 1)
(Free (inj (Put (s2 ++ "a")
(Pure ()))))))))))))
{- ORMOLU_ENABLE -}
inject ::
forall a sub sup.
Member sub sup =>
sub (Free sup a) ->
Free sup a
inject = Free . inj
project ::
forall a sub sup.
Member sub sup =>
Free sup a ->
Maybe (sub (Free sup a))
project (Free s) = prj s
project _ = Nothing
get :: forall s sig. Functor sig => Member (State s) sig => Free sig s
get = inject (Get pure)
put :: forall s sig. Functor sig => Member (State s) sig => s -> Free sig ()
put s = inject (Put s (pure ()))
threadedStateM'' ::
Functor sig =>
Member (State Int) sig =>
Member (State String) sig =>
Free sig ()
threadedStateM'' = do
s1 <- get @Int
s2 <- get @String
put (s1 + 1)
put (s2 ++ "a")
pure ()
-- ========================================
-- Exceptions
-- ========================================
newtype Throw e k = Throw e deriving (Functor)
throw :: forall e a sig. Functor sig => Member (Throw e) sig => e -> Free sig a
throw e = inject (Throw e)
catch ::
forall e a sig.
Functor sig =>
Free (Throw e :+: sig) a ->
(e -> Free sig a) ->
Free sig a
catch (Pure a) _ = pure a
catch (Free (L (Throw e))) h = h e
catch (Free (R other)) h = Free (fmap (`catch` h) other)
runThrow ::
forall e a sig.
Functor sig =>
Free (Throw e :+: sig) a ->
Free sig (Either e a)
runThrow (Pure a) = pure (Right a)
runThrow (Free (L (Throw e))) = pure (Left e)
runThrow (Free (R other)) = Free (fmap runThrow other)
countDown ::
forall sig.
Functor sig =>
Member (State Int) sig =>
Member (Throw ()) sig =>
Free sig ()
countDown = do
decr
catch (decr >> decr) pure
where
decr ::
forall sig2.
Functor sig2 =>
Member (State Int) sig2 =>
Member (Throw ()) sig2 =>
Free sig2 ()
decr = do
x <- get @Int
if x > 0 then put (x - 1) else throw ()
{- ORMOLU_DISABLE -}
countDown' ::
Functor sig =>
Member (State Int) sig =>
Member (Throw ()) sig =>
Free sig ()
countDown' =
Free (inj (Get @Int (\x ->
let a = \k -> if x > 0 then Free (inj (Put (x - 1) k)) else throw ()
in a (catch (Free (inj (Get @Int (\y ->
let b = \k -> if y > 0 then Free (inj (Put (y - 1) k)) else throw ()
in b (Free (inj (Get @Int (\z ->
let c = \k -> if z > 0 then Free (inj (Put (z - 1) k)) else throw ()
in c (Pure ()))))))))) pure))))
{- ORMOLU_ENABLE -}

View File

@ -0,0 +1,264 @@
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.Monad.Free.Scoped
( -- * Free
HFunctor (..),
Syntax (..),
Free (..),
-- * Sum
(:+:) (..),
-- * Members
Member (..),
inject,
project,
-- * Lifting
Lift (..),
HState,
hIncrement,
runState,
get,
put,
HVoid,
run,
-- * Exceptions
Error (..),
throw,
catch,
runError,
countDown,
)
where
import Control.Monad (ap)
import Control.Monad.Free.Compose (State (..), Void)
import Data.Text (pack)
import Prelude hiding (State, Void, get, put, runState)
{-# ANN module "HLINT: ignore Use <&>" #-}
-- ========================================
-- Free
-- ========================================
class HFunctor f where
hmap ::
(Functor m, Functor n) =>
(forall x. m x -> n x) ->
(forall x. f m x -> f n x)
class HFunctor f => Syntax f where
emap :: (m a -> m b) -> (f m a -> f m b)
weave ::
(Monad m, Monad n, Functor ctx) =>
ctx () ->
Handler ctx m n ->
(f m a -> f n (ctx a))
type Handler ctx m n = forall x. ctx (m x) -> n (ctx x)
data Free f a = Pure a | Free (f (Free f) a)
instance Syntax f => Functor (Free f) where
fmap f m = m >>= pure . f
instance Syntax f => Applicative (Free f) where
pure = Pure
(<*>) = ap
instance Syntax f => Monad (Free f) where
Pure a >>= g = g a
Free f >>= g = Free (emap (>>= g) f)
-- ========================================
-- Sum
-- ========================================
data (f :+: g) (m :: Type -> Type) a = L (f m a) | R (g m a)
infixr 4 :+:
instance (HFunctor f, HFunctor g) => HFunctor (f :+: g) where
hmap t (L f) = L (hmap t f)
hmap t (R g) = R (hmap t g)
instance (Syntax f, Syntax g) => Syntax (f :+: g) where
emap t (L f) = L (emap t f)
emap t (R g) = R (emap t g)
weave ctx hdl (L f) = L (weave ctx hdl f)
weave ctx hdl (R g) = R (weave ctx hdl g)
-- ========================================
-- Members
-- ========================================
class (Syntax sub, Syntax sup) => Member sub sup where
inj :: sub m a -> sup m a
prj :: sup m a -> Maybe (sub m a)
instance (Syntax sig) => Member sig sig where
inj = id
prj = Just
instance
{-# OVERLAPPABLE #-}
( Syntax sig,
Syntax l1,
Syntax l2,
Syntax r,
Member sig (l1 :+: (l2 :+: r))
) =>
Member sig ((l1 :+: l2) :+: r)
where
inj sub = case inj sub of
L l1 -> L (L l1)
R (L l2) -> L (R l2)
R (R r) -> R r
prj sup = case sup of
L (L l1) -> prj (L @l1 @(l2 :+: r) l1)
L (R l2) -> prj (R @l1 @(l2 :+: r) (L @l2 l2))
R r -> prj (R @l1 @(l2 :+: r) (R @l2 @r r))
instance
{-# OVERLAPPABLE #-}
(Syntax sig, Syntax r) =>
Member sig (sig :+: r)
where
inj = L
prj (L f) = Just f
prj _ = Nothing
instance
{-# OVERLAPPABLE #-}
(Member sig r, Syntax l) =>
Member sig (l :+: r)
where
inj = R . inj
prj (R g) = prj g
prj _ = Nothing
inject ::
forall a sub sup.
Member sub sup =>
sub (Free sup) a ->
Free sup a
inject = Free . inj
project ::
forall a sub sup.
Member sub sup =>
Free sup a ->
Maybe (sub (Free sup) a)
project (Free s) = prj s
project _ = Nothing
-- ========================================
-- Lift
-- ========================================
newtype Lift sig (m :: Type -> Type) a = Lift (sig (m a))
type HState s = Lift (State s)
hIncrement :: Free (Lift (State Int)) ()
hIncrement = Free (Lift (Get (\s -> Free (Lift (Put (s + 1) (Pure ()))))))
instance Functor sig => HFunctor (Lift sig) where
hmap t (Lift f) = Lift (fmap t f)
instance Functor sig => Syntax (Lift sig) where
emap t (Lift f) = Lift (fmap t f)
weave ctx hdl (Lift f) = Lift (fmap (\p -> hdl (fmap (const p) ctx)) f)
runState ::
forall s a sig.
Syntax sig =>
s ->
Free (HState s :+: sig) a ->
Free sig (s, a)
runState s (Pure a) = pure (s, a)
runState s (Free (L (Lift (Get f)))) = runState s (f s)
runState _ (Free (L (Lift (Put s f)))) = runState s f
runState s (Free (R other)) = Free (weave (s, ()) hdl other)
where
hdl :: forall x. (s, Free (HState s :+: sig) x) -> Free sig (s, x)
hdl = uncurry runState
get :: forall s sig. HFunctor sig => Member (HState s) sig => Free sig s
get = inject (Lift (Get Pure))
put :: forall s sig. HFunctor sig => Member (HState s) sig => s -> Free sig ()
put s = inject (Lift (Put s (pure ())))
type HVoid = Lift Void
run :: Free HVoid a -> a
run (Pure a) = a
run _ = error (pack "impossible")
-- ========================================
-- Exceptions
-- ========================================
data Error e m a
= Throw e
| forall x. Catch (m x) (e -> m x) (x -> m a)
instance HFunctor (Error e) where
hmap _ (Throw x) = Throw x
hmap t (Catch p h k) = Catch (t p) (t . h) (t . k)
instance Syntax (Error e) where
emap _ (Throw e) = Throw e
emap f (Catch p h k) = Catch p h (f . k)
weave _ _ (Throw x) = Throw x
weave ctx hdl (Catch p h k) =
Catch
(hdl (fmap (const p) ctx))
(\e -> hdl (fmap (const (h e)) ctx))
(hdl . fmap k)
throw :: Member (Error e) sig => e -> Free sig a
throw e = inject (Throw e)
catch :: Member (Error e) sig => Free sig a -> (e -> Free sig a) -> Free sig a
catch p h = inject (Catch p h pure)
runError ::
forall e a sig.
Syntax sig =>
Free (Error e :+: sig) a ->
Free sig (Either e a)
runError (Pure a) = pure (Right a)
runError (Free (L (Throw e))) = pure (Left e)
runError (Free (L (Catch p h k))) =
runError p >>= \case
Left e ->
runError (h e) >>= \case
Left e' -> pure (Left e')
Right a -> runError (k a)
Right a -> runError (k a)
runError (Free (R other)) =
Free $ weave (Right ()) (either (pure . Left) runError) other
countDown ::
forall sig.
Syntax sig =>
Member (HState Int) sig =>
Member (Error ()) sig =>
Free sig ()
countDown = do
decr {- 1 -}
catch (decr {- 2 -} >> decr {- 3 -}) pure
where
decr = do
x <- get @Int
if x > 0 then put (x - 1) else throw ()