bpo-29988: Test signal safety of with statements by ncoghlan · Pull Request #2 · ncoghlan/cpython (original) (raw)

@@ -0,0 +1,200 @@

"""Additional signal safety tests for "with" and "async with"

"""

from test.support import cpython_only, verbose

from _testcapi import install_error_injection_hook

import asyncio

import dis

import sys

import threading

import unittest

class InjectedException(Exception):

"""Exception injected into a running frame via a trace function"""

pass

def raise_after_offset(target_function, target_offset):

"""Sets a trace function to inject an exception into given function

Relies on the ability to request that a trace function be called for

every executed opcode, not just every line

"""

target_code = target_function.__code__

def inject_exception():

exc = InjectedException(f"Failing after {target_offset}")

raise exc

# This installs a trace hook that's implemented in C, and hence won't

# trigger any of the per-bytecode processing in the eval loop

# This means it can register the pending call that raises the exception and

# the pending call won't be processed until after the trace hook returns

install_error_injection_hook(target_code, target_offset, inject_exception)

# TODO: Add a test case that ensures raise_after_offset is working

# properly (otherwise there's a risk the tests will pass due to the

# exception not being injected properly)

@cpython_only

class CheckFunctionSignalSafety(unittest.TestCase):

"""Ensure with statements are signal-safe.

Signal safety means that, regardless of when external signals (e.g.

KeyboardInterrupt) are received, if __enter__ succeeds, __exit__ will

be called.

See https://bugs.python.org/issue29988 for more details

"""

def setUp(self):

old_trace = sys.gettrace()

self.addCleanup(sys.settrace, old_trace)

sys.settrace(None)

def assert_lock_released(self, test_lock, target_offset, traced_code):

just_acquired = test_lock.acquire(blocking=False)

# Either we just acquired the lock, or the test didn't release it

test_lock.release()

if not just_acquired:

msg = ("Context manager entered without exit due to "

f"exception injected at offset {target_offset} in:\n"

f"{dis.Bytecode(traced_code).dis()}")

self.fail(msg)

def _check_CM_exits_correctly(self, traced_function):

# Must use a signal-safe CM, otherwise __exit__ will start

# but then fail to actually run as the pending call gets processed

test_lock = threading.Lock()

target_offset = -1

traced_code = dis.Bytecode(traced_function)

for instruction in traced_code:

if instruction.opname == "RETURN_VALUE":

break

max_offset = instruction.offset

while target_offset < max_offset:

target_offset += 1

raise_after_offset(traced_function, target_offset)

try:

traced_function(test_lock)

except InjectedException:

# key invariant: if we entered the CM, we exited it

self.assert_lock_released(test_lock, target_offset, traced_code)

else:

try:

msg = (f"Exception wasn't raised @{target_offset} in:\n"

f"{traced_code.dis()}")

self.fail(msg)

except InjectedException:

# The pending call was still active when we tried to report

# the fact the exception wasn't raised by the traced function

msg = (f"Pending calls weren't processed after @{target_offset} in:\n"

f"{traced_code.dis()}")

self.fail(msg)

def test_with_statement_completed(self):

def traced_function(test_cm):

with test_cm:

1 + 1

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_return(self):

def traced_function(test_cm):

with test_cm:

1 + 1

return

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_continue(self):

def traced_function(test_cm):

for i in range(1):

with test_cm:

1 + 1

continue

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_break(self):

def traced_function(test_cm):

while True:

with test_cm:

1 + 1

break

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_function)

def test_with_statement_exited_via_raise(self):

def traced_function(test_cm):

try:

with test_cm:

1 + 1

1/0

except ZeroDivisionError:

pass

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_function)

@cpython_only

class CheckCoroutineSignalSafety(unittest.TestCase):

"""Ensure async with statements are signal-safe.

Signal safety means that, regardless of when external signals (e.g.

KeyboardInterrupt) are received, if __aenter__ succeeeds, __aexit__ will

be called *and* the resulting awaitable will be awaited.

See https://bugs.python.org/issue29988 for more details

"""

def setUp(self):

old_trace = sys.gettrace()

self.addCleanup(sys.settrace, old_trace)

sys.settrace(None)

def assert_CM_balanced(self, test_cm, target_offset, traced_code):

if test_cm.enter_without_exit:

msg = ("Context manager entered without exit due to "

f"exception injected at offset {target_offset} in:\n"

f"{traced_code.dis()}")

self.fail(msg)

def _check_CM_exits_correctly(self, traced_coroutine):

# NOTE: to get this to work, we also needed to update ceval to ensure

# that at least one line in a frame is executed before signals are

# processed (otherwise __aexit__'s body doesn't run)

class AsyncTrackingCM():

def __init__(self):

self.enter_without_exit = None

async def __aenter__(self):

self.enter_without_exit = True

async def __aexit__(self, *args):

self.enter_without_exit = False

test_cm = AsyncTrackingCM()

target_offset = -1

traced_code = dis.Bytecode(traced_coroutine)

for instruction in traced_code:

if instruction.opname == "RETURN_VALUE":

break

max_offset = instruction.offset

loop = asyncio.get_event_loop()

while target_offset < max_offset:

target_offset += 1

raise_after_offset(traced_coroutine, target_offset)

try:

loop.run_until_complete(traced_coroutine(test_cm))

except InjectedException:

# key invariant: if we entered the CM, we exited it

self.assert_CM_balanced(test_cm, target_offset, traced_code)

else:

msg = (f"Exception wasn't raised @{target_offset} in:\n"

f"{traced_code.dis()}")

self.fail(msg)

def test_async_with_statement_completed(self):

async def traced_coroutine(test_cm):

async with test_cm:

1 + 1

return # Make implicit final return explicit

self._check_CM_exits_correctly(traced_coroutine)

if __name__ == '__main__':

unittest.main()