Keep same `unliftio` exception semantics.

main
Joshua Potter 2022-04-03 12:13:20 -04:00
parent e4aa4c78be
commit 54b92afc06
3 changed files with 73 additions and 254 deletions

View File

@ -32,14 +32,13 @@ library
hs-source-dirs: src hs-source-dirs: src
exposed-modules: exposed-modules:
Control.Effect.Exception Control.Effect.Exception
Control.Effect.Exception.UnliftIO Control.Effect.UnliftIO.Exception
other-modules:
Control.Effect.Exception.Internal
build-depends: build-depends:
base >= 4.7 && < 5 base >= 4.7 && < 5
, fused-effects >= 1.1 , fused-effects >= 1.1
, transformers >= 0.4 && < 0.6 , transformers >= 0.4 && < 0.6
, unliftio-core >= 0.2 && < 0.3 , unliftio-core >= 0.2 && < 0.3
, unliftio >= 0.2 && < 0.3
test-suite test test-suite test
import: common import: common

View File

@ -1,132 +0,0 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}
-- | Operations from "UnliftIO.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'.
--
-- These methods are shamelessly copied from the @unliftio@ module in an effort to keep dependencies small.
-- @unliftio-core@ is assumed fair game though considering it's included in core @fused-effects@ package.
--
-- @since <version>
module Control.Effect.Exception.Internal
( -- * Throwing
throwIO
-- * Catching (with recovery)
, catch
, catchJust
, Handler (..)
, catches
-- * Evaluation
, evaluate
-- * Masking
, mask
, uninterruptibleMask
-- * Reexports
, Exception (..)
, Typeable
, SomeException (..)
, SomeAsyncException (..)
, IOException
) where
import Control.Exception (Exception(..), IOException, SomeAsyncException(..), SomeException(..))
import qualified Control.Exception as EUnsafe
import Control.Monad.IO.Unlift
import Data.Typeable (Typeable, cast)
-- | Catch a synchronous (but not asynchronous) exception and recover from it.
--
-- This is parameterized on the exception type. To catch all synchronous exceptions,
-- use 'catchAny'.
catch
:: (MonadUnliftIO m, Exception e)
=> m a -- ^ action
-> (e -> m a) -- ^ handler
-> m a
catch f g = withRunInIO $ \run -> run f `EUnsafe.catch` \e ->
if isSyncException e
then run (g e)
-- intentionally rethrowing an async exception synchronously,
-- since we want to preserve async behavior
else EUnsafe.throwIO e
-- | 'catchJust' is like 'catch' but it takes an extra argument which
-- is an exception predicate, a function which selects which type of
-- exceptions we're interested in.
catchJust :: (MonadUnliftIO m, Exception e) => (e -> Maybe b) -> m a -> (b -> m a) -> m a
catchJust f a b = a `catch` \e -> maybe (liftIO (throwIO e)) b $ f e
-- | A helper data type for usage with 'catches' and similar functions.
data Handler m a = forall e . Exception e => Handler (e -> m a)
-- | Internal.
catchesHandler :: MonadIO m => [Handler m a] -> SomeException -> m a
catchesHandler handlers e = foldr tryHandler (liftIO (EUnsafe.throwIO e)) handlers
where tryHandler (Handler handler) res
= case fromException e of
Just e' -> handler e'
Nothing -> res
-- | Similar to 'catch', but provides multiple different handler functions.
--
-- For more information on motivation, see @base@'s 'EUnsafe.catches'. Note that,
-- unlike that function, this function will not catch asynchronous exceptions.
catches :: MonadUnliftIO m => m a -> [Handler m a] -> m a
catches io handlers = io `catch` catchesHandler handlers
-- | Lifted version of 'EUnsafe.evaluate'.
evaluate :: MonadIO m => a -> m a
evaluate = liftIO . EUnsafe.evaluate
-- | Synchronously throw the given exception.
--
-- Note that, if you provide an exception value which is of an asynchronous
-- type, it will be wrapped up in 'SyncExceptionWrapper'. See 'toSyncException'.
throwIO :: (MonadIO m, Exception e) => e -> m a
throwIO = liftIO . EUnsafe.throwIO . toSyncException
-- | Wrap up an asynchronous exception to be treated as a synchronous
-- exception.
--
-- This is intended to be created via 'toSyncException'.
data SyncExceptionWrapper = forall e. Exception e => SyncExceptionWrapper e
deriving Typeable
instance Show SyncExceptionWrapper where
show (SyncExceptionWrapper e) = show e
instance Exception SyncExceptionWrapper where
#if MIN_VERSION_base(4,8,0)
displayException (SyncExceptionWrapper e) = displayException e
#endif
-- | Convert an exception into a synchronous exception.
--
-- For synchronous exceptions, this is the same as 'toException'.
-- For asynchronous exceptions, this will wrap up the exception with
-- 'SyncExceptionWrapper'.
toSyncException :: Exception e => e -> SomeException
toSyncException e =
case fromException se of
Just (SomeAsyncException _) -> toException (SyncExceptionWrapper e)
Nothing -> se
where
se = toException e
-- | Check if the given exception is synchronous.
isSyncException :: Exception e => e -> Bool
isSyncException e =
case fromException (toException e) of
Just (SomeAsyncException _) -> False
Nothing -> True
-- | Unlifted version of 'EUnsafe.mask'.
mask :: MonadUnliftIO m => ((forall a. m a -> m a) -> m b) -> m b
mask f = withRunInIO $ \run -> EUnsafe.mask $ \unmask ->
run $ f $ liftIO . unmask . run
-- | Unlifted version of 'EUnsafe.uninterruptibleMask'.
uninterruptibleMask :: MonadUnliftIO m => ((forall a. m a -> m a) -> m b) -> m b
uninterruptibleMask f = withRunInIO $ \run -> EUnsafe.uninterruptibleMask $ \unmask ->
run $ f $ liftIO . unmask . run

