GHC/Conc/Sync.lhs (original) (raw)
\begin{code}
#include "Typeable.h"
module GHC.Conc.Sync ( ThreadId(..)
, forkIO
, forkIOUnmasked
, forkIOWithUnmask
, forkOn
, forkOnIO
, forkOnIOUnmasked
, forkOnWithUnmask
, numCapabilities
, getNumCapabilities
, setNumCapabilities
, getNumProcessors
, numSparks
, childHandler
, myThreadId
, killThread
, throwTo
, par
, pseq
, runSparks
, yield
, labelThread
, mkWeakThreadId
, ThreadStatus(..), BlockReason(..)
, threadStatus
, threadCapability
, STM(..)
, atomically
, retry
, orElse
, throwSTM
, catchSTM
, alwaysSucceeds
, always
, TVar(..)
, newTVar
, newTVarIO
, readTVar
, readTVarIO
, writeTVar
, unsafeIOToSTM
, withMVar
, modifyMVar_
, setUncaughtExceptionHandler
, getUncaughtExceptionHandler
, reportError, reportStackOverflow
, sharedCAF
) where
import Foreign hiding (unsafePerformIO) import Foreign.C
#ifdef mingw32_HOST_OS import Data.Typeable #endif
#ifndef mingw32_HOST_OS import Data.Dynamic #endif import Control.Monad import Data.Maybe
import GHC.Base import GHC.IO.Handle ( hFlush ) import GHC.IO.Handle.FD ( stdout ) import GHC.IO import GHC.IO.Encoding.UTF8 import GHC.IO.Exception import GHC.Exception import qualified GHC.Foreign import GHC.IORef import GHC.MVar import GHC.Ptr import GHC.Real ( fromIntegral ) import GHC.Show ( Show(..), showString ) import GHC.Weak
infixr 0 par
, pseq
\end{code} %************************************************************************ %* * \subsection{@ThreadId@, @par@, and @fork@} %* * %************************************************************************ \begin{code}
data ThreadId = ThreadId ThreadId# deriving( Typeable )
instance Show ThreadId where showsPrec d t = showString "ThreadId " . showsPrec d (getThreadId (id2TSO t))
foreign import ccall unsafe "rts_getThreadId" getThreadId :: ThreadId# -> CInt
id2TSO :: ThreadId -> ThreadId# id2TSO (ThreadId t) = t
foreign import ccall unsafe "cmp_thread" cmp_thread :: ThreadId# -> ThreadId# -> CInt
cmpThread :: ThreadId -> ThreadId -> Ordering cmpThread t1 t2 = case cmp_thread (id2TSO t1) (id2TSO t2) of 1 -> LT 0 -> EQ _ -> GT
instance Eq ThreadId where
t1 == t2 =
case t1 cmpThread
t2 of
EQ -> True
_ -> False
instance Ord ThreadId where compare = cmpThread
forkIO :: IO () -> IO ThreadId forkIO action = IO $ \ s -> case (fork# action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) where action_plus = catchException action childHandler
forkIOUnmasked :: IO () -> IO ThreadId forkIOUnmasked io = forkIO (unsafeUnmask io)
forkIOWithUnmask :: ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId forkIOWithUnmask io = forkIO (io unsafeUnmask)
forkOn :: Int -> IO () -> IO ThreadId forkOn (I# cpu) action = IO $ \ s -> case (forkOn# cpu action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) where action_plus = catchException action childHandler
forkOnIO :: Int -> IO () -> IO ThreadId forkOnIO = forkOn
forkOnIOUnmasked :: Int -> IO () -> IO ThreadId forkOnIOUnmasked cpu io = forkOn cpu (unsafeUnmask io)
forkOnWithUnmask :: Int -> ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId forkOnWithUnmask cpu io = forkOn cpu (io unsafeUnmask)
numCapabilities :: Int numCapabilities = unsafePerformIO $ getNumCapabilities
getNumCapabilities :: IO Int getNumCapabilities = do n <- peek n_capabilities return (fromIntegral n)
setNumCapabilities :: Int -> IO () setNumCapabilities i = c_setNumCapabilities (fromIntegral i)
foreign import ccall safe "setNumCapabilities" c_setNumCapabilities :: CUInt -> IO ()
getNumProcessors :: IO Int getNumProcessors = fmap fromIntegral c_getNumberOfProcessors
foreign import ccall unsafe "getNumberOfProcessors" c_getNumberOfProcessors :: IO CUInt
numSparks :: IO Int numSparks = IO $ \s -> case numSparks# s of (# s', n #) -> (# s', I# n #)
foreign import ccall "&n_capabilities" n_capabilities :: Ptr CInt
childHandler :: SomeException -> IO () childHandler err = catchException (real_handler err) childHandler
real_handler :: SomeException -> IO () real_handler se@(SomeException ex) =
case cast ex of Just BlockedIndefinitelyOnMVar -> return () _ -> case cast ex of Just BlockedIndefinitelyOnSTM -> return () _ -> case cast ex of Just ThreadKilled -> return () _ -> case cast ex of
Just StackOverflow -> reportStackOverflow
_ -> reportError se
killThread :: ThreadId -> IO () killThread tid = throwTo tid ThreadKilled
throwTo :: Exception e => ThreadId -> e -> IO () throwTo (ThreadId tid) ex = IO $ \ s -> case (killThread# tid (toException ex) s) of s1 -> (# s1, () #)
myThreadId :: IO ThreadId myThreadId = IO $ \s -> case (myThreadId# s) of (# s1, tid #) -> (# s1, ThreadId tid #)
yield :: IO () yield = IO $ \s -> case (yield# s) of s1 -> (# s1, () #)
labelThread :: ThreadId -> String -> IO () labelThread (ThreadId t) str = GHC.Foreign.withCString utf8 str $ (Ptr p) -> IO $ \ s -> case labelThread# t p s of s1 -> (# s1, () #)
pseq :: a -> b -> b
pseq x y = x seq
lazy y
par :: a -> b -> b par x y = case (par# x) of { _ -> lazy y }
runSparks :: IO ()
runSparks = IO loop
where loop s = case getSpark# s of
(# s', n, p #) ->
if n ==# 0# then (# s', () #)
else p seq
loop s'
data BlockReason = BlockedOnMVar
| BlockedOnBlackHole
| BlockedOnException
| BlockedOnSTM
| BlockedOnForeignCall
| BlockedOnOther
deriving (Eq,Ord,Show)
data ThreadStatus = ThreadRunning
| ThreadFinished
| ThreadBlocked BlockReason
| ThreadDied
deriving (Eq,Ord,Show)
threadStatus :: ThreadId -> IO ThreadStatus threadStatus (ThreadId t) = IO $ \s -> case threadStatus# t s of (# s', stat, _cap, _locked #) -> (# s', mk_stat (I# stat) #) where
mk_stat 0 = ThreadRunning
mk_stat 1 = ThreadBlocked BlockedOnMVar
mk_stat 2 = ThreadBlocked BlockedOnBlackHole
mk_stat 6 = ThreadBlocked BlockedOnSTM
mk_stat 10 = ThreadBlocked BlockedOnForeignCall
mk_stat 11 = ThreadBlocked BlockedOnForeignCall
mk_stat 12 = ThreadBlocked BlockedOnException
mk_stat 16 = ThreadFinished
mk_stat 17 = ThreadDied
mk_stat _ = ThreadBlocked BlockedOnOther
threadCapability :: ThreadId -> IO (Int, Bool) threadCapability (ThreadId t) = IO $ \s -> case threadStatus# t s of (# s', _, cap#, locked# #) -> (# s', (I# cap#, locked# /=# 0#) #)
mkWeakThreadId :: ThreadId -> IO (Weak ThreadId) mkWeakThreadId t@(ThreadId t#) = IO $ \s -> case mkWeakNoFinalizer# t# t s of (# s1, w #) -> (# s1, Weak w #)
\end{code} %************************************************************************ %* * \subsection[stm]{Transactional heap operations} %* * %************************************************************************ TVars are shared memory locations which support atomic memory transactions. \begin{code}
newtype STM a = STM (State# RealWorld -> (# State# RealWorld, a #))
unSTM :: STM a -> (State# RealWorld -> (# State# RealWorld, a #)) unSTM (STM a) = a
INSTANCE_TYPEABLE1(STM,stmTc,"STM")
instance Functor STM where fmap f x = x >>= (return . f)
instance Monad STM where
m >> k = thenSTM m k
return x = returnSTM x
m >>= k = bindSTM m k
bindSTM :: STM a -> (a -> STM b) -> STM b bindSTM (STM m) k = STM ( \s -> case m s of (# new_s, a #) -> unSTM (k a) new_s )
thenSTM :: STM a -> STM b -> STM b thenSTM (STM m) k = STM ( \s -> case m s of (# new_s, _ #) -> unSTM k new_s )
returnSTM :: a -> STM a returnSTM x = STM (\s -> (# s, x #))
instance MonadPlus STM where mzero = retry mplus = orElse
unsafeIOToSTM :: IO a -> STM a unsafeIOToSTM (IO m) = STM m
atomically :: STM a -> IO a atomically (STM m) = IO (\s -> (atomically# m) s )
retry :: STM a retry = STM $ \s# -> retry# s#
orElse :: STM a -> STM a -> STM a orElse (STM m) e = STM $ \s -> catchRetry# m (unSTM e) s
throwSTM :: Exception e => e -> STM a throwSTM e = STM $ raiseIO# (toException e)
catchSTM :: Exception e => STM a -> (e -> STM a) -> STM a catchSTM (STM m) handler = STM $ catchSTM# m handler' where handler' e = case fromException e of Just e' -> unSTM (handler e') Nothing -> raiseIO# e
checkInv :: STM a -> STM () checkInv (STM m) = STM (\s -> (check# m) s)
alwaysSucceeds :: STM a -> STM ()
alwaysSucceeds i = do ( i >> retry ) orElse
( return () )
checkInv i
always :: STM Bool -> STM () always i = alwaysSucceeds ( do v <- i if (v) then return () else ( error "Transactional invariant violation" ) )
data TVar a = TVar (TVar# RealWorld a)
INSTANCE_TYPEABLE1(TVar,tvarTc,"TVar")
instance Eq (TVar a) where (TVar tvar1#) == (TVar tvar2#) = sameTVar# tvar1# tvar2#
newTVar :: a -> STM (TVar a) newTVar val = STM $ \s1# -> case newTVar# val s1# of (# s2#, tvar# #) -> (# s2#, TVar tvar# #)
newTVarIO :: a -> IO (TVar a) newTVarIO val = IO $ \s1# -> case newTVar# val s1# of (# s2#, tvar# #) -> (# s2#, TVar tvar# #)
readTVarIO :: TVar a -> IO a readTVarIO (TVar tvar#) = IO $ \s# -> readTVarIO# tvar# s#
readTVar :: TVar a -> STM a readTVar (TVar tvar#) = STM $ \s# -> readTVar# tvar# s#
writeTVar :: TVar a -> a -> STM () writeTVar (TVar tvar#) val = STM $ \s1# -> case writeTVar# tvar# val s1# of s2# -> (# s2#, () #)
\end{code} MVar utilities \begin{code}
withMVar :: MVar a -> (a -> IO b) -> IO b withMVar m io = mask $ \restore -> do a <- takeMVar m b <- catchAny (restore (io a)) (\e -> do putMVar m a; throw e) putMVar m a return b
modifyMVar_ :: MVar a -> (a -> IO a) -> IO () modifyMVar_ m io = mask $ \restore -> do a <- takeMVar m a' <- catchAny (restore (io a)) (\e -> do putMVar m a; throw e) putMVar m a' return ()
\end{code} %************************************************************************ %* * \subsection{Thread waiting} %* * %************************************************************************ \begin{code}
sharedCAF :: a -> (Ptr a -> IO (Ptr a)) -> IO a sharedCAF a get_or_set = mask_ $ do stable_ref <- newStablePtr a let ref = castPtr (castStablePtrToPtr stable_ref) ref2 <- get_or_set ref if ref==ref2 then return a else do freeStablePtr stable_ref deRefStablePtr (castPtrToStablePtr (castPtr ref2))
reportStackOverflow :: IO () reportStackOverflow = callStackOverflowHook
reportError :: SomeException -> IO () reportError ex = do handler <- getUncaughtExceptionHandler handler ex
foreign import ccall unsafe "stackOverflow" callStackOverflowHook :: IO ()
uncaughtExceptionHandler :: IORef (SomeException -> IO ())
uncaughtExceptionHandler = unsafePerformIO (newIORef defaultHandler)
where
defaultHandler :: SomeException -> IO ()
defaultHandler se@(SomeException ex) = do
(hFlush stdout) catchAny
(\ _ -> return ())
let msg = case cast ex of
Just Deadlock -> "no threads to run: infinite loop or deadlock?"
_ -> case cast ex of
Just (ErrorCall s) -> s
_ -> showsPrec 0 se ""
withCString "%s" $ \cfmt ->
withCString msg $ \cmsg ->
errorBelch cfmt cmsg
foreign import ccall unsafe "HsBase.h errorBelch2" errorBelch :: CString -> CString -> IO ()
setUncaughtExceptionHandler :: (SomeException -> IO ()) -> IO () setUncaughtExceptionHandler = writeIORef uncaughtExceptionHandler
getUncaughtExceptionHandler :: IO (SomeException -> IO ()) getUncaughtExceptionHandler = readIORef uncaughtExceptionHandler
\end{code}