Raise at next Checkpoint if Non-awaited coroutine found. by Carreau · Pull Request #176 · python-trio/trio (original) (raw)

@@ -0,0 +1,146 @@

"""

This module provides utilities to protect against non-awaited coroutine.

Mostly it provide a protector which can install itself with

`sys.set_coroutine_wrapper` and track the creation of all coroutines.

Every now and then we can go over all the coroutines we have reference to, and

check their state. In trio, the trio-runner will do that, at least on every

checkpoint, but that's not the responsibility of this module.

If the coroutine have been awaited at least once, we discard them.

A :class:`CoroProtector` also provide a convenience method

:meth:`await_later(coro)` that return the coroutine unchanged but will ignore it

if not-awaited at next checkpoint.

A default instance of coroutine protector is provided under the attribute `protector`,

and is shared between `trio.run` and the `MultiError.catch`

"""

import sys

import inspect

import textwrap

from ._exceptions import NonAwaitedCoroutines

try:

from tracemalloc import get_object_traceback as _get_tb

except ImportError: # Not available on, for example, PyPy

def _get_tb(obj):

return None

__all__ = ["CoroProtector", "protector"]

################################################################

# Protection against non-awaited coroutines

################################################################

class CoroProtector:

"""

Protector preventing the creation of non-awaited coroutines

between two checkpoints.

"""

def __init__(self):

self._enabled = True

self._pending_test = set()

self._key = object()

self._previous_coro_wrapper = None

def _coro_wrapper(self, coro):

"""

Coroutine wrapper to track creation of coroutines.

"""

if self._enabled:

self._pending_test.add(coro)

if not self._previous_coro_wrapper:

return coro

else:

return self._previous_coro_wrapper(coro)

def await_later(self, coro):

"""

Mark a coroutine as safe to no be awaited, and return it.

"""

self._pending_test.discard(coro)

return coro

def install(self) -> None:

"""install a coroutine wrapper to track created coroutines.

If a coroutine wrapper is already set wrap and call it.

"""

self._previous_coro_wrapper = sys.get_coroutine_wrapper()

sys.set_coroutine_wrapper(self._coro_wrapper)

def uninstall(self) -> None:

assert sys.get_coroutine_wrapper() == self._coro_wrapper

sys.set_coroutine_wrapper(self._previous_coro_wrapper)

def has_unawaited_coroutines(self) -> bool:

"""

Return whether there are unawaited coroutines.

Flush all internally tracked awaited coroutine. Does not discard non-awaited

ones. You need to call `pop_all_unawaited_coroutines` to do that.

"""

return len(self.get_all_unawaited_coroutines()) > 0

def get_all_unawaited_coroutines(self):

state = inspect.getcoroutinestate

self._pending_test = {

coro

for coro in self._pending_test if state(coro) == 'CORO_CREATED'

}

return set(self._pending_test)

def forget(self, coroutines) -> None:

self._pending_test.difference_update(coroutines)

def pop_all_unawaited_coroutines(self):

"""

Check that since last invocation no coroutine has been left unawaited.

Return a list of unawaited coroutines since last call to this function,

and stop tracking them.

"""

coros = self.get_all_unawaited_coroutines()

self._pending_test = set()

return coros

@staticmethod

def make_non_awaited_coroutines_error(coros):

"""

Construct a nice NonAwaitedCoroutines error messages with the origin of the

coroutine if possible.

"""

err = []

for coro in coros:

tb = _get_tb(coro)

if tb:

err.append(' - {coro} ({tb})'.format(coro=coro, tb=tb)

) # pragma: no cover

else:

err.append(' - {coro}'.format(coro=coro))

err = '\n'.join(err)

return NonAwaitedCoroutines(

textwrap.dedent(

'''

One or more coroutines where not awaited:

{err}

Trio has detected that at least a coroutine has not been between awaited

between this checkpoint point and previous one. This is may be due

to a missing `await`.

''' [1:]

).format(err=err),

coroutines=coros

)

protector = CoroProtector()