View File

@ -5,13 +5,12 @@
{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeApplications #-}
-- | Operations from "Control.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'. -- | Operations from "UnliftIO.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
module Control.Effect.Exception.UnliftIO module Control.Effect.UnliftIO.Exception
( -- * Lifted "Control.Exception" operations ( -- * Lifted "UnliftIO.Exception" operations
throwIO throwIO
, ioError
, throwTo , throwTo
, catch , catch
, catches , catches
@ -26,9 +25,6 @@ module Control.Effect.Exception.UnliftIO
, mask_ , mask_
, uninterruptibleMask , uninterruptibleMask
, uninterruptibleMask_ , uninterruptibleMask_
, getMaskingState
, interruptible
, allowInterrupt
, bracket , bracket
, bracket_ , bracket_
, bracketOnError , bracketOnError
@ -46,11 +42,9 @@ module Control.Effect.Exception.UnliftIO
) where ) where
import Control.Concurrent (ThreadId) import Control.Concurrent (ThreadId)
import qualified Control.Effect.Exception.Internal as Exc
import Control.Effect.Lift import Control.Effect.Lift
import Control.Exception hiding import Control.Exception hiding
( Handler ( Handler
, allowInterrupt
, bracket , bracket
, bracketOnError , bracketOnError
, bracket_ , bracket_
@ -59,11 +53,8 @@ import Control.Exception hiding
, catches , catches
, evaluate , evaluate
, finally , finally
, getMaskingState
, handle , handle
, handleJust , handleJust
, interruptible
, ioError
, mask , mask
, mask_ , mask_
, onException , onException
@ -74,44 +65,34 @@ import Control.Exception hiding
, uninterruptibleMask , uninterruptibleMask
, uninterruptibleMask_ , uninterruptibleMask_
) )
import qualified Control.Exception as EUnsafe
import Control.Monad.IO.Unlift (MonadIO, MonadUnliftIO, liftIO, withRunInIO) import Control.Monad.IO.Unlift (MonadIO, MonadUnliftIO, liftIO, withRunInIO)
import Prelude hiding (ioError) import Prelude hiding (ioError)
import qualified UnliftIO.Exception as Exc
-- | See @"Control.Exception".throwIO@. -- | See @"UnliftIO.Exception".'Exc.throwIO'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
throwIO throwIO
:: forall n e sig m a :: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
=> e => e
-> m a -> m a
throwIO = sendM @n . Exc.throwIO throwIO = sendM @n . liftIO . Exc.throwIO
-- | See @"Control.Exception".'Exc.ioError'@. -- | See @"UnliftIO.Exception".'Exc.throwTo'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
ioError
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> IOError
-> m a
ioError = sendM @n . liftIO . EUnsafe.ioError
-- | See @"Control.Exception".'Exc.throwTo'@.
--
-- @since 1.1.2.0
throwTo throwTo
:: forall n e sig m a :: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
=> ThreadId => ThreadId
-> e -> e
-> m () -> m ()
throwTo thread = sendM @n . liftIO . EUnsafe.throwTo thread throwTo thread = sendM @n . liftIO . Exc.throwTo thread
-- | See @"Control.Exception".catch@. -- | See @"UnliftIO.Exception".'Exc.catch'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
catch catch
:: forall n e sig m a :: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -119,33 +100,33 @@ catch
-> (e -> m a) -> (e -> m a)
-> m a -> m a
catch m h = liftWith @n $ catch m h = liftWith @n $
\run ctx -> run (m <$ ctx) `Exc.catch` (run . (<$ ctx) . h) \hdl ctx -> hdl (m <$ ctx) `Exc.catch` (hdl . (<$ ctx) . h)
-- | See @"Control.Exception".catches@. -- | See @"UnliftIO.Exception".'Exc.catches'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
catches catches
:: forall n sig m a :: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
=> m a => m a
-> [Handler m a] -> [Exc.Handler m a]
-> m a -> m a
catches m hs = liftWith @n $ catches m hs = liftWith @n $
\ run ctx -> Exc.catches \ hdl ctx -> Exc.catches
(run (m <$ ctx)) (hdl (m <$ ctx))
(map (\ (Handler h) -> Exc.Handler (run . (<$ ctx) . h)) hs) (map (\ (Exc.Handler h) -> Exc.Handler (hdl . (<$ ctx) . h)) hs)
-- | See @"Control.Exception".'Exc.Handler'@. -- | See @"UnliftIO.Exception".'Exc.Handler'@.
-- --
-- @since <version> -- @since 1.1.2.1
data Handler m a data Handler m a
= forall e . Exc.Exception e => Handler (e -> m a) = forall e . Exc.Exception e => Handler (e -> m a)
deriving instance Functor m => Functor (Handler m) deriving instance Functor m => Functor (Handler m)
-- | See @"Control.Exception".catchJust@. -- | See @"UnliftIO.Exception".'Exc.catchJust'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
catchJust catchJust
:: forall n e sig m a b :: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -154,11 +135,11 @@ catchJust
-> (b -> m a) -> (b -> m a)
-> m a -> m a
catchJust p m h = liftWith @n $ catchJust p m h = liftWith @n $
\ run ctx -> Exc.catchJust p (run (m <$ ctx)) (run . (<$ ctx) . h) \ hdl ctx -> Exc.catchJust p (hdl (m <$ ctx)) (hdl . (<$ ctx) . h)
-- | See @"Control.Exception".'Exc.handle'@. -- | See @"UnliftIO.Exception".'Exc.handle'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
handle handle
:: forall n e sig m a :: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -167,9 +148,9 @@ handle
-> m a -> m a
handle = flip $ catch @n handle = flip $ catch @n
-- | See @"Control.Exception".'Exc.handleJust'@. -- | See @"UnliftIO.Exception".'Exc.handleJust'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
handleJust handleJust
:: forall n e sig m a b :: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -179,9 +160,9 @@ handleJust
-> m a -> m a
handleJust p = flip (catchJust @n p) handleJust p = flip (catchJust @n p)
-- | See @"Control.Exception".'Exc.try'@. -- | See @"UnliftIO.Exception".'Exc.try'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
try try
:: forall n e sig m a b :: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -189,9 +170,9 @@ try
-> m (Either e a) -> m (Either e a)
try m = catch @n (Right <$> m) (pure . Left) try m = catch @n (Right <$> m) (pure . Left)
-- | See @"Control.Exception".'Exc.tryJust'@. -- | See @"UnliftIO.Exception".'Exc.tryJust'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
tryJust tryJust
:: forall n e sig m a b :: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m) . (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -200,26 +181,26 @@ tryJust
-> m (Either b a) -> m (Either b a)
tryJust p m = catchJust @n p (Right <$> m) (pure . Left) tryJust p m = catchJust @n p (Right <$> m) (pure . Left)
-- | See @"Control.Exception".evaluate@. -- | See @"UnliftIO.Exception".evaluate@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
evaluate :: forall n sig m a. (MonadUnliftIO n, Has (Lift n) sig m) => a -> m a evaluate :: forall n sig m a. (MonadUnliftIO n, Has (Lift n) sig m) => a -> m a
evaluate = sendM @n . Exc.evaluate evaluate = sendM @n . Exc.evaluate
-- | See @"Control.Exception".mask@. -- | See @"UnliftIO.Exception".mask@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
mask mask
:: forall n sig m a b :: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
=> ((forall a . m a -> m a) -> m b) => ((forall a. m a -> m a) -> m b)
-> m b -> m b
mask with = liftWith @n $ \ run ctx -> Exc.mask $ \ restore -> mask with = liftWith @n $ \ hdl ctx -> Exc.mask $ \ restore ->
run (with (\ m -> liftWith $ \ run' ctx' -> restore (run' (m <$ ctx'))) <$ ctx) hdl (with (\ m -> liftWith $ \ hdl' ctx' -> restore (hdl' (m <$ ctx'))) <$ ctx)
-- | See @"Control.Exception".'Exc.mask_'@. -- | See @"UnliftIO.Exception".'Exc.mask_'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
mask_ mask_
:: forall n sig m a :: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
@ -227,22 +208,22 @@ mask_
-> m a -> m a
mask_ m = mask @n (const m) mask_ m = mask @n (const m)
-- | See @"Control.Exception".uninterruptibleMask@. -- | See @"UnliftIO.Exception".uninterruptibleMask@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
uninterruptibleMask uninterruptibleMask
:: forall n sig m a b :: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
=> ((forall a . m a -> m a) -> m b) => ((forall a. m a -> m a) -> m b)
-> m b -> m b
uninterruptibleMask with = liftWith @n $ uninterruptibleMask with = liftWith @n $
\ run ctx -> Exc.uninterruptibleMask $ \ restore -> \ hdl ctx -> Exc.uninterruptibleMask $ \ restore ->
run (with (\ m -> liftWith $ hdl (with (\ m -> liftWith $
\ run' ctx' -> restore (run' (m <$ ctx'))) <$ ctx) \ hdl' ctx' -> restore (hdl' (m <$ ctx'))) <$ ctx)
-- | See @"Control.Exception".'Exc.uninterruptibleMask_'@. -- | See @"UnliftIO.Exception".'Exc.uninterruptibleMask_'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
uninterruptibleMask_ uninterruptibleMask_
:: forall n sig m a :: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
@ -250,38 +231,9 @@ uninterruptibleMask_
-> m a -> m a
uninterruptibleMask_ m = uninterruptibleMask @n (const m) uninterruptibleMask_ m = uninterruptibleMask @n (const m)
-- | See @"Control.Exception".'Exc.getMaskingState'@. -- | See @"UnliftIO.Exception".'Exc.bracket'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
getMaskingState
:: forall n sig m
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m EUnsafe.MaskingState
getMaskingState = sendM @n (liftIO EUnsafe.getMaskingState)
-- | See @"Control.Exception".'Exc.interruptible'@.
--
-- @since 1.1.2.0
interruptible
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m a
-> m a
interruptible m = liftWith @n $ \ run ctx -> withRunInIO $ \runInIO ->
EUnsafe.interruptible (runInIO $ run (m <$ ctx))
-- | See @"Control.Exception".'Exc.allowInterrupt'@.
--
-- @since 1.1.2.0
allowInterrupt
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m ()
allowInterrupt = sendM @n (liftIO EUnsafe.allowInterrupt)
-- | See @"Control.Exception".'Exc.bracket'@.
--
-- @since 1.1.2.0
bracket bracket
:: forall n sig m a b c :: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
@ -289,14 +241,12 @@ bracket
-> (a -> m b) -> (a -> m b)
-> (a -> m c) -> (a -> m c)
-> m c -> m c
bracket acquire release m = mask @n $ \ restore -> do bracket acquire release m = liftWith @n $ \ hdl ctx ->
a <- acquire Exc.bracket (hdl (acquire <$ ctx)) (hdl . (release <$>)) (hdl . (m <$>))
r <- onException @n (restore $ m a) (release a)
r <$ release a
-- | See @"Control.Exception".'Exc.bracket_'@. -- | See @"UnliftIO.Exception".'Exc.bracket_'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
bracket_ bracket_
:: forall n sig m a b c :: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
@ -306,9 +256,9 @@ bracket_
-> m c -> m c
bracket_ before after thing = bracket @n before (const after) (const thing) bracket_ before after thing = bracket @n before (const after) (const thing)
-- | See @"Control.Exception".'Exc.bracketOnError'@. -- | See @"UnliftIO.Exception".'Exc.bracketOnError'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
bracketOnError bracketOnError
:: forall n sig m a b c :: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
@ -316,30 +266,32 @@ bracketOnError
-> (a -> m b) -> (a -> m b)
-> (a -> m c) -> (a -> m c)
-> m c -> m c
bracketOnError acquire release m = mask @n $ \ restore -> do bracketOnError acquire release m = liftWith @n $ \ hdl ctx ->
a <- acquire Exc.bracketOnError
onException @n (restore $ m a) (release a) (hdl (acquire <$ ctx))
(hdl . (release <$>))
(hdl . (m <$>))
-- | See @"Control.Exception".'Exc.finally'@. -- | See @"UnliftIO.Exception".'Exc.finally'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
finally finally
:: forall n sig m a b :: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
=> m a => m a
-> m b -> m b
-> m a -> m a
finally m sequel = mask @n $ finally m sequel = liftWith @n $ \ hdl ctx ->
\ restore -> onException @n (restore m) sequel <* sequel Exc.finally (hdl (m <$ ctx)) (hdl (sequel <$ ctx))
-- | See @"Control.Exception".'Exc.onException'@. -- | See @"UnliftIO.Exception".'Exc.onException'@.
-- --
-- @since 1.1.2.0 -- @since 1.1.2.1
onException onException
:: forall n sig m a b :: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m) . (MonadUnliftIO n, Has (Lift n) sig m)
=> m a => m a
-> m b -> m b
-> m a -> m a
onException io what = catch @n io $ onException io what = liftWith @n $ \ hdl ctx ->
\e -> what >> throwIO @n @Exc.SomeException e Exc.onException (hdl (io <$ ctx)) (hdl (what <$ ctx))