bpo-46752: Add TaskGroup; add Task..cancelled(),.uncancel() (GH-31270) · python/cpython@602630a (original) (raw)

``

1

`+

Adapted with permission from the EdgeDB project.

`

``

2

+

``

3

+

``

4

`+

all = ["TaskGroup"]

`

``

5

+

``

6

`+

import itertools

`

``

7

`+

import textwrap

`

``

8

`+

import traceback

`

``

9

`+

import types

`

``

10

`+

import weakref

`

``

11

+

``

12

`+

from . import events

`

``

13

`+

from . import exceptions

`

``

14

`+

from . import tasks

`

``

15

+

``

16

`+

class TaskGroup:

`

``

17

+

``

18

`+

def init(self, *, name=None):

`

``

19

`+

if name is None:

`

``

20

`+

self._name = f'tg-{_name_counter()}'

`

``

21

`+

else:

`

``

22

`+

self._name = str(name)

`

``

23

+

``

24

`+

self._entered = False

`

``

25

`+

self._exiting = False

`

``

26

`+

self._aborting = False

`

``

27

`+

self._loop = None

`

``

28

`+

self._parent_task = None

`

``

29

`+

self._parent_cancel_requested = False

`

``

30

`+

self._tasks = weakref.WeakSet()

`

``

31

`+

self._unfinished_tasks = 0

`

``

32

`+

self._errors = []

`

``

33

`+

self._base_error = None

`

``

34

`+

self._on_completed_fut = None

`

``

35

+

``

36

`+

def get_name(self):

`

``

37

`+

return self._name

`

``

38

+

``

39

`+

def repr(self):

`

``

40

`+

msg = f'<TaskGroup {self._name!r}'

`

``

41

`+

if self._tasks:

`

``

42

`+

msg += f' tasks:{len(self._tasks)}'

`

``

43

`+

if self._unfinished_tasks:

`

``

44

`+

msg += f' unfinished:{self._unfinished_tasks}'

`

``

45

`+

if self._errors:

`

``

46

`+

msg += f' errors:{len(self._errors)}'

`

``

47

`+

if self._aborting:

`

``

48

`+

msg += ' cancelling'

`

``

49

`+

elif self._entered:

`

``

50

`+

msg += ' entered'

`

``

51

`+

msg += '>'

`

``

52

`+

return msg

`

``

53

+

``

54

`+

async def aenter(self):

`

``

55

`+

if self._entered:

`

``

56

`+

raise RuntimeError(

`

``

57

`+

f"TaskGroup {self!r} has been already entered")

`

``

58

`+

self._entered = True

`

``

59

+

``

60

`+

if self._loop is None:

`

``

61

`+

self._loop = events.get_running_loop()

`

``

62

+

``

63

`+

self._parent_task = tasks.current_task(self._loop)

`

``

64

`+

if self._parent_task is None:

`

``

65

`+

raise RuntimeError(

`

``

66

`+

f'TaskGroup {self!r} cannot determine the parent task')

`

``

67

+

``

68

`+

return self

`

``

69

+

``

70

`+

async def aexit(self, et, exc, tb):

`

``

71

`+

self._exiting = True

`

``

72

`+

propagate_cancellation_error = None

`

``

73

+

``

74

`+

if (exc is not None and

`

``

75

`+

self._is_base_error(exc) and

`

``

76

`+

self._base_error is None):

`

``

77

`+

self._base_error = exc

`

``

78

+

``

79

`+

if et is exceptions.CancelledError:

`

``

80

`+

if self._parent_cancel_requested:

`

``

81

`+

Only if we did request task to cancel ourselves

`

``

82

`+

we mark it as no longer cancelled.

`

``

83

`+

self._parent_task.uncancel()

`

``

84

`+

else:

`

``

85

`+

propagate_cancellation_error = et

`

``

86

+

``

87

`+

if et is not None and not self._aborting:

`

``

88

`+

Our parent task is being cancelled:

`

``

89

`+

`

``

90

`+

async with TaskGroup() as g:

`

``

91

`+

g.create_task(...)

`

``

92

`+

await ... # <- CancelledError

`

``

93

`+

`

``

94

`+

if et is exceptions.CancelledError:

`

``

95

`+

propagate_cancellation_error = et

`

``

96

+

``

97

`+

or there's an exception in "async with":

`

``

98

`+

`

``

99

`+

async with TaskGroup() as g:

`

``

100

`+

g.create_task(...)

`

``

101

`+

1 / 0

`

``

102

`+

`

``

103

`+

self._abort()

`

``

104

+

``

105

`+

We use while-loop here because "self._on_completed_fut"

`

``

106

`+

can be cancelled multiple times if our parent task

`

``

107

`+

is being cancelled repeatedly (or even once, when

`

``

108

`+

our own cancellation is already in progress)

`

``

109

`+

while self._unfinished_tasks:

`

``

110

`+

if self._on_completed_fut is None:

`

``

111

`+

self._on_completed_fut = self._loop.create_future()

`

``

112

+

``

113

`+

try:

`

``

114

`+

await self._on_completed_fut

`

``

115

`+

except exceptions.CancelledError as ex:

`

``

116

`+

if not self._aborting:

`

``

117

`+

Our parent task is being cancelled:

`

``

118

`+

`

``

119

`+

async def wrapper():

`

``

120

`+

async with TaskGroup() as g:

`

``

121

`+

g.create_task(foo)

`

