GHC/Conc/Sync.lhs (original) (raw)

\begin{code}

#include "Typeable.h"

module GHC.Conc.Sync ( ThreadId(..)

    , forkIO        
    , forkIOUnmasked
    , forkIOWithUnmask
    , forkOn      
    , forkOnIO    
    , forkOnIOUnmasked
    , forkOnWithUnmask
    , numCapabilities 
    , getNumCapabilities 
    , setNumCapabilities 
    , getNumProcessors   
    , numSparks      
    , childHandler  
    , myThreadId    
    , killThread    
    , throwTo       
    , par           
    , pseq          
    , runSparks
    , yield         
    , labelThread   
    , mkWeakThreadId 

    , ThreadStatus(..), BlockReason(..)
    , threadStatus  
    , threadCapability

    
    , STM(..)
    , atomically    
    , retry         
    , orElse        
    , throwSTM      
    , catchSTM      
    , alwaysSucceeds 
    , always        
    , TVar(..)
    , newTVar       
    , newTVarIO     
    , readTVar      
    , readTVarIO    
    , writeTVar     
    , unsafeIOToSTM 

    
    , withMVar
    , modifyMVar_

    , setUncaughtExceptionHandler      
    , getUncaughtExceptionHandler      

    , reportError, reportStackOverflow

    , sharedCAF
    ) where

import Foreign hiding (unsafePerformIO) import Foreign.C

#ifdef mingw32_HOST_OS import Data.Typeable #endif

#ifndef mingw32_HOST_OS import Data.Dynamic #endif import Control.Monad import Data.Maybe

import GHC.Base import GHC.IO.Handle ( hFlush ) import GHC.IO.Handle.FD ( stdout ) import GHC.IO import GHC.IO.Encoding.UTF8 import GHC.IO.Exception import GHC.Exception import qualified GHC.Foreign import GHC.IORef import GHC.MVar import GHC.Ptr import GHC.Real ( fromIntegral ) import GHC.Show ( Show(..), showString ) import GHC.Weak

infixr 0 par, pseq

\end{code} %************************************************************************ %* * \subsection{@ThreadId@, @par@, and @fork@} %* * %************************************************************************ \begin{code}

data ThreadId = ThreadId ThreadId# deriving( Typeable )

instance Show ThreadId where showsPrec d t = showString "ThreadId " . showsPrec d (getThreadId (id2TSO t))

foreign import ccall unsafe "rts_getThreadId" getThreadId :: ThreadId# -> CInt

id2TSO :: ThreadId -> ThreadId# id2TSO (ThreadId t) = t

foreign import ccall unsafe "cmp_thread" cmp_thread :: ThreadId# -> ThreadId# -> CInt

cmpThread :: ThreadId -> ThreadId -> Ordering cmpThread t1 t2 = case cmp_thread (id2TSO t1) (id2TSO t2) of 1 -> LT 0 -> EQ _ -> GT

instance Eq ThreadId where t1 == t2 = case t1 cmpThread t2 of EQ -> True _ -> False

instance Ord ThreadId where compare = cmpThread

