bpo-36801: Fix waiting in StreamWriter.drain for closing SSL transpor… · python/cpython@1cc0ee7 (original) (raw)

`@@ -199,6 +199,9 @@ async def _drain_helper(self):

`

199

199

`self._drain_waiter = waiter

`

200

200

`await waiter

`

201

201

``

``

202

`+

def _get_close_waiter(self, stream):

`

``

203

`+

raise NotImplementedError

`

``

204

+

202

205

``

203

206

`class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):

`

204

207

`"""Helper class to adapt between Protocol and StreamReader.

`

`@@ -315,6 +318,9 @@ def eof_received(self):

`

315

318

`return False

`

316

319

`return True

`

317

320

``

``

321

`+

def _get_close_waiter(self, stream):

`

``

322

`+

return self._closed

`

``

323

+

318

324

`def del(self):

`

319

325

`# Prevent reports about unhandled exceptions.

`

320

326

`# Better than self._closed._log_traceback = False hack

`

`@@ -376,7 +382,7 @@ def is_closing(self):

`

376

382

`return self._transport.is_closing()

`

377

383

``

378

384

`async def wait_closed(self):

`

379

``

`-

await self._protocol._closed

`

``

385

`+

await self._protocol._get_close_waiter(self)

`

380

386

``

381

387

`def get_extra_info(self, name, default=None):

`

382

388

`return self._transport.get_extra_info(name, default)

`

`@@ -394,13 +400,12 @@ async def drain(self):

`

394

400

`if exc is not None:

`

395

401

`raise exc

`

396

402

`if self._transport.is_closing():

`

397

``

`-

Yield to the event loop so connection_lost() may be

`

398

``

`-

called. Without this, _drain_helper() would return

`

399

``

`-

immediately, and code that calls

`

400

``

`-

write(...); await drain()

`

401

``

`-

in a loop would never call connection_lost(), so it

`

402

``

`-

would not see an error when the socket is closed.

`

403

``

`-

await sleep(0, loop=self._loop)

`

``

403

`+

Wait for protocol.connection_lost() call

`

``

404

`+

Raise connection closing error if any,

`

``

405

`+

ConnectionResetError otherwise

`

``

406

`+

fut = self._protocol._get_close_waiter(self)

`

``

407

`+

await fut

`

``

408

`+

raise ConnectionResetError('Connection lost')

`

404

409

`await self._protocol._drain_helper()

`

405

410

``

406

411

`async def aclose(self):

`