(original) (raw)
import unittest import os import time import multiprocessing as mp from barrier import Barrier, BrokenBarrierError # Many of the tests for threading.Barrier use a list as an atomic # counter: a value is appended to increment the counter, and the # length of the list gives the value. We use the class DummyList # for the same purpose. class DummyList(object): def __init__(self): self._length = mp.RawValue('i', 0) self._lock = mp.Lock() def append(self, _): with self._lock: self._length.value += 1 def __len__(self): with self._lock: return self._length.value def _wait(): # A crude wait/yield function not relying on synchronization primitives. time.sleep(0.01) class Bunch(object): """ A bunch of threads. """ def __init__(self, f, args, n, wait_before_exit=False): """ Construct a bunch of `n` threads running the same function `f`. If `wait_before_exit` is True, the threads won't terminate until do_finish() is called. """ self.f = f self.args = args self.n = n self.started = DummyList() self.finished = DummyList() self._can_exit = mp.Value('i', not wait_before_exit) for i in range(n): mp.Process(target=self.task).start() def task(self): pid = os.getpid() self.started.append(pid) try: self.f(*self.args) finally: self.finished.append(pid) while not self._can_exit.value: _wait() def wait_for_started(self): while len(self.started) < self.n: _wait() def wait_for_finished(self): while len(self.finished) < self.n: _wait() def do_finish(self): self._can_exit.value = True class BarrierTests(unittest.TestCase): """ Tests for Barrier objects. """ N = 5 defaultTimeout = 2.0 barriertype = Barrier def setUp(self): self.barrier = self.barriertype(self.N, timeout=self.defaultTimeout) def tearDown(self): self.barrier.abort() def run_threads(self, f, args): b = Bunch(f, args, self.N-1) f(*args) b.wait_for_finished() def multipass(self, results, n): m = self.barrier.parties self.assertEqual(m, self.N) for i in range(n): results[0].append(True) self.assertEqual(len(results[1]), i * m) self.barrier.wait() results[1].append(True) self.assertEqual(len(results[0]), (i + 1) * m) self.barrier.wait() try: self.assertEqual(self.barrier.n_waiting, 0) except NotImplementedError: pass self.assertFalse(self.barrier.broken) def test_barrier(self, passes=1): """ Test that a barrier is passed in lockstep """ results = [DummyList(), DummyList()] self.run_threads(self.multipass, (results, passes)) def test_barrier_10(self): """ Test that a barrier works for 10 consecutive runs """ return self.test_barrier(10) def _test_wait_return_f(self, q): r = self.barrier.wait() q.put(r) def test_wait_return(self): """ test the return value from barrier.wait """ q = mp.Queue() self.run_threads(self._test_wait_return_f, (q,)) results = [q.get() for i in range(self.N)] self.assertEqual(results.count(0), 1) def _test_action_action(self, results): results.append(True) def _test_action_f(self, results, barrier): barrier.wait() self.assertEqual(len(results), 1) def test_action(self): """ Test the 'action' callback """ results = DummyList() barrier = self.barriertype(self.N, action=self._test_action_action, action_args=(results,)) self.run_threads(self._test_action_f, (results, barrier)) def _test_abort_f(self, results1, results2): try: i = self.barrier.wait() if i == 0: raise RuntimeError self.barrier.wait() results1.append(True) except BrokenBarrierError: results2.append(True) except RuntimeError: self.barrier.abort() pass def test_abort(self): """ Test that an abort will put the barrier in a broken state """ results1 = DummyList() results2 = DummyList() self.run_threads(self._test_abort_f, (results1, results2)) self.assertEqual(len(results1), 0) self.assertEqual(len(results2), self.N-1) self.assertTrue(self.barrier.broken) # def test_reset(self): # """ # Test that a 'reset' on a barrier frees the waiting threads # """ # results1 = [] # results2 = [] # results3 = [] # def f(): # i = self.barrier.wait() # if i == self.N//2: # # Wait until the other threads are all in the barrier. # while self.barrier.n_waiting < self.N-1: # time.sleep(0.001) # self.barrier.reset() # else: # try: # self.barrier.wait() # results1.append(True) # except threading.BrokenBarrierError: # results2.append(True) # # Now, pass the barrier again # self.barrier.wait() # results3.append(True) # self.run_threads(f) # self.assertEqual(len(results1), 0) # self.assertEqual(len(results2), self.N-1) # self.assertEqual(len(results3), self.N) # def test_abort_and_reset(self): # """ # Test that a barrier can be reset after being broken. # """ # results1 = [] # results2 = [] # results3 = [] # barrier2 = self.barriertype(self.N) # def f(): # try: # i = self.barrier.wait() # if i == self.N//2: # raise RuntimeError # self.barrier.wait() # results1.append(True) # except threading.BrokenBarrierError: # results2.append(True) # except RuntimeError: # self.barrier.abort() # pass # # Synchronize and reset the barrier. Must synchronize first so # # that everyone has left it when we reset, and after so that no # # one enters it before the reset. # if barrier2.wait() == self.N//2: # self.barrier.reset() # barrier2.wait() # self.barrier.wait() # results3.append(True) # self.run_threads(f) # self.assertEqual(len(results1), 0) # self.assertEqual(len(results2), self.N-1) # self.assertEqual(len(results3), self.N) def _test_timeout_f(self): i = self.barrier.wait() if i == 0: # One thread is late! time.sleep(1.0) # Default timeout is 2.0, so this is shorter. self.assertRaises(BrokenBarrierError, self.barrier.wait, 0.5) def test_timeout(self): """ Test wait(timeout) """ self.run_threads(self._test_timeout_f, ()) def _test_default_timeout_f(self, barrier): i = barrier.wait() if i == 0: # One thread is later than the default timeout of 0.6s. time.sleep(1.5) self.assertRaises(BrokenBarrierError, barrier.wait) def test_default_timeout(self): """ Test the barrier's default timeout """ # XXX for Windows debug build, original timeout of 0.3 was too small # create a barrier with a low default timeout barrier = self.barriertype(self.N, timeout=0.6) self.run_threads(self._test_default_timeout_f, (barrier,)) def test_single_thread(self): b = self.barriertype(1) b.wait() b.wait() def _test_thousand_f(self, passes, conn, lock): for i in range(passes): self.barrier.wait() with lock: conn.send(i) def test_thousand(self): passes = 1000 lock = mp.Lock() conn, child_conn = mp.Pipe(False) for j in range(self.N): p = mp.Process(target=self._test_thousand_f, args=(passes, child_conn, lock)) p.start() for i in range(passes): for j in range(self.N): self.assertEqual(conn.recv(), i) if __name__ == '__main__': unittest.main()