Keep same `unliftio` exception semantics.

main
Joshua Potter 2022-04-03 12:13:20 -04:00
parent e4aa4c78be
commit 54b92afc06
3 changed files with 73 additions and 254 deletions

View File

@ -32,14 +32,13 @@ library
hs-source-dirs: src
exposed-modules:
Control.Effect.Exception
Control.Effect.Exception.UnliftIO
other-modules:
Control.Effect.Exception.Internal
Control.Effect.UnliftIO.Exception
build-depends:
base >= 4.7 && < 5
, fused-effects >= 1.1
, transformers >= 0.4 && < 0.6
, unliftio-core >= 0.2 && < 0.3
, unliftio >= 0.2 && < 0.3
test-suite test
import: common

View File

@ -1,132 +0,0 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}
-- | Operations from "UnliftIO.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'.
--
-- These methods are shamelessly copied from the @unliftio@ module in an effort to keep dependencies small.
-- @unliftio-core@ is assumed fair game though considering it's included in core @fused-effects@ package.
--
-- @since <version>
module Control.Effect.Exception.Internal
( -- * Throwing
throwIO
-- * Catching (with recovery)
, catch
, catchJust
, Handler (..)
, catches
-- * Evaluation
, evaluate
-- * Masking
, mask
, uninterruptibleMask
-- * Reexports
, Exception (..)
, Typeable
, SomeException (..)
, SomeAsyncException (..)
, IOException
) where
import Control.Exception (Exception(..), IOException, SomeAsyncException(..), SomeException(..))
import qualified Control.Exception as EUnsafe
import Control.Monad.IO.Unlift
import Data.Typeable (Typeable, cast)
-- | Catch a synchronous (but not asynchronous) exception and recover from it.
--
-- This is parameterized on the exception type. To catch all synchronous exceptions,
-- use 'catchAny'.
catch
:: (MonadUnliftIO m, Exception e)
=> m a -- ^ action
-> (e -> m a) -- ^ handler
-> m a
catch f g = withRunInIO $ \run -> run f `EUnsafe.catch` \e ->
if isSyncException e
then run (g e)
-- intentionally rethrowing an async exception synchronously,
-- since we want to preserve async behavior
else EUnsafe.throwIO e
-- | 'catchJust' is like 'catch' but it takes an extra argument which
-- is an exception predicate, a function which selects which type of
-- exceptions we're interested in.
catchJust :: (MonadUnliftIO m, Exception e) => (e -> Maybe b) -> m a -> (b -> m a) -> m a
catchJust f a b = a `catch` \e -> maybe (liftIO (throwIO e)) b $ f e
-- | A helper data type for usage with 'catches' and similar functions.
data Handler m a = forall e . Exception e => Handler (e -> m a)
-- | Internal.
catchesHandler :: MonadIO m => [Handler m a] -> SomeException -> m a
catchesHandler handlers e = foldr tryHandler (liftIO (EUnsafe.throwIO e)) handlers
where tryHandler (Handler handler) res
= case fromException e of
Just e' -> handler e'
Nothing -> res
-- | Similar to 'catch', but provides multiple different handler functions.
--
-- For more information on motivation, see @base@'s 'EUnsafe.catches'. Note that,
-- unlike that function, this function will not catch asynchronous exceptions.
catches :: MonadUnliftIO m => m a -> [Handler m a] -> m a
catches io handlers = io `catch` catchesHandler handlers
-- | Lifted version of 'EUnsafe.evaluate'.
evaluate :: MonadIO m => a -> m a
evaluate = liftIO . EUnsafe.evaluate
-- | Synchronously throw the given exception.
--
-- Note that, if you provide an exception value which is of an asynchronous
-- type, it will be wrapped up in 'SyncExceptionWrapper'. See 'toSyncException'.
throwIO :: (MonadIO m, Exception e) => e -> m a
throwIO = liftIO . EUnsafe.throwIO . toSyncException
-- | Wrap up an asynchronous exception to be treated as a synchronous
-- exception.
--
-- This is intended to be created via 'toSyncException'.
data SyncExceptionWrapper = forall e. Exception e => SyncExceptionWrapper e
deriving Typeable
instance Show SyncExceptionWrapper where
show (SyncExceptionWrapper e) = show e
instance Exception SyncExceptionWrapper where
#if MIN_VERSION_base(4,8,0)
displayException (SyncExceptionWrapper e) = displayException e
#endif
-- | Convert an exception into a synchronous exception.
--
-- For synchronous exceptions, this is the same as 'toException'.
-- For asynchronous exceptions, this will wrap up the exception with
-- 'SyncExceptionWrapper'.
toSyncException :: Exception e => e -> SomeException
toSyncException e =
case fromException se of
Just (SomeAsyncException _) -> toException (SyncExceptionWrapper e)
Nothing -> se
where
se = toException e
-- | Check if the given exception is synchronous.
isSyncException :: Exception e => e -> Bool
isSyncException e =
case fromException (toException e) of
Just (SomeAsyncException _) -> False
Nothing -> True
-- | Unlifted version of 'EUnsafe.mask'.
mask :: MonadUnliftIO m => ((forall a. m a -> m a) -> m b) -> m b
mask f = withRunInIO $ \run -> EUnsafe.mask $ \unmask ->
run $ f $ liftIO . unmask . run
-- | Unlifted version of 'EUnsafe.uninterruptibleMask'.
uninterruptibleMask :: MonadUnliftIO m => ((forall a. m a -> m a) -> m b) -> m b
uninterruptibleMask f = withRunInIO $ \run -> EUnsafe.uninterruptibleMask $ \unmask ->
run $ f $ liftIO . unmask . run

