bpo-43352: Add a Barrier object in asyncio lib (GH-24903) · python/cpython@d03acd7 (original) (raw)
1
1
`"""Synchronization primitives."""
`
2
2
``
3
``
`-
all = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')
`
``
3
`+
all = ('Lock', 'Event', 'Condition', 'Semaphore',
`
``
4
`+
'BoundedSemaphore', 'Barrier')
`
4
5
``
5
6
`import collections
`
``
7
`+
import enum
`
6
8
``
7
9
`from . import exceptions
`
8
10
`from . import mixins
`
9
11
`from . import tasks
`
10
12
``
11
``
-
12
13
`class _ContextManagerMixin:
`
13
14
`async def aenter(self):
`
14
15
`await self.acquire()
`
`@@ -416,3 +417,155 @@ def release(self):
`
416
417
`if self._value >= self._bound_value:
`
417
418
`raise ValueError('BoundedSemaphore released too many times')
`
418
419
`super().release()
`
``
420
+
``
421
+
``
422
+
``
423
`+
class _BarrierState(enum.Enum):
`
``
424
`+
FILLING = 'filling'
`
``
425
`+
DRAINING = 'draining'
`
``
426
`+
RESETTING = 'resetting'
`
``
427
`+
BROKEN = 'broken'
`
``
428
+
``
429
+
``
430
`+
class Barrier(mixins._LoopBoundMixin):
`
``
431
`+
"""Asyncio equivalent to threading.Barrier
`
``
432
+
``
433
`+
Implements a Barrier primitive.
`
``
434
`+
Useful for synchronizing a fixed number of tasks at known synchronization
`
``
435
`+
points. Tasks block on 'wait()' and are simultaneously awoken once they
`
``
436
`+
have all made their call.
`
``
437
`+
"""
`
``
438
+
``
439
`+
def init(self, parties):
`
``
440
`+
"""Create a barrier, initialised to 'parties' tasks."""
`
``
441
`+
if parties < 1:
`
``
442
`+
raise ValueError('parties must be > 0')
`
``
443
+
``
444
`+
self._cond = Condition() # notify all tasks when state changes
`
``
445
+
``
446
`+
self._parties = parties
`
``
447
`+
self._state = _BarrierState.FILLING
`
``
448
`+
self._count = 0 # count tasks in Barrier
`
``
449
+
``
450
`+
def repr(self):
`
``
451
`+
res = super().repr()
`
``
452
`+
extra = f'{self._state.value}'
`
``
453
`+
if not self.broken:
`
``
454
`+
extra += f', waiters:{self.n_waiting}/{self.parties}'
`
``
455
`+
return f'<{res[1:-1]} [{extra}]>'
`
``
456
+
``
457
`+
async def aenter(self):
`
``
458
`+
wait for the barrier reaches the parties number
`
``
459
`+
when start draining release and return index of waited task
`
``
460
`+
return await self.wait()
`
``
461
+
``
462
`+
async def aexit(self, *args):
`
``
463
`+
pass
`
``
464
+
``
465
`+
async def wait(self):
`
``
466
`+
"""Wait for the barrier.
`
``
467
+
``
468
`+
When the specified number of tasks have started waiting, they are all
`
``
469
`+
simultaneously awoken.
`
``
470
`+
Returns an unique and individual index number from 0 to 'parties-1'.
`
``
471
`+
"""
`
``
472
`+
async with self._cond:
`
``
473
`+
await self._block() # Block while the barrier drains or resets.
`
``
474
`+
try:
`
``
475
`+
index = self._count
`
``
476
`+
self._count += 1
`
``
477
`+
if index + 1 == self._parties:
`
``
478
`+
We release the barrier
`
``
479
`+
await self._release()
`
``
480
`+
else:
`
``
481
`+
await self._wait()
`
``
482
`+
return index
`
``
483
`+
finally:
`
``
484
`+
self._count -= 1
`
``
485
`+
Wake up any tasks waiting for barrier to drain.
`
``
486
`+
self._exit()
`
``
487
+
``
488
`+
async def _block(self):
`
``
489
`+
Block until the barrier is ready for us,
`
``
490
`+
or raise an exception if it is broken.
`
``
491
`+
`
``
492
`+
It is draining or resetting, wait until done
`
``
493
`+
unless a CancelledError occurs
`
``
494
`+
await self._cond.wait_for(
`
``
495
`+
lambda: self._state not in (
`
``
496
`+
_BarrierState.DRAINING, _BarrierState.RESETTING
`
``
497
`+
)
`
``
498
`+
)
`
``
499
+
``
500
`+
see if the barrier is in a broken state
`
``
501
`+
if self._state is _BarrierState.BROKEN:
`
``
502
`+
raise exceptions.BrokenBarrierError("Barrier aborted")
`
``
503
+
``
504
`+
async def _release(self):
`
``
505
`+
Release the tasks waiting in the barrier.
`
``
506
+
``
507
`+
Enter draining state.
`
``
508
`+
Next waiting tasks will be blocked until the end of draining.
`
``
509
`+
self._state = _BarrierState.DRAINING
`
``
510
`+
self._cond.notify_all()
`
``
511
+
``
512
`+
async def _wait(self):
`
``
513
`+
Wait in the barrier until we are released. Raise an exception
`
``
514
`+
if the barrier is reset or broken.
`
``
515
+
``
516
`+
wait for end of filling
`
``
517
`+
unless a CancelledError occurs
`
``
518
`+
await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)
`
``
519
+
``
520
`+
if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):
`
``
521
`+
raise exceptions.BrokenBarrierError("Abort or reset of barrier")
`
``
522
+
``
523
`+
def _exit(self):
`
``
524
`+
If we are the last tasks to exit the barrier, signal any tasks
`
``
525
`+
waiting for the barrier to drain.
`
``
526
`+
if self._count == 0:
`
``
527
`+
if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):
`
``
528
`+
self._state = _BarrierState.FILLING
`
``
529
`+
self._cond.notify_all()
`
``
530
+
``
531
`+
async def reset(self):
`
``
532
`+
"""Reset the barrier to the initial state.
`
``
533
+
``
534
`+
Any tasks currently waiting will get the BrokenBarrier exception
`
``
535
`+
raised.
`
``
536
`+
"""
`
``
537
`+
async with self._cond:
`
``
538
`+
if self._count > 0:
`
``
539
`+
if self._state is not _BarrierState.RESETTING:
`
``
540
`+
#reset the barrier, waking up tasks
`
``
541
`+
self._state = _BarrierState.RESETTING
`
``
542
`+
else:
`
``
543
`+
self._state = _BarrierState.FILLING
`
``
544
`+
self._cond.notify_all()
`
``
545
+
``
546
`+
async def abort(self):
`
``
547
`+
"""Place the barrier into a 'broken' state.
`
``
548
+
``
549
`+
Useful in case of error. Any currently waiting tasks and tasks
`
``
550
`+
attempting to 'wait()' will have BrokenBarrierError raised.
`
``
551
`+
"""
`
``
552
`+
async with self._cond:
`
``
553
`+
self._state = _BarrierState.BROKEN
`
``
554
`+
self._cond.notify_all()
`
``
555
+
``
556
`+
@property
`
``
557
`+
def parties(self):
`
``
558
`+
"""Return the number of tasks required to trip the barrier."""
`
``
559
`+
return self._parties
`
``
560
+
``
561
`+
@property
`
``
562
`+
def n_waiting(self):
`
``
563
`+
"""Return the number of tasks currently waiting at the barrier."""
`
``
564
`+
if self._state is _BarrierState.FILLING:
`
``
565
`+
return self._count
`
``
566
`+
return 0
`
``
567
+
``
568
`+
@property
`
``
569
`+
def broken(self):
`
``
570
`+
"""Return True if the barrier is in a broken state."""
`
``
571
`+
return self._state is _BarrierState.BROKEN
`