``

122

`+

`

``

123

`+

"wrapper" is being cancelled while "foo" is

`

``

124

`+

still running.

`

``

125

`+

propagate_cancellation_error = ex

`

``

126

`+

self._abort()

`

``

127

+

``

128

`+

self._on_completed_fut = None

`

``

129

+

``

130

`+

assert self._unfinished_tasks == 0

`

``

131

`+

self._on_completed_fut = None # no longer needed

`

``

132

+

``

133

`+

if self._base_error is not None:

`

``

134

`+

raise self._base_error

`

``

135

+

``

136

`+

if propagate_cancellation_error is not None:

`

``

137

`+

The wrapping task was cancelled; since we're done with

`

``

138

`+

closing all child tasks, just propagate the cancellation

`

``

139

`+

request now.

`

``

140

`+

raise propagate_cancellation_error

`

``

141

+

``

142

`+

if et is not None and et is not exceptions.CancelledError:

`

``

143

`+

self._errors.append(exc)

`

``

144

+

``

145

`+

if self._errors:

`

``

146

`+

Exceptions are heavy objects that can have object

`

``

147

`+

cycles (bad for GC); let's not keep a reference to

`

``

148

`+

a bunch of them.

`

``

149

`+

errors = self._errors

`

``

150

`+

self._errors = None

`

``

151

+

``

152

`+

me = BaseExceptionGroup('unhandled errors in a TaskGroup', errors)

`

``

153

`+

raise me from None

`

``

154

+

``

155

`+

def create_task(self, coro):

`

``

156

`+

if not self._entered:

`

``

157

`+

raise RuntimeError(f"TaskGroup {self!r} has not been entered")

`

``

158

`+

if self._exiting and self._unfinished_tasks == 0:

`

``

159

`+

raise RuntimeError(f"TaskGroup {self!r} is finished")

`

``

160

`+

task = self._loop.create_task(coro)

`

``

161

`+

task.add_done_callback(self._on_task_done)

`

``

162

`+

self._unfinished_tasks += 1

`

``

163

`+

self._tasks.add(task)

`

``

164

`+

return task

`

``

165

+

``

166

`+

Since Python 3.8 Tasks propagate all exceptions correctly,

`

``

167

`+

except for KeyboardInterrupt and SystemExit which are

`

``

168

`+

still considered special.

`

``

169

+

``

170

`+

def _is_base_error(self, exc: BaseException) -> bool:

`

``

171

`+

assert isinstance(exc, BaseException)

`

``

172

`+

return isinstance(exc, (SystemExit, KeyboardInterrupt))

`

``

173

+

``

174

`+

def _abort(self):

`

``

175

`+

self._aborting = True

`

``

176

+

``

177

`+

for t in self._tasks:

`

``

178

`+

if not t.done():

`

``

179

`+

t.cancel()

`

``

180

+

``

181

`+

def _on_task_done(self, task):

`

``

182

`+

self._unfinished_tasks -= 1

`

``

183

`+

assert self._unfinished_tasks >= 0

`

``

184

+

``

185

`+

if self._on_completed_fut is not None and not self._unfinished_tasks:

`

``

186

`+

if not self._on_completed_fut.done():

`

``

187

`+

self._on_completed_fut.set_result(True)

`

``

188

+

``

189

`+

if task.cancelled():

`

``

190

`+

return

`

``

191

+

``

192

`+

exc = task.exception()

`

``

193

`+

if exc is None:

`

``

194

`+

return

`

``

195

+

``

196

`+

self._errors.append(exc)

`

``

197

`+

if self._is_base_error(exc) and self._base_error is None:

`

``

198

`+

self._base_error = exc

`

``

199

+

``

200

`+

if self._parent_task.done():

`

``

201

`+

Not sure if this case is possible, but we want to handle

`

``

202

`+

it anyways.

`

``

203

`+

self._loop.call_exception_handler({

`

``

204

`+

'message': f'Task {task!r} has errored out but its parent '

`

``

205

`+

f'task {self._parent_task} is already completed',

`

``

206

`+

'exception': exc,

`

``

207

`+

'task': task,

`

``

208

`+

})

`

``

209

`+

return

`

``

210

+

``

211

`+

self._abort()

`

``

212

`+

if not self._parent_task.cancelling():

`

``

213

`+

If parent task is not being cancelled, it means that we want

`

``

214

`+

to manually cancel it to abort whatever is being run right now

`

``

215

`+

in the TaskGroup. But we want to mark parent task as

`

``

216

`+

"not cancelled" later in aexit. Example situation that

`

``

217

`+

we need to handle:

`

``

218

`+

`

``

219

`+

async def foo():

`

``

220

`+

try:

`

``

221

`+

async with TaskGroup() as g:

`

``

222

`+

g.create_task(crash_soon())

`

``

223

`+

await something # <- this needs to be canceled

`

``

224

`+

# by the TaskGroup, e.g.

`

``

225

`+

# foo() needs to be cancelled

`

``

226

`+

except Exception:

`

``

227

`+

# Ignore any exceptions raised in the TaskGroup

`

``

228

`+

pass

`

``

229

`+

await something_else # this line has to be called

`

``

230

`+

# after TaskGroup is finished.

`

``

231

`+

self._parent_cancel_requested = True

`

``

232

`+

self._parent_task.cancel()

`

``

233

+

``

234

+

``

235

`+

_name_counter = itertools.count(1).next

`