diff --git a/.githooks/pre-commit b/.githooks/pre-commit index 2606d55..e2be07f 100755 --- a/.githooks/pre-commit +++ b/.githooks/pre-commit @@ -8,6 +8,6 @@ filesToFormat=$( for path in $filesToFormat do - ormolu --mode inplace $path + ormolu --ghc-opt -XTypeApplications --mode inplace $path git add $path done; diff --git a/cabal.project b/cabal.project deleted file mode 100644 index 7e47e83..0000000 --- a/cabal.project +++ /dev/null @@ -1,2 +0,0 @@ -packages: - free-monads diff --git a/free-monads/free-monads.cabal b/free-monads.cabal similarity index 95% rename from free-monads/free-monads.cabal rename to free-monads.cabal index c63e4bf..c94ce41 100644 --- a/free-monads/free-monads.cabal +++ b/free-monads.cabal @@ -77,4 +77,6 @@ common free-monads-common library import: free-monads-common hs-source-dirs: src - exposed-modules: Control.Monad.Free + exposed-modules: Control.Monad.Free, + Control.Monad.Free.Compose, + Control.Monad.Free.Scoped diff --git a/free-monads/src/Control/Monad/Free.hs b/src/Control/Monad/Free.hs similarity index 82% rename from free-monads/src/Control/Monad/Free.hs rename to src/Control/Monad/Free.hs index 78d1496..e05a5ce 100644 --- a/free-monads/src/Control/Monad/Free.hs +++ b/src/Control/Monad/Free.hs @@ -20,7 +20,6 @@ module Control.Monad.Free -- * Free Monad Free (..), - monadicAp, -- * Teletype Teletype (..), @@ -77,17 +76,15 @@ instance (Functor f) => Functor (NonEmptyList' f) where fmap f (Last' a) = Last' (f a) fmap f (Cons' a g) = Cons' (f a) (fmap (fmap f) g) +{- ORMOLU_DISABLE -} + twoPlusThree :: NonEmptyList' (Reader Int) Int twoPlusThree = - Cons' - 2 - ( reader - ( \a -> - Cons' - 3 - (reader (\b -> Last' (a + b))) - ) - ) + Cons' 2 (reader (\a -> + Cons' 3 (reader (\b -> + Last' (a + b))))) + +{- ORMOLU_ENABLE -} -- | -- @@ -101,31 +98,28 @@ runNonEmptyList' (Cons' a f) = runNonEmptyList' (runReader f a) -- Third Pass -- ======================================== -data Wrap a b c = Wrap a (b c) deriving (Functor) +data Container a m k = Container a (m k) deriving (Functor) data NonEmptyList'' f a = Last'' a | Cons'' (f (NonEmptyList'' f a)) deriving (Functor) -threePlusFour :: NonEmptyList'' (Wrap Int (Reader Int)) Int +{- ORMOLU_DISABLE -} + +threePlusFour :: NonEmptyList'' (Container Int (Reader Int)) Int threePlusFour = - Cons'' - ( Wrap - 3 - ( reader - ( \a -> - Cons'' - (Wrap 4 (reader (\b -> Last'' (a + b)))) - ) - ) - ) + Cons'' (Container 3 (reader (\a -> + Cons'' (Container 4 (reader (\b -> + Last'' (a + b))))))) + +{- ORMOLU_ENABLE -} -- | -- -- >>> runNonEmptyList'' threePlusFour -- 5 -runNonEmptyList'' :: NonEmptyList'' (Wrap Int (Reader Int)) Int -> Int +runNonEmptyList'' :: NonEmptyList'' (Container Int (Reader Int)) Int -> Int runNonEmptyList'' (Last'' a) = a -runNonEmptyList'' (Cons'' (Wrap a f)) = runNonEmptyList'' (runReader f a) +runNonEmptyList'' (Cons'' (Container a f)) = runNonEmptyList'' (runReader f a) instance (Functor f) => Applicative (NonEmptyList'' f) @@ -148,12 +142,6 @@ instance (Functor f) => Functor (Free f) where fmap f (Pure a) = Pure (f a) fmap f (Free g) = Free (fmap (fmap f) g) -monadicAp :: forall f a b. Functor f => Free f (a -> b) -> Free f a -> Free f b -monadicAp f g = do - f' <- f - g' <- g - pure (f' g') - instance (Functor f) => Applicative (Free f) where pure = Pure @@ -168,7 +156,7 @@ instance (Functor f) => Monad (Free f) where -- Teletype -- ======================================== -data Teletype a = Read a | Write String a deriving (Functor, Show) +data Teletype k = Read k | Write String k deriving (Functor, Show) read :: Free Teletype String read = Free (Read (Pure "hello")) diff --git a/src/Control/Monad/Free/Compose.hs b/src/Control/Monad/Free/Compose.hs new file mode 100644 index 0000000..5551b38 --- /dev/null +++ b/src/Control/Monad/Free/Compose.hs @@ -0,0 +1,274 @@ +{-# LANGUAGE ViewPatterns #-} + +module Control.Monad.Free.Compose + ( -- * State + State (..), + increment, + runState, + + -- * Sum + (:+:) (..), + runTwoState, + runState', + threadedState, + threadedState', + + -- * Member + Member (..), + Void, + inject, + project, + get, + put, + run, + threadedState'', + threadedStateM'', + + -- * Exceptions + Throw (..), + throw, + catch, + runThrow, + countDown, + countDown', + ) +where + +import Control.Monad.Free +import Data.Text (pack) +import qualified Text.Show as S +import Prelude hiding (State, Void, get, put, runState) + +-- ======================================== +-- State +-- ======================================== + +data State s k = Get (s -> k) | Put s k deriving (Functor) + +instance (Show s, Show k) => Show (State s k) where + show (Get _) = "Get " + show (Put s k) = "Put " <> S.show s <> " " <> S.show k + +runState :: forall s a. s -> Free (State s) a -> (s, a) +runState s (Free (Get f)) = runState s (f s) +runState _ (Free (Put s' f)) = runState s' f +runState s (Pure a) = (s, a) + +-- | +-- +-- >>> runState 0 increment +-- (1, ()) +increment :: Free (State Int) () +increment = Free (Get (\s -> Free (Put (s + 1) (Pure ())))) + +-- ======================================== +-- Sum +-- ======================================== + +data (f :+: g) k = L (f k) | R (g k) deriving (Functor, Show) + +infixr 4 :+: + +runTwoState :: + forall s1 s2 a. + s1 -> + s2 -> + Free (State s1 :+: State s2) a -> + (s1, s2, a) +runTwoState s1 s2 (Free (L (Get f))) = runTwoState s1 s2 (f s1) +runTwoState s1 s2 (Free (R (Get f))) = runTwoState s1 s2 (f s2) +runTwoState _ s2 (Free (L (Put s1 f))) = runTwoState s1 s2 f +runTwoState s1 _ (Free (R (Put s2 f))) = runTwoState s1 s2 f +runTwoState s1 s2 (Pure a) = (s1, s2, a) + +runState' :: + forall s a sig. + Functor sig => + s -> + Free (State s :+: sig) a -> + Free sig (s, a) +runState' s (Pure a) = pure (s, a) +runState' s (Free (L (Get f))) = runState' s (f s) +runState' _ (Free (L (Put s f))) = runState' s f +runState' s (Free (R other)) = Free (fmap (runState' s) other) + +{- ORMOLU_DISABLE -} + +-- | +-- +-- >>> runState "" (runState' 0 threadedState) +-- ("a",(1,())) +threadedState :: Free (State Int :+: State String) () +threadedState = + Free (L (Get (\s1 -> + Free (R (Get (\s2 -> + Free (L (Put (s1 + 1) + (Free (R (Put (s2 ++ "a") + (Pure ())))))))))))) + +threadedState' :: Free (State String :+: State Int) () +threadedState' = + Free (R (Get (\s1 -> + Free (L (Get (\s2 -> + Free (R (Put (s1 + 1) + (Free (L (Put (s2 ++ "a") + (Pure ())))))))))))) + +{- ORMOLU_ENABLE -} + +-- ======================================== +-- Membership +-- ======================================== + +class Member sub sup where + inj :: sub a -> sup a + prj :: sup a -> Maybe (sub a) + +instance Member sig sig where + inj = id + prj = Just + +instance + {-# OVERLAPPABLE #-} + Member sig (l1 :+: (l2 :+: r)) => + Member sig ((l1 :+: l2) :+: r) + where + inj sub = case inj sub of + L l1 -> L (L l1) + R (L l2) -> L (R l2) + R (R r) -> R r + prj sup = case sup of + L (L l1) -> prj (L @l1 @(l2 :+: r) l1) + L (R l2) -> prj (R @l1 @(l2 :+: r) (L @l2 l2)) + R r -> prj (R @l1 @(l2 :+: r) (R @l2 @r r)) + +instance {-# OVERLAPPABLE #-} Member sig (sig :+: r) where + inj = L + prj (L f) = Just f + prj _ = Nothing + +instance {-# OVERLAPPABLE #-} (Member sig r) => Member sig (l :+: r) where + inj = R . inj + prj (R g) = prj g + prj _ = Nothing + +data Void k deriving (Functor) + +run :: forall a. Free Void a -> a +run (Pure a) = a +run _ = error (pack "impossible") + +{- ORMOLU_DISABLE -} + +threadedState'' :: + Functor sig => + Member (State Int) sig => + Member (State String) sig => + Free sig () +threadedState'' = + Free (inj (Get @Int (\s1 -> + Free (inj (Get (\s2 -> + Free (inj (Put (s1 + 1) + (Free (inj (Put (s2 ++ "a") + (Pure ())))))))))))) + +{- ORMOLU_ENABLE -} + +inject :: + forall a sub sup. + Member sub sup => + sub (Free sup a) -> + Free sup a +inject = Free . inj + +project :: + forall a sub sup. + Member sub sup => + Free sup a -> + Maybe (sub (Free sup a)) +project (Free s) = prj s +project _ = Nothing + +get :: forall s sig. Functor sig => Member (State s) sig => Free sig s +get = inject (Get pure) + +put :: forall s sig. Functor sig => Member (State s) sig => s -> Free sig () +put s = inject (Put s (pure ())) + +threadedStateM'' :: + Functor sig => + Member (State Int) sig => + Member (State String) sig => + Free sig () +threadedStateM'' = do + s1 <- get @Int + s2 <- get @String + put (s1 + 1) + put (s2 ++ "a") + pure () + +-- ======================================== +-- Exceptions +-- ======================================== + +newtype Throw e k = Throw e deriving (Functor) + +throw :: forall e a sig. Functor sig => Member (Throw e) sig => e -> Free sig a +throw e = inject (Throw e) + +catch :: + forall e a sig. + Functor sig => + Free (Throw e :+: sig) a -> + (e -> Free sig a) -> + Free sig a +catch (Pure a) _ = pure a +catch (Free (L (Throw e))) h = h e +catch (Free (R other)) h = Free (fmap (`catch` h) other) + +runThrow :: + forall e a sig. + Functor sig => + Free (Throw e :+: sig) a -> + Free sig (Either e a) +runThrow (Pure a) = pure (Right a) +runThrow (Free (L (Throw e))) = pure (Left e) +runThrow (Free (R other)) = Free (fmap runThrow other) + +countDown :: + forall sig. + Functor sig => + Member (State Int) sig => + Member (Throw ()) sig => + Free sig () +countDown = do + decr + catch (decr >> decr) pure + where + decr :: + forall sig2. + Functor sig2 => + Member (State Int) sig2 => + Member (Throw ()) sig2 => + Free sig2 () + decr = do + x <- get @Int + if x > 0 then put (x - 1) else throw () + +{- ORMOLU_DISABLE -} + +countDown' :: + Functor sig => + Member (State Int) sig => + Member (Throw ()) sig => + Free sig () +countDown' = + Free (inj (Get @Int (\x -> + let a = \k -> if x > 0 then Free (inj (Put (x - 1) k)) else throw () + in a (catch (Free (inj (Get @Int (\y -> + let b = \k -> if y > 0 then Free (inj (Put (y - 1) k)) else throw () + in b (Free (inj (Get @Int (\z -> + let c = \k -> if z > 0 then Free (inj (Put (z - 1) k)) else throw () + in c (Pure ()))))))))) pure)))) + +{- ORMOLU_ENABLE -} diff --git a/src/Control/Monad/Free/Scoped.hs b/src/Control/Monad/Free/Scoped.hs new file mode 100644 index 0000000..6d3a3c4 --- /dev/null +++ b/src/Control/Monad/Free/Scoped.hs @@ -0,0 +1,264 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE UndecidableInstances #-} + +module Control.Monad.Free.Scoped + ( -- * Free + HFunctor (..), + Syntax (..), + Free (..), + + -- * Sum + (:+:) (..), + + -- * Members + Member (..), + inject, + project, + + -- * Lifting + Lift (..), + HState, + hIncrement, + runState, + get, + put, + HVoid, + run, + + -- * Exceptions + Error (..), + throw, + catch, + runError, + countDown, + ) +where + +import Control.Monad (ap) +import Control.Monad.Free.Compose (State (..), Void) +import Data.Text (pack) +import Prelude hiding (State, Void, get, put, runState) + +{-# ANN module "HLINT: ignore Use <&>" #-} + +-- ======================================== +-- Free +-- ======================================== + +class HFunctor f where + hmap :: + (Functor m, Functor n) => + (forall x. m x -> n x) -> + (forall x. f m x -> f n x) + +class HFunctor f => Syntax f where + emap :: (m a -> m b) -> (f m a -> f m b) + + weave :: + (Monad m, Monad n, Functor ctx) => + ctx () -> + Handler ctx m n -> + (f m a -> f n (ctx a)) + +type Handler ctx m n = forall x. ctx (m x) -> n (ctx x) + +data Free f a = Pure a | Free (f (Free f) a) + +instance Syntax f => Functor (Free f) where + fmap f m = m >>= pure . f + +instance Syntax f => Applicative (Free f) where + pure = Pure + (<*>) = ap + +instance Syntax f => Monad (Free f) where + Pure a >>= g = g a + Free f >>= g = Free (emap (>>= g) f) + +-- ======================================== +-- Sum +-- ======================================== + +data (f :+: g) (m :: Type -> Type) a = L (f m a) | R (g m a) + +infixr 4 :+: + +instance (HFunctor f, HFunctor g) => HFunctor (f :+: g) where + hmap t (L f) = L (hmap t f) + hmap t (R g) = R (hmap t g) + +instance (Syntax f, Syntax g) => Syntax (f :+: g) where + emap t (L f) = L (emap t f) + emap t (R g) = R (emap t g) + + weave ctx hdl (L f) = L (weave ctx hdl f) + weave ctx hdl (R g) = R (weave ctx hdl g) + +-- ======================================== +-- Members +-- ======================================== + +class (Syntax sub, Syntax sup) => Member sub sup where + inj :: sub m a -> sup m a + prj :: sup m a -> Maybe (sub m a) + +instance (Syntax sig) => Member sig sig where + inj = id + prj = Just + +instance + {-# OVERLAPPABLE #-} + ( Syntax sig, + Syntax l1, + Syntax l2, + Syntax r, + Member sig (l1 :+: (l2 :+: r)) + ) => + Member sig ((l1 :+: l2) :+: r) + where + inj sub = case inj sub of + L l1 -> L (L l1) + R (L l2) -> L (R l2) + R (R r) -> R r + prj sup = case sup of + L (L l1) -> prj (L @l1 @(l2 :+: r) l1) + L (R l2) -> prj (R @l1 @(l2 :+: r) (L @l2 l2)) + R r -> prj (R @l1 @(l2 :+: r) (R @l2 @r r)) + +instance + {-# OVERLAPPABLE #-} + (Syntax sig, Syntax r) => + Member sig (sig :+: r) + where + inj = L + prj (L f) = Just f + prj _ = Nothing + +instance + {-# OVERLAPPABLE #-} + (Member sig r, Syntax l) => + Member sig (l :+: r) + where + inj = R . inj + prj (R g) = prj g + prj _ = Nothing + +inject :: + forall a sub sup. + Member sub sup => + sub (Free sup) a -> + Free sup a +inject = Free . inj + +project :: + forall a sub sup. + Member sub sup => + Free sup a -> + Maybe (sub (Free sup) a) +project (Free s) = prj s +project _ = Nothing + +-- ======================================== +-- Lift +-- ======================================== + +newtype Lift sig (m :: Type -> Type) a = Lift (sig (m a)) + +type HState s = Lift (State s) + +hIncrement :: Free (Lift (State Int)) () +hIncrement = Free (Lift (Get (\s -> Free (Lift (Put (s + 1) (Pure ())))))) + +instance Functor sig => HFunctor (Lift sig) where + hmap t (Lift f) = Lift (fmap t f) + +instance Functor sig => Syntax (Lift sig) where + emap t (Lift f) = Lift (fmap t f) + + weave ctx hdl (Lift f) = Lift (fmap (\p -> hdl (fmap (const p) ctx)) f) + +runState :: + forall s a sig. + Syntax sig => + s -> + Free (HState s :+: sig) a -> + Free sig (s, a) +runState s (Pure a) = pure (s, a) +runState s (Free (L (Lift (Get f)))) = runState s (f s) +runState _ (Free (L (Lift (Put s f)))) = runState s f +runState s (Free (R other)) = Free (weave (s, ()) hdl other) + where + hdl :: forall x. (s, Free (HState s :+: sig) x) -> Free sig (s, x) + hdl = uncurry runState + +get :: forall s sig. HFunctor sig => Member (HState s) sig => Free sig s +get = inject (Lift (Get Pure)) + +put :: forall s sig. HFunctor sig => Member (HState s) sig => s -> Free sig () +put s = inject (Lift (Put s (pure ()))) + +type HVoid = Lift Void + +run :: Free HVoid a -> a +run (Pure a) = a +run _ = error (pack "impossible") + +-- ======================================== +-- Exceptions +-- ======================================== + +data Error e m a + = Throw e + | forall x. Catch (m x) (e -> m x) (x -> m a) + +instance HFunctor (Error e) where + hmap _ (Throw x) = Throw x + hmap t (Catch p h k) = Catch (t p) (t . h) (t . k) + +instance Syntax (Error e) where + emap _ (Throw e) = Throw e + emap f (Catch p h k) = Catch p h (f . k) + + weave _ _ (Throw x) = Throw x + weave ctx hdl (Catch p h k) = + Catch + (hdl (fmap (const p) ctx)) + (\e -> hdl (fmap (const (h e)) ctx)) + (hdl . fmap k) + +throw :: Member (Error e) sig => e -> Free sig a +throw e = inject (Throw e) + +catch :: Member (Error e) sig => Free sig a -> (e -> Free sig a) -> Free sig a +catch p h = inject (Catch p h pure) + +runError :: + forall e a sig. + Syntax sig => + Free (Error e :+: sig) a -> + Free sig (Either e a) +runError (Pure a) = pure (Right a) +runError (Free (L (Throw e))) = pure (Left e) +runError (Free (L (Catch p h k))) = + runError p >>= \case + Left e -> + runError (h e) >>= \case + Left e' -> pure (Left e') + Right a -> runError (k a) + Right a -> runError (k a) +runError (Free (R other)) = + Free $ weave (Right ()) (either (pure . Left) runError) other + +countDown :: + forall sig. + Syntax sig => + Member (HState Int) sig => + Member (Error ()) sig => + Free sig () +countDown = do + decr {- 1 -} + catch (decr {- 2 -} >> decr {- 3 -}) pure + where + decr = do + x <- get @Int + if x > 0 then put (x - 1) else throw ()