View File

@ -5,13 +5,12 @@
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
-- | Operations from "Control.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'.
-- | Operations from "UnliftIO.Exception" lifted into effectful contexts using 'Control.Effect.Lift.Lift'.
--
-- @since 1.1.2.0
module Control.Effect.Exception.UnliftIO
( -- * Lifted "Control.Exception" operations
-- @since 1.1.2.1
module Control.Effect.UnliftIO.Exception
( -- * Lifted "UnliftIO.Exception" operations
throwIO
, ioError
, throwTo
, catch
, catches
@ -26,9 +25,6 @@ module Control.Effect.Exception.UnliftIO
, mask_
, uninterruptibleMask
, uninterruptibleMask_
, getMaskingState
, interruptible
, allowInterrupt
, bracket
, bracket_
, bracketOnError
@ -46,11 +42,9 @@ module Control.Effect.Exception.UnliftIO
) where
import Control.Concurrent (ThreadId)
import qualified Control.Effect.Exception.Internal as Exc
import Control.Effect.Lift
import Control.Exception hiding
( Handler
, allowInterrupt
, bracket
, bracketOnError
, bracket_
@ -59,11 +53,8 @@ import Control.Exception hiding
, catches
, evaluate
, finally
, getMaskingState
, handle
, handleJust
, interruptible
, ioError
, mask
, mask_
, onException
@ -74,44 +65,34 @@ import Control.Exception hiding
, uninterruptibleMask
, uninterruptibleMask_
)
import qualified Control.Exception as EUnsafe
import Control.Monad.IO.Unlift (MonadIO, MonadUnliftIO, liftIO, withRunInIO)
import Prelude hiding (ioError)
import qualified UnliftIO.Exception as Exc
-- | See @"Control.Exception".throwIO@.
-- | See @"UnliftIO.Exception".'Exc.throwIO'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
throwIO
:: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
=> e
-> m a
throwIO = sendM @n . Exc.throwIO
throwIO = sendM @n . liftIO . Exc.throwIO
-- | See @"Control.Exception".'Exc.ioError'@.
-- | See @"UnliftIO.Exception".'Exc.throwTo'@.
--
-- @since 1.1.2.0
ioError
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> IOError
-> m a
ioError = sendM @n . liftIO . EUnsafe.ioError
-- | See @"Control.Exception".'Exc.throwTo'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
throwTo
:: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
=> ThreadId
-> e
-> m ()
throwTo thread = sendM @n . liftIO . EUnsafe.throwTo thread
throwTo thread = sendM @n . liftIO . Exc.throwTo thread
-- | See @"Control.Exception".catch@.
-- | See @"UnliftIO.Exception".'Exc.catch'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
catch
:: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -119,33 +100,33 @@ catch
-> (e -> m a)
-> m a
catch m h = liftWith @n $
\run ctx -> run (m <$ ctx) `Exc.catch` (run . (<$ ctx) . h)
\hdl ctx -> hdl (m <$ ctx) `Exc.catch` (hdl . (<$ ctx) . h)
-- | See @"Control.Exception".catches@.
-- | See @"UnliftIO.Exception".'Exc.catches'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
catches
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m a
-> [Handler m a]
-> [Exc.Handler m a]
-> m a
catches m hs = liftWith @n $
\ run ctx -> Exc.catches
(run (m <$ ctx))
(map (\ (Handler h) -> Exc.Handler (run . (<$ ctx) . h)) hs)
\ hdl ctx -> Exc.catches
(hdl (m <$ ctx))
(map (\ (Exc.Handler h) -> Exc.Handler (hdl . (<$ ctx) . h)) hs)
-- | See @"Control.Exception".'Exc.Handler'@.
-- | See @"UnliftIO.Exception".'Exc.Handler'@.
--
-- @since <version>
-- @since 1.1.2.1
data Handler m a
= forall e . Exc.Exception e => Handler (e -> m a)
deriving instance Functor m => Functor (Handler m)
-- | See @"Control.Exception".catchJust@.
-- | See @"UnliftIO.Exception".'Exc.catchJust'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
catchJust
:: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -154,11 +135,11 @@ catchJust
-> (b -> m a)
-> m a
catchJust p m h = liftWith @n $
\ run ctx -> Exc.catchJust p (run (m <$ ctx)) (run . (<$ ctx) . h)
\ hdl ctx -> Exc.catchJust p (hdl (m <$ ctx)) (hdl . (<$ ctx) . h)
-- | See @"Control.Exception".'Exc.handle'@.
-- | See @"UnliftIO.Exception".'Exc.handle'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
handle
:: forall n e sig m a
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -167,9 +148,9 @@ handle
-> m a
handle = flip $ catch @n
-- | See @"Control.Exception".'Exc.handleJust'@.
-- | See @"UnliftIO.Exception".'Exc.handleJust'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
handleJust
:: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -179,9 +160,9 @@ handleJust
-> m a
handleJust p = flip (catchJust @n p)
-- | See @"Control.Exception".'Exc.try'@.
-- | See @"UnliftIO.Exception".'Exc.try'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
try
:: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -189,9 +170,9 @@ try
-> m (Either e a)
try m = catch @n (Right <$> m) (pure . Left)
-- | See @"Control.Exception".'Exc.tryJust'@.
-- | See @"UnliftIO.Exception".'Exc.tryJust'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
tryJust
:: forall n e sig m a b
. (MonadUnliftIO n, Exc.Exception e, Has (Lift n) sig m)
@ -200,26 +181,26 @@ tryJust
-> m (Either b a)
tryJust p m = catchJust @n p (Right <$> m) (pure . Left)
-- | See @"Control.Exception".evaluate@.
-- | See @"UnliftIO.Exception".evaluate@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
evaluate :: forall n sig m a. (MonadUnliftIO n, Has (Lift n) sig m) => a -> m a
evaluate = sendM @n . Exc.evaluate
-- | See @"Control.Exception".mask@.
-- | See @"UnliftIO.Exception".mask@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
mask
:: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m)
=> ((forall a . m a -> m a) -> m b)
=> ((forall a. m a -> m a) -> m b)
-> m b
mask with = liftWith @n $ \ run ctx -> Exc.mask $ \ restore ->
run (with (\ m -> liftWith $ \ run' ctx' -> restore (run' (m <$ ctx'))) <$ ctx)
mask with = liftWith @n $ \ hdl ctx -> Exc.mask $ \ restore ->
hdl (with (\ m -> liftWith $ \ hdl' ctx' -> restore (hdl' (m <$ ctx'))) <$ ctx)
-- | See @"Control.Exception".'Exc.mask_'@.
-- | See @"UnliftIO.Exception".'Exc.mask_'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
mask_
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
@ -227,22 +208,22 @@ mask_
-> m a
mask_ m = mask @n (const m)
-- | See @"Control.Exception".uninterruptibleMask@.
-- | See @"UnliftIO.Exception".uninterruptibleMask@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
uninterruptibleMask
:: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m)
=> ((forall a . m a -> m a) -> m b)
=> ((forall a. m a -> m a) -> m b)
-> m b
uninterruptibleMask with = liftWith @n $
\ run ctx -> Exc.uninterruptibleMask $ \ restore ->
run (with (\ m -> liftWith $
\ run' ctx' -> restore (run' (m <$ ctx'))) <$ ctx)
\ hdl ctx -> Exc.uninterruptibleMask $ \ restore ->
hdl (with (\ m -> liftWith $
\ hdl' ctx' -> restore (hdl' (m <$ ctx'))) <$ ctx)
-- | See @"Control.Exception".'Exc.uninterruptibleMask_'@.
-- | See @"UnliftIO.Exception".'Exc.uninterruptibleMask_'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
uninterruptibleMask_
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
@ -250,38 +231,9 @@ uninterruptibleMask_
-> m a
uninterruptibleMask_ m = uninterruptibleMask @n (const m)
-- | See @"Control.Exception".'Exc.getMaskingState'@.
-- | See @"UnliftIO.Exception".'Exc.bracket'@.
--
-- @since 1.1.2.0
getMaskingState
:: forall n sig m
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m EUnsafe.MaskingState
getMaskingState = sendM @n (liftIO EUnsafe.getMaskingState)
-- | See @"Control.Exception".'Exc.interruptible'@.
--
-- @since 1.1.2.0
interruptible
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m a
-> m a
interruptible m = liftWith @n $ \ run ctx -> withRunInIO $ \runInIO ->
EUnsafe.interruptible (runInIO $ run (m <$ ctx))
-- | See @"Control.Exception".'Exc.allowInterrupt'@.
--
-- @since 1.1.2.0
allowInterrupt
:: forall n sig m a
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m ()
allowInterrupt = sendM @n (liftIO EUnsafe.allowInterrupt)
-- | See @"Control.Exception".'Exc.bracket'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
bracket
:: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m)
@ -289,14 +241,12 @@ bracket
-> (a -> m b)
-> (a -> m c)
-> m c
bracket acquire release m = mask @n $ \ restore -> do
a <- acquire
r <- onException @n (restore $ m a) (release a)
r <$ release a
bracket acquire release m = liftWith @n $ \ hdl ctx ->
Exc.bracket (hdl (acquire <$ ctx)) (hdl . (release <$>)) (hdl . (m <$>))
-- | See @"Control.Exception".'Exc.bracket_'@.
-- | See @"UnliftIO.Exception".'Exc.bracket_'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
bracket_
:: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m)
@ -306,9 +256,9 @@ bracket_
-> m c
bracket_ before after thing = bracket @n before (const after) (const thing)
-- | See @"Control.Exception".'Exc.bracketOnError'@.
-- | See @"UnliftIO.Exception".'Exc.bracketOnError'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
bracketOnError
:: forall n sig m a b c
. (MonadUnliftIO n, Has (Lift n) sig m)
@ -316,30 +266,32 @@ bracketOnError
-> (a -> m b)
-> (a -> m c)
-> m c
bracketOnError acquire release m = mask @n $ \ restore -> do
a <- acquire
onException @n (restore $ m a) (release a)
bracketOnError acquire release m = liftWith @n $ \ hdl ctx ->
Exc.bracketOnError
(hdl (acquire <$ ctx))
(hdl . (release <$>))
(hdl . (m <$>))
-- | See @"Control.Exception".'Exc.finally'@.
-- | See @"UnliftIO.Exception".'Exc.finally'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
finally
:: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m a
-> m b
-> m a
finally m sequel = mask @n $
\ restore -> onException @n (restore m) sequel <* sequel
finally m sequel = liftWith @n $ \ hdl ctx ->
Exc.finally (hdl (m <$ ctx)) (hdl (sequel <$ ctx))
-- | See @"Control.Exception".'Exc.onException'@.
-- | See @"UnliftIO.Exception".'Exc.onException'@.
--
-- @since 1.1.2.0
-- @since 1.1.2.1
onException
:: forall n sig m a b
. (MonadUnliftIO n, Has (Lift n) sig m)
=> m a
-> m b
-> m a
onException io what = catch @n io $
\e -> what >> throwIO @n @Exc.SomeException e
onException io what = liftWith @n $ \ hdl ctx ->
Exc.onException (hdl (io <$ ctx)) (hdl (what <$ ctx))