bpo-43352: Add a Barrier object in asyncio lib (GH-24903) · python/cpython@d03acd7 (original) (raw)

1

1

`"""Synchronization primitives."""

`

2

2

``

3

``

`-

all = ('Lock', 'Event', 'Condition', 'Semaphore', 'BoundedSemaphore')

`

``

3

`+

all = ('Lock', 'Event', 'Condition', 'Semaphore',

`

``

4

`+

'BoundedSemaphore', 'Barrier')

`

4

5

``

5

6

`import collections

`

``

7

`+

import enum

`

6

8

``

7

9

`from . import exceptions

`

8

10

`from . import mixins

`

9

11

`from . import tasks

`

10

12

``

11

``

-

12

13

`class _ContextManagerMixin:

`

13

14

`async def aenter(self):

`

14

15

`await self.acquire()

`

`@@ -416,3 +417,155 @@ def release(self):

`

416

417

`if self._value >= self._bound_value:

`

417

418

`raise ValueError('BoundedSemaphore released too many times')

`

418

419

`super().release()

`

``

420

+

``

421

+

``

422

+

``

423

`+

class _BarrierState(enum.Enum):

`

``

424

`+

FILLING = 'filling'

`

``

425

`+

DRAINING = 'draining'

`

``

426

`+

RESETTING = 'resetting'

`

``

427

`+

BROKEN = 'broken'

`

``

428

+

``

429

+

``

430

`+

class Barrier(mixins._LoopBoundMixin):

`

``

431

`+

"""Asyncio equivalent to threading.Barrier

`

``

432

+

``

433

`+

Implements a Barrier primitive.

`

``

434

`+

Useful for synchronizing a fixed number of tasks at known synchronization

`

``

435

`+

points. Tasks block on 'wait()' and are simultaneously awoken once they

`

``

436

`+

have all made their call.

`

``

437

`+

"""

`

``

438

+

``

439

`+

def init(self, parties):

`

``

440

`+

"""Create a barrier, initialised to 'parties' tasks."""

`

``

441

`+

if parties < 1:

`

``

442

`+

raise ValueError('parties must be > 0')

`

``

443

+

``

444

`+

self._cond = Condition() # notify all tasks when state changes

`

``

445

+

``

446

`+

self._parties = parties

`

``

447

`+

self._state = _BarrierState.FILLING

`

``

448

`+

self._count = 0 # count tasks in Barrier

`

``

449

+

``

450

`+

def repr(self):

`

``

451

`+

res = super().repr()

`

``

452

`+

extra = f'{self._state.value}'

`

``

453

`+

if not self.broken:

`

``

454

`+

extra += f', waiters:{self.n_waiting}/{self.parties}'

`

``

455

`+

return f'<{res[1:-1]} [{extra}]>'

`

``

456

+

``

457

`+

async def aenter(self):

`

``

458

`+

wait for the barrier reaches the parties number

`

``

459

`+

when start draining release and return index of waited task

`

``

460

`+

return await self.wait()

`

``

461

+

``

462

`+

async def aexit(self, *args):

`

``

463

`+

pass

`

``

464

+

``

465

`+

async def wait(self):

`

``

466

`+

"""Wait for the barrier.

`

``

467

+

``

468

`+

When the specified number of tasks have started waiting, they are all

`

``

469

`+

simultaneously awoken.

`

``

470

`+

Returns an unique and individual index number from 0 to 'parties-1'.

`

``

471

`+

"""

`

``

472

`+

async with self._cond:

`

``

473

`+

await self._block() # Block while the barrier drains or resets.

`

``

474

`+

try:

`

``

475

`+

index = self._count

`

``

476

`+

self._count += 1

`

``

477

`+

if index + 1 == self._parties:

`

``

478

`+

We release the barrier

`

``

479

`+

await self._release()

`

``

480

`+

else:

`

``

481

`+

await self._wait()

`

``

482

`+

return index

`

``

483

`+

finally:

`

``

484

`+

self._count -= 1

`

``

485

`+

Wake up any tasks waiting for barrier to drain.

`

``

486

`+

self._exit()

`

``

487

+

``

488

