(original) (raw)

Index: Lib/unittest/test/test_suite.py =================================================================== --- Lib/unittest/test/test_suite.py (revision 85014) +++ Lib/unittest/test/test_suite.py (working copy) @@ -345,5 +345,19 @@ self.assertEqual(result.testsRun, 2) + def test_overriding_call(self): + class MySuite(unittest.TestSuite): + called = False + def __call__(self, *args, **kw): + self.called = True + unittest.TestSuite.__call__(self, *args, **kw) + + suite = MySuite() + wrapper = unittest.TestSuite() + wrapper.addTest(suite) + wrapper(unittest.TestResult()) + self.assertTrue(suite.called) + + if __name__ == '__main__': unittest.main() Index: Lib/unittest/suite.py =================================================================== --- Lib/unittest/suite.py (revision 85014) +++ Lib/unittest/suite.py (working copy) @@ -78,22 +78,11 @@ """ - def run(self, result): - self._wrapped_run(result) - self._tearDownPreviousClass(None, result) - self._handleModuleTearDown(result) - return result + def run(self, result, debug=False): + topLevel = False + if not hasattr(result, 'foobar'): + result.foobar = topLevel = True - def debug(self): - """Run the tests without collecting errors in a TestResult""" - debug = _DebugResult() - self._wrapped_run(debug, True) - self._tearDownPreviousClass(None, debug) - self._handleModuleTearDown(debug) - - ################################ - # private methods - def _wrapped_run(self, result, debug=False): for test in self: if result.shouldStop: break @@ -108,13 +97,23 @@ getattr(result, '_moduleSetUpFailed', False)): continue - if hasattr(test, '_wrapped_run'): - test._wrapped_run(result, debug) - elif not debug: + if not debug: test(result) else: test.debug() + if topLevel: + self._tearDownPreviousClass(None, result) + self._handleModuleTearDown(result) + return result + + def debug(self): + """Run the tests without collecting errors in a TestResult""" + debug = _DebugResult() + self.run(debug, True) + + ################################ + def _handleClassSetUp(self, test, result): previousClass = getattr(result, '_previousTestClass', None) currentClass = test.__class__