(original) (raw)
import multiprocessing as mp import time class BrokenBarrierError(Exception): pass class Barrier(object): def __init__(self, parties, action=None, timeout=None, action_args=()): if parties <= 0: raise ValueError('parties must be greater than 0') self._parties = parties self._action = action self._action_args = action_args self._timeout = timeout self._lock = mp.RLock() self._counter = mp.Semaphore(parties-1) self._wait_sem = mp.Semaphore(0) self._broken = mp.Semaphore(0) def wait(self, timeout=None): # When each thread enters the semaphore it tries to do a # non-blocking acquire on _counter. Since the original value # of _counter was parties-1, the last thread to enter will # fail to acquire the semaphore. This final thread is the # "control" thread, and it is responsible for waking the # threads which arrived before it, waiting for them to # respond, calling the action (if any) and resetting barrier. # wait() returns 0 from the control thread; from all other # threads it returns -1. if timeout is None: timeout = self._timeout with self._lock: if self.broken: raise BrokenBarrierError try: if self._counter.acquire(timeout=0): # we are not the control thread self._lock.release() try: # - wait to be woken by control thread if not self._wait_sem.acquire(timeout=timeout): raise BrokenBarrierError res = -1 finally: self._counter.release() self._lock.acquire() else: # we are the control thread # - release the early arrivers for i in range(self._parties-1): self._wait_sem.release() # - wait for all early arrivers to wake up for i in range(self._parties-1): temp = self._counter.acquire(timeout=5) assert temp # - reset state of the barrier for i in range(self._parties-1): self._counter.release() # - carry out action and return if self._action is not None: self._action(*self._action_args) res = 0 except: self.abort() raise if self.broken: raise BrokenBarrierError return res def abort(self): with self._lock: if self.broken: return self._broken.release() # release any waiters for i in range(self._parties - 1): self._wait_sem.release() @property def broken(self): return not self._broken._semlock._is_zero() @property def parties(self): return self._parties @property def n_waiting(self): with self._lock: if self.broken: raise BrokenBarrierError return (self._parties - 1) - self._counter.get_value()