`+

async def _block(self):

`

``

489

`+

Block until the barrier is ready for us,

`

``

490

`+

or raise an exception if it is broken.

`

``

491

`+

`

``

492

`+

It is draining or resetting, wait until done

`

``

493

`+

unless a CancelledError occurs

`

``

494

`+

await self._cond.wait_for(

`

``

495

`+

lambda: self._state not in (

`

``

496

`+

_BarrierState.DRAINING, _BarrierState.RESETTING

`

``

497

`+

)

`

``

498

`+

)

`

``

499

+

``

500

`+

see if the barrier is in a broken state

`

``

501

`+

if self._state is _BarrierState.BROKEN:

`

``

502

`+

raise exceptions.BrokenBarrierError("Barrier aborted")

`

``

503

+

``

504

`+

async def _release(self):

`

``

505

`+

Release the tasks waiting in the barrier.

`

``

506

+

``

507

`+

Enter draining state.

`

``

508

`+

Next waiting tasks will be blocked until the end of draining.

`

``

509

`+

self._state = _BarrierState.DRAINING

`

``

510

`+

self._cond.notify_all()

`

``

511

+

``

512

`+

async def _wait(self):

`

``

513

`+

Wait in the barrier until we are released. Raise an exception

`

``

514

`+

if the barrier is reset or broken.

`

``

515

+

``

516

`+

wait for end of filling

`

``

517

`+

unless a CancelledError occurs

`

``

518

`+

await self._cond.wait_for(lambda: self._state is not _BarrierState.FILLING)

`

``

519

+

``

520

`+

if self._state in (_BarrierState.BROKEN, _BarrierState.RESETTING):

`

``

521

`+

raise exceptions.BrokenBarrierError("Abort or reset of barrier")

`

``

522

+

``

523

`+

def _exit(self):

`

``

524

`+

If we are the last tasks to exit the barrier, signal any tasks

`

``

525

`+

waiting for the barrier to drain.

`

``

526

`+

if self._count == 0:

`

``

527

`+

if self._state in (_BarrierState.RESETTING, _BarrierState.DRAINING):

`

``

528

`+

self._state = _BarrierState.FILLING

`

``

529

`+

self._cond.notify_all()

`

``

530

+

``

531

`+

async def reset(self):

`

``

532

`+

"""Reset the barrier to the initial state.

`

``

533

+

``

534

`+

Any tasks currently waiting will get the BrokenBarrier exception

`

``

535

`+

raised.

`

``

536

`+

"""

`

``

537

`+

async with self._cond:

`

``

538

`+

if self._count > 0:

`

``

539

`+

if self._state is not _BarrierState.RESETTING:

`

``

540

`+

#reset the barrier, waking up tasks

`

``

541

`+

self._state = _BarrierState.RESETTING

`

``

542

`+

else:

`

``

543

`+

self._state = _BarrierState.FILLING

`

``

544

`+

self._cond.notify_all()

`

``

545

+

``

546

`+

async def abort(self):

`

``

547

`+

"""Place the barrier into a 'broken' state.

`

``

548

+

``

549

`+

Useful in case of error. Any currently waiting tasks and tasks

`

``

550

`+

attempting to 'wait()' will have BrokenBarrierError raised.

`

``

551

`+

"""

`

``

552

`+

async with self._cond:

`

``

553

`+

self._state = _BarrierState.BROKEN

`

``

554

`+

self._cond.notify_all()

`

``

555

+

``

556

`+

@property

`

``

557

`+

def parties(self):

`

``

558

`+

"""Return the number of tasks required to trip the barrier."""

`

``

559

`+

return self._parties

`

``

560

+

``

561

`+

@property

`

``

562

`+

def n_waiting(self):

`

``

563

`+

"""Return the number of tasks currently waiting at the barrier."""

`

``

564

`+

if self._state is _BarrierState.FILLING:

`

``

565

`+

return self._count

`

``

566

`+

return 0

`

``

567

+

``

568

`+

@property

`

``

569

`+

def broken(self):

`

``

570

`+

"""Return True if the barrier is in a broken state."""

`

``

571

`+

return self._state is _BarrierState.BROKEN

`