forkIO :: IO () -> IO ThreadId forkIO action = IO $ \ s -> case (fork# action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) where action_plus = catchException action childHandler

forkIOUnmasked :: IO () -> IO ThreadId forkIOUnmasked io = forkIO (unsafeUnmask io)

forkIOWithUnmask :: ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId forkIOWithUnmask io = forkIO (io unsafeUnmask)

forkOn :: Int -> IO () -> IO ThreadId forkOn (I# cpu) action = IO $ \ s -> case (forkOn# cpu action_plus s) of (# s1, tid #) -> (# s1, ThreadId tid #) where action_plus = catchException action childHandler

forkOnIO :: Int -> IO () -> IO ThreadId forkOnIO = forkOn

forkOnIOUnmasked :: Int -> IO () -> IO ThreadId forkOnIOUnmasked cpu io = forkOn cpu (unsafeUnmask io)

forkOnWithUnmask :: Int -> ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId forkOnWithUnmask cpu io = forkOn cpu (io unsafeUnmask)

numCapabilities :: Int numCapabilities = unsafePerformIO $ getNumCapabilities

getNumCapabilities :: IO Int getNumCapabilities = do n <- peek n_capabilities return (fromIntegral n)

setNumCapabilities :: Int -> IO () setNumCapabilities i = c_setNumCapabilities (fromIntegral i)

foreign import ccall safe "setNumCapabilities" c_setNumCapabilities :: CUInt -> IO ()

getNumProcessors :: IO Int getNumProcessors = fmap fromIntegral c_getNumberOfProcessors

foreign import ccall unsafe "getNumberOfProcessors" c_getNumberOfProcessors :: IO CUInt

numSparks :: IO Int numSparks = IO $ \s -> case numSparks# s of (# s', n #) -> (# s', I# n #)

foreign import ccall "&n_capabilities" n_capabilities :: Ptr CInt

childHandler :: SomeException -> IO () childHandler err = catchException (real_handler err) childHandler

real_handler :: SomeException -> IO () real_handler se@(SomeException ex) =

case cast ex of Just BlockedIndefinitelyOnMVar -> return () _ -> case cast ex of Just BlockedIndefinitelyOnSTM -> return () _ -> case cast ex of Just ThreadKilled -> return () _ -> case cast ex of

             Just StackOverflow     -> reportStackOverflow
             _                      -> reportError se

killThread :: ThreadId -> IO () killThread tid = throwTo tid ThreadKilled

throwTo :: Exception e => ThreadId -> e -> IO () throwTo (ThreadId tid) ex = IO $ \ s -> case (killThread# tid (toException ex) s) of s1 -> (# s1, () #)

myThreadId :: IO ThreadId myThreadId = IO $ \s -> case (myThreadId# s) of (# s1, tid #) -> (# s1, ThreadId tid #)

yield :: IO () yield = IO $ \s -> case (yield# s) of s1 -> (# s1, () #)

labelThread :: ThreadId -> String -> IO () labelThread (ThreadId t) str = GHC.Foreign.withCString utf8 str $ (Ptr p) -> IO $ \ s -> case labelThread# t p s of s1 -> (# s1, () #)

pseq :: a -> b -> b pseq x y = x seq lazy y

par :: a -> b -> b par x y = case (par# x) of { _ -> lazy y }

runSparks :: IO () runSparks = IO loop where loop s = case getSpark# s of (# s', n, p #) -> if n ==# 0# then (# s', () #) else p seq loop s'

data BlockReason = BlockedOnMVar

| BlockedOnBlackHole

| BlockedOnException

| BlockedOnSTM

| BlockedOnForeignCall

| BlockedOnOther

deriving (Eq,Ord,Show)

data ThreadStatus = ThreadRunning

| ThreadFinished

| ThreadBlocked BlockReason

| ThreadDied

deriving (Eq,Ord,Show)

threadStatus :: ThreadId -> IO ThreadStatus threadStatus (ThreadId t) = IO $ \s -> case threadStatus# t s of (# s', stat, _cap, _locked #) -> (# s', mk_stat (I# stat) #) where

 mk_stat 0  = ThreadRunning
 mk_stat 1  = ThreadBlocked BlockedOnMVar
 mk_stat 2  = ThreadBlocked BlockedOnBlackHole
 mk_stat 6  = ThreadBlocked BlockedOnSTM
 mk_stat 10 = ThreadBlocked BlockedOnForeignCall
 mk_stat 11 = ThreadBlocked BlockedOnForeignCall
 mk_stat 12 = ThreadBlocked BlockedOnException
 mk_stat 16 = ThreadFinished
 mk_stat 17 = ThreadDied
 mk_stat _  = ThreadBlocked BlockedOnOther

threadCapability :: ThreadId -> IO (Int, Bool) threadCapability (ThreadId t) = IO $ \s -> case threadStatus# t s of (# s', _, cap#, locked# #) -> (# s', (I# cap#, locked# /=# 0#) #)

mkWeakThreadId :: ThreadId -> IO (Weak ThreadId) mkWeakThreadId t@(ThreadId t#) = IO $ \s -> case mkWeakNoFinalizer# t# t s of (# s1, w #) -> (# s1, Weak w #)

\end{code} %************************************************************************ %* * \subsection[stm]{Transactional heap operations} %* * %************************************************************************ TVars are shared memory locations which support atomic memory transactions. \begin{code}

newtype STM a = STM (State# RealWorld -> (# State# RealWorld, a #))

unSTM :: STM a -> (State# RealWorld -> (# State# RealWorld, a #)) unSTM (STM a) = a

INSTANCE_TYPEABLE1(STM,stmTc,"STM")

instance Functor STM where fmap f x = x >>= (return . f)

instance Monad STM where

m >> k      = thenSTM m k
return x    = returnSTM x
m >>= k     = bindSTM m k

bindSTM :: STM a -> (a -> STM b) -> STM b bindSTM (STM m) k = STM ( \s -> case m s of (# new_s, a #) -> unSTM (k a) new_s )

thenSTM :: STM a -> STM b -> STM b thenSTM (STM m) k = STM ( \s -> case m s of (# new_s, _ #) -> unSTM k new_s )

returnSTM :: a -> STM a returnSTM x = STM (\s -> (# s, x #))

instance MonadPlus STM where mzero = retry mplus = orElse

unsafeIOToSTM :: IO a -> STM a unsafeIOToSTM (IO m) = STM m

atomically :: STM a -> IO a atomically (STM m) = IO (\s -> (atomically# m) s )

retry :: STM a retry = STM $ \s# -> retry# s#

orElse :: STM a -> STM a -> STM a orElse (STM m) e = STM $ \s -> catchRetry# m (unSTM e) s

throwSTM :: Exception e => e -> STM a throwSTM e = STM $ raiseIO# (toException e)

catchSTM :: Exception e => STM a -> (e -> STM a) -> STM a catchSTM (STM m) handler = STM $ catchSTM# m handler' where handler' e = case fromException e of Just e' -> unSTM (handler e') Nothing -> raiseIO# e

checkInv :: STM a -> STM () checkInv (STM m) = STM (\s -> (check# m) s)

alwaysSucceeds :: STM a -> STM () alwaysSucceeds i = do ( i >> retry ) orElse ( return () ) checkInv i

always :: STM Bool -> STM () always i = alwaysSucceeds ( do v <- i if (v) then return () else ( error "Transactional invariant violation" ) )

data TVar a = TVar (TVar# RealWorld a)

INSTANCE_TYPEABLE1(TVar,tvarTc,"TVar")

instance Eq (TVar a) where (TVar tvar1#) == (TVar tvar2#) = sameTVar# tvar1# tvar2#

newTVar :: a -> STM (TVar a) newTVar val = STM $ \s1# -> case newTVar# val s1# of (# s2#, tvar# #) -> (# s2#, TVar tvar# #)

newTVarIO :: a -> IO (TVar a) newTVarIO val = IO $ \s1# -> case newTVar# val s1# of (# s2#, tvar# #) -> (# s2#, TVar tvar# #)

readTVarIO :: TVar a -> IO a readTVarIO (TVar tvar#) = IO $ \s# -> readTVarIO# tvar# s#

readTVar :: TVar a -> STM a readTVar (TVar tvar#) = STM $ \s# -> readTVar# tvar# s#

writeTVar :: TVar a -> a -> STM () writeTVar (TVar tvar#) val = STM $ \s1# -> case writeTVar# tvar# val s1# of s2# -> (# s2#, () #)

\end{code} MVar utilities \begin{code}

withMVar :: MVar a -> (a -> IO b) -> IO b withMVar m io = mask $ \restore -> do a <- takeMVar m b <- catchAny (restore (io a)) (\e -> do putMVar m a; throw e) putMVar m a return b

modifyMVar_ :: MVar a -> (a -> IO a) -> IO () modifyMVar_ m io = mask $ \restore -> do a <- takeMVar m a' <- catchAny (restore (io a)) (\e -> do putMVar m a; throw e) putMVar m a' return ()

\end{code} %************************************************************************ %* * \subsection{Thread waiting} %* * %************************************************************************ \begin{code}

sharedCAF :: a -> (Ptr a -> IO (Ptr a)) -> IO a sharedCAF a get_or_set = mask_ $ do stable_ref <- newStablePtr a let ref = castPtr (castStablePtrToPtr stable_ref) ref2 <- get_or_set ref if ref==ref2 then return a else do freeStablePtr stable_ref deRefStablePtr (castPtrToStablePtr (castPtr ref2))

reportStackOverflow :: IO () reportStackOverflow = callStackOverflowHook

reportError :: SomeException -> IO () reportError ex = do handler <- getUncaughtExceptionHandler handler ex

foreign import ccall unsafe "stackOverflow" callStackOverflowHook :: IO ()

uncaughtExceptionHandler :: IORef (SomeException -> IO ()) uncaughtExceptionHandler = unsafePerformIO (newIORef defaultHandler) where defaultHandler :: SomeException -> IO () defaultHandler se@(SomeException ex) = do (hFlush stdout) catchAny (\ _ -> return ()) let msg = case cast ex of Just Deadlock -> "no threads to run: infinite loop or deadlock?" _ -> case cast ex of Just (ErrorCall s) -> s _ -> showsPrec 0 se "" withCString "%s" $ \cfmt -> withCString msg $ \cmsg -> errorBelch cfmt cmsg

foreign import ccall unsafe "HsBase.h errorBelch2" errorBelch :: CString -> CString -> IO ()

setUncaughtExceptionHandler :: (SomeException -> IO ()) -> IO () setUncaughtExceptionHandler = writeIORef uncaughtExceptionHandler

getUncaughtExceptionHandler :: IO (SomeException -> IO ()) getUncaughtExceptionHandler = readIORef uncaughtExceptionHandler

\end{code}