]>
Commit | Line | Data |
---|---|---|
4ccaab03 JS |
1 | """ |
2 | Generic Asynchronous Message-based Protocol Support | |
3 | ||
4 | This module provides a generic framework for sending and receiving | |
5 | messages over an asyncio stream. `AsyncProtocol` is an abstract class | |
6 | that implements the core mechanisms of a simple send/receive protocol, | |
7 | and is designed to be extended. | |
8 | ||
9 | In this package, it is used as the implementation for the `QMPClient` | |
10 | class. | |
11 | """ | |
12 | ||
13 | import asyncio | |
14 | from asyncio import StreamReader, StreamWriter | |
c58b42e0 JS |
15 | from enum import Enum |
16 | from functools import wraps | |
50e53306 | 17 | import logging |
4ccaab03 | 18 | from ssl import SSLContext |
c58b42e0 | 19 | from typing import ( |
4ccaab03 JS |
20 | Any, |
21 | Awaitable, | |
22 | Callable, | |
23 | Generic, | |
24 | List, | |
25 | Optional, | |
26 | Tuple, | |
27 | TypeVar, | |
28 | Union, | |
c58b42e0 | 29 | cast, |
4ccaab03 JS |
30 | ) |
31 | ||
32 | from .error import AQMPError | |
33 | from .util import ( | |
34 | bottom_half, | |
35 | create_task, | |
50e53306 | 36 | exception_summary, |
4ccaab03 JS |
37 | flush, |
38 | is_closing, | |
50e53306 | 39 | pretty_traceback, |
4ccaab03 JS |
40 | upper_half, |
41 | wait_closed, | |
42 | ) | |
43 | ||
44 | ||
45 | T = TypeVar('T') | |
46 | _TaskFN = Callable[[], Awaitable[None]] # aka ``async def func() -> None`` | |
47 | _FutureT = TypeVar('_FutureT', bound=Optional['asyncio.Future[Any]']) | |
48 | ||
49 | ||
c58b42e0 JS |
50 | class Runstate(Enum): |
51 | """Protocol session runstate.""" | |
52 | ||
53 | #: Fully quiesced and disconnected. | |
54 | IDLE = 0 | |
55 | #: In the process of connecting or establishing a session. | |
56 | CONNECTING = 1 | |
57 | #: Fully connected and active session. | |
58 | RUNNING = 2 | |
59 | #: In the process of disconnecting. | |
60 | #: Runstate may be returned to `IDLE` by calling `disconnect()`. | |
61 | DISCONNECTING = 3 | |
62 | ||
63 | ||
4ccaab03 JS |
64 | class ConnectError(AQMPError): |
65 | """ | |
66 | Raised when the initial connection process has failed. | |
67 | ||
68 | This Exception always wraps a "root cause" exception that can be | |
69 | interrogated for additional information. | |
70 | ||
71 | :param error_message: Human-readable string describing the error. | |
72 | :param exc: The root-cause exception. | |
73 | """ | |
74 | def __init__(self, error_message: str, exc: Exception): | |
75 | super().__init__(error_message) | |
76 | #: Human-readable error string | |
77 | self.error_message: str = error_message | |
78 | #: Wrapped root cause exception | |
79 | self.exc: Exception = exc | |
80 | ||
81 | def __str__(self) -> str: | |
82 | return f"{self.error_message}: {self.exc!s}" | |
83 | ||
84 | ||
c58b42e0 JS |
85 | class StateError(AQMPError): |
86 | """ | |
87 | An API command (connect, execute, etc) was issued at an inappropriate time. | |
88 | ||
89 | This error is raised when a command like | |
90 | :py:meth:`~AsyncProtocol.connect()` is issued at an inappropriate | |
91 | time. | |
92 | ||
93 | :param error_message: Human-readable string describing the state violation. | |
94 | :param state: The actual `Runstate` seen at the time of the violation. | |
95 | :param required: The `Runstate` required to process this command. | |
96 | """ | |
97 | def __init__(self, error_message: str, | |
98 | state: Runstate, required: Runstate): | |
99 | super().__init__(error_message) | |
100 | self.error_message = error_message | |
101 | self.state = state | |
102 | self.required = required | |
103 | ||
104 | ||
105 | F = TypeVar('F', bound=Callable[..., Any]) # pylint: disable=invalid-name | |
106 | ||
107 | ||
108 | # Don't Panic. | |
109 | def require(required_state: Runstate) -> Callable[[F], F]: | |
110 | """ | |
111 | Decorator: protect a method so it can only be run in a certain `Runstate`. | |
112 | ||
113 | :param required_state: The `Runstate` required to invoke this method. | |
114 | :raise StateError: When the required `Runstate` is not met. | |
115 | """ | |
116 | def _decorator(func: F) -> F: | |
117 | # _decorator is the decorator that is built by calling the | |
118 | # require() decorator factory; e.g.: | |
119 | # | |
120 | # @require(Runstate.IDLE) def foo(): ... | |
121 | # will replace 'foo' with the result of '_decorator(foo)'. | |
122 | ||
123 | @wraps(func) | |
124 | def _wrapper(proto: 'AsyncProtocol[Any]', | |
125 | *args: Any, **kwargs: Any) -> Any: | |
126 | # _wrapper is the function that gets executed prior to the | |
127 | # decorated method. | |
128 | ||
129 | name = type(proto).__name__ | |
130 | ||
131 | if proto.runstate != required_state: | |
132 | if proto.runstate == Runstate.CONNECTING: | |
133 | emsg = f"{name} is currently connecting." | |
134 | elif proto.runstate == Runstate.DISCONNECTING: | |
135 | emsg = (f"{name} is disconnecting." | |
136 | " Call disconnect() to return to IDLE state.") | |
137 | elif proto.runstate == Runstate.RUNNING: | |
138 | emsg = f"{name} is already connected and running." | |
139 | elif proto.runstate == Runstate.IDLE: | |
140 | emsg = f"{name} is disconnected and idle." | |
141 | else: | |
142 | assert False | |
143 | raise StateError(emsg, proto.runstate, required_state) | |
144 | # No StateError, so call the wrapped method. | |
145 | return func(proto, *args, **kwargs) | |
146 | ||
147 | # Return the decorated method; | |
148 | # Transforming Func to Decorated[Func]. | |
149 | return cast(F, _wrapper) | |
150 | ||
151 | # Return the decorator instance from the decorator factory. Phew! | |
152 | return _decorator | |
153 | ||
154 | ||
4ccaab03 JS |
155 | class AsyncProtocol(Generic[T]): |
156 | """ | |
157 | AsyncProtocol implements a generic async message-based protocol. | |
158 | ||
159 | This protocol assumes the basic unit of information transfer between | |
160 | client and server is a "message", the details of which are left up | |
161 | to the implementation. It assumes the sending and receiving of these | |
162 | messages is full-duplex and not necessarily correlated; i.e. it | |
163 | supports asynchronous inbound messages. | |
164 | ||
165 | It is designed to be extended by a specific protocol which provides | |
166 | the implementations for how to read and send messages. These must be | |
167 | defined in `_do_recv()` and `_do_send()`, respectively. | |
168 | ||
169 | Other callbacks have a default implementation, but are intended to be | |
170 | either extended or overridden: | |
171 | ||
172 | - `_establish_session`: | |
173 | The base implementation starts the reader/writer tasks. | |
174 | A protocol implementation can override this call, inserting | |
175 | actions to be taken prior to starting the reader/writer tasks | |
176 | before the super() call; actions needing to occur afterwards | |
177 | can be written after the super() call. | |
178 | - `_on_message`: | |
179 | Actions to be performed when a message is received. | |
12c7a57f JS |
180 | - `_cb_outbound`: |
181 | Logging/Filtering hook for all outbound messages. | |
182 | - `_cb_inbound`: | |
183 | Logging/Filtering hook for all inbound messages. | |
184 | This hook runs *before* `_on_message()`. | |
50e53306 JS |
185 | |
186 | :param name: | |
187 | Name used for logging messages, if any. By default, messages | |
188 | will log to 'qemu.aqmp.protocol', but each individual connection | |
189 | can be given its own logger by giving it a name; messages will | |
190 | then log to 'qemu.aqmp.protocol.${name}'. | |
4ccaab03 JS |
191 | """ |
192 | # pylint: disable=too-many-instance-attributes | |
193 | ||
50e53306 JS |
194 | #: Logger object for debugging messages from this connection. |
195 | logger = logging.getLogger(__name__) | |
196 | ||
2686ac13 JS |
197 | # Maximum allowable size of read buffer |
198 | _limit = (64 * 1024) | |
199 | ||
4ccaab03 JS |
200 | # ------------------------- |
201 | # Section: Public interface | |
202 | # ------------------------- | |
203 | ||
50e53306 JS |
204 | def __init__(self, name: Optional[str] = None) -> None: |
205 | #: The nickname for this connection, if any. | |
206 | self.name: Optional[str] = name | |
207 | if self.name is not None: | |
208 | self.logger = self.logger.getChild(self.name) | |
209 | ||
4ccaab03 JS |
210 | # stream I/O |
211 | self._reader: Optional[StreamReader] = None | |
212 | self._writer: Optional[StreamWriter] = None | |
213 | ||
214 | # Outbound Message queue | |
215 | self._outgoing: asyncio.Queue[T] | |
216 | ||
217 | # Special, long-running tasks: | |
218 | self._reader_task: Optional[asyncio.Future[None]] = None | |
219 | self._writer_task: Optional[asyncio.Future[None]] = None | |
220 | ||
221 | # Aggregate of the above two tasks, used for Exception management. | |
222 | self._bh_tasks: Optional[asyncio.Future[Tuple[None, None]]] = None | |
223 | ||
224 | #: Disconnect task. The disconnect implementation runs in a task | |
225 | #: so that asynchronous disconnects (initiated by the | |
226 | #: reader/writer) are allowed to wait for the reader/writers to | |
227 | #: exit. | |
228 | self._dc_task: Optional[asyncio.Future[None]] = None | |
229 | ||
c58b42e0 JS |
230 | self._runstate = Runstate.IDLE |
231 | self._runstate_changed: Optional[asyncio.Event] = None | |
232 | ||
50e53306 JS |
233 | def __repr__(self) -> str: |
234 | cls_name = type(self).__name__ | |
235 | tokens = [] | |
236 | if self.name is not None: | |
237 | tokens.append(f"name={self.name!r}") | |
238 | tokens.append(f"runstate={self.runstate.name}") | |
239 | return f"<{cls_name} {' '.join(tokens)}>" | |
240 | ||
c58b42e0 JS |
241 | @property # @upper_half |
242 | def runstate(self) -> Runstate: | |
243 | """The current `Runstate` of the connection.""" | |
244 | return self._runstate | |
245 | ||
4ccaab03 | 246 | @upper_half |
c58b42e0 JS |
247 | async def runstate_changed(self) -> Runstate: |
248 | """ | |
249 | Wait for the `runstate` to change, then return that runstate. | |
250 | """ | |
251 | await self._runstate_event.wait() | |
252 | return self.runstate | |
253 | ||
774c64a5 JS |
254 | @upper_half |
255 | @require(Runstate.IDLE) | |
256 | async def accept(self, address: Union[str, Tuple[str, int]], | |
257 | ssl: Optional[SSLContext] = None) -> None: | |
258 | """ | |
259 | Accept a connection and begin processing message queues. | |
260 | ||
261 | If this call fails, `runstate` is guaranteed to be set back to `IDLE`. | |
262 | ||
263 | :param address: | |
264 | Address to listen to; UNIX socket path or TCP address/port. | |
265 | :param ssl: SSL context to use, if any. | |
266 | ||
267 | :raise StateError: When the `Runstate` is not `IDLE`. | |
268 | :raise ConnectError: If a connection could not be accepted. | |
269 | """ | |
270 | await self._new_session(address, ssl, accept=True) | |
271 | ||
c58b42e0 JS |
272 | @upper_half |
273 | @require(Runstate.IDLE) | |
4ccaab03 JS |
274 | async def connect(self, address: Union[str, Tuple[str, int]], |
275 | ssl: Optional[SSLContext] = None) -> None: | |
276 | """ | |
277 | Connect to the server and begin processing message queues. | |
278 | ||
279 | If this call fails, `runstate` is guaranteed to be set back to `IDLE`. | |
280 | ||
281 | :param address: | |
282 | Address to connect to; UNIX socket path or TCP address/port. | |
283 | :param ssl: SSL context to use, if any. | |
284 | ||
285 | :raise StateError: When the `Runstate` is not `IDLE`. | |
286 | :raise ConnectError: If a connection cannot be made to the server. | |
287 | """ | |
288 | await self._new_session(address, ssl) | |
289 | ||
290 | @upper_half | |
291 | async def disconnect(self) -> None: | |
292 | """ | |
293 | Disconnect and wait for all tasks to fully stop. | |
294 | ||
295 | If there was an exception that caused the reader/writers to | |
296 | terminate prematurely, it will be raised here. | |
297 | ||
298 | :raise Exception: When the reader or writer terminate unexpectedly. | |
299 | """ | |
50e53306 | 300 | self.logger.debug("disconnect() called.") |
4ccaab03 JS |
301 | self._schedule_disconnect() |
302 | await self._wait_disconnect() | |
303 | ||
304 | # -------------------------- | |
305 | # Section: Session machinery | |
306 | # -------------------------- | |
307 | ||
c58b42e0 JS |
308 | @property |
309 | def _runstate_event(self) -> asyncio.Event: | |
310 | # asyncio.Event() objects should not be created prior to entrance into | |
311 | # an event loop, so we can ensure we create it in the correct context. | |
312 | # Create it on-demand *only* at the behest of an 'async def' method. | |
313 | if not self._runstate_changed: | |
314 | self._runstate_changed = asyncio.Event() | |
315 | return self._runstate_changed | |
316 | ||
317 | @upper_half | |
318 | @bottom_half | |
319 | def _set_state(self, state: Runstate) -> None: | |
320 | """ | |
321 | Change the `Runstate` of the protocol connection. | |
322 | ||
323 | Signals the `runstate_changed` event. | |
324 | """ | |
325 | if state == self._runstate: | |
326 | return | |
327 | ||
50e53306 JS |
328 | self.logger.debug("Transitioning from '%s' to '%s'.", |
329 | str(self._runstate), str(state)) | |
c58b42e0 JS |
330 | self._runstate = state |
331 | self._runstate_event.set() | |
332 | self._runstate_event.clear() | |
333 | ||
4ccaab03 JS |
334 | @upper_half |
335 | async def _new_session(self, | |
336 | address: Union[str, Tuple[str, int]], | |
774c64a5 JS |
337 | ssl: Optional[SSLContext] = None, |
338 | accept: bool = False) -> None: | |
4ccaab03 JS |
339 | """ |
340 | Establish a new connection and initialize the session. | |
341 | ||
342 | Connect or accept a new connection, then begin the protocol | |
343 | session machinery. If this call fails, `runstate` is guaranteed | |
344 | to be set back to `IDLE`. | |
345 | ||
346 | :param address: | |
774c64a5 | 347 | Address to connect to/listen on; |
4ccaab03 JS |
348 | UNIX socket path or TCP address/port. |
349 | :param ssl: SSL context to use, if any. | |
774c64a5 | 350 | :param accept: Accept a connection instead of connecting when `True`. |
4ccaab03 JS |
351 | |
352 | :raise ConnectError: | |
353 | When a connection or session cannot be established. | |
354 | ||
355 | This exception will wrap a more concrete one. In most cases, | |
356 | the wrapped exception will be `OSError` or `EOFError`. If a | |
357 | protocol-level failure occurs while establishing a new | |
358 | session, the wrapped error may also be an `AQMPError`. | |
359 | """ | |
c58b42e0 JS |
360 | assert self.runstate == Runstate.IDLE |
361 | ||
4ccaab03 JS |
362 | try: |
363 | phase = "connection" | |
774c64a5 | 364 | await self._establish_connection(address, ssl, accept) |
4ccaab03 JS |
365 | |
366 | phase = "session" | |
367 | await self._establish_session() | |
368 | ||
369 | except BaseException as err: | |
370 | emsg = f"Failed to establish {phase}" | |
50e53306 JS |
371 | self.logger.error("%s: %s", emsg, exception_summary(err)) |
372 | self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) | |
373 | try: | |
374 | # Reset from CONNECTING back to IDLE. | |
375 | await self.disconnect() | |
376 | except: | |
377 | emsg = "Unexpected bottom half exception" | |
378 | self.logger.critical("%s:\n%s\n", emsg, pretty_traceback()) | |
379 | raise | |
4ccaab03 JS |
380 | |
381 | # NB: CancelledError is not a BaseException before Python 3.8 | |
382 | if isinstance(err, asyncio.CancelledError): | |
383 | raise | |
384 | ||
385 | if isinstance(err, Exception): | |
386 | raise ConnectError(emsg, err) from err | |
387 | ||
388 | # Raise BaseExceptions un-wrapped, they're more important. | |
389 | raise | |
390 | ||
c58b42e0 JS |
391 | assert self.runstate == Runstate.RUNNING |
392 | ||
4ccaab03 JS |
393 | @upper_half |
394 | async def _establish_connection( | |
395 | self, | |
396 | address: Union[str, Tuple[str, int]], | |
397 | ssl: Optional[SSLContext] = None, | |
774c64a5 | 398 | accept: bool = False |
4ccaab03 JS |
399 | ) -> None: |
400 | """ | |
401 | Establish a new connection. | |
402 | ||
403 | :param address: | |
404 | Address to connect to/listen on; | |
405 | UNIX socket path or TCP address/port. | |
406 | :param ssl: SSL context to use, if any. | |
774c64a5 | 407 | :param accept: Accept a connection instead of connecting when `True`. |
4ccaab03 | 408 | """ |
c58b42e0 JS |
409 | assert self.runstate == Runstate.IDLE |
410 | self._set_state(Runstate.CONNECTING) | |
411 | ||
412 | # Allow runstate watchers to witness 'CONNECTING' state; some | |
413 | # failures in the streaming layer are synchronous and will not | |
414 | # otherwise yield. | |
415 | await asyncio.sleep(0) | |
416 | ||
774c64a5 JS |
417 | if accept: |
418 | await self._do_accept(address, ssl) | |
419 | else: | |
420 | await self._do_connect(address, ssl) | |
421 | ||
422 | @upper_half | |
423 | async def _do_accept(self, address: Union[str, Tuple[str, int]], | |
424 | ssl: Optional[SSLContext] = None) -> None: | |
425 | """ | |
426 | Acting as the transport server, accept a single connection. | |
427 | ||
428 | :param address: | |
429 | Address to listen on; UNIX socket path or TCP address/port. | |
430 | :param ssl: SSL context to use, if any. | |
431 | ||
432 | :raise OSError: For stream-related errors. | |
433 | """ | |
434 | self.logger.debug("Awaiting connection on %s ...", address) | |
435 | connected = asyncio.Event() | |
436 | server: Optional[asyncio.AbstractServer] = None | |
437 | ||
438 | async def _client_connected_cb(reader: asyncio.StreamReader, | |
439 | writer: asyncio.StreamWriter) -> None: | |
440 | """Used to accept a single incoming connection, see below.""" | |
441 | nonlocal server | |
442 | nonlocal connected | |
443 | ||
444 | # A connection has been accepted; stop listening for new ones. | |
445 | assert server is not None | |
446 | server.close() | |
447 | await server.wait_closed() | |
448 | server = None | |
449 | ||
450 | # Register this client as being connected | |
451 | self._reader, self._writer = (reader, writer) | |
452 | ||
453 | # Signal back: We've accepted a client! | |
454 | connected.set() | |
455 | ||
456 | if isinstance(address, tuple): | |
457 | coro = asyncio.start_server( | |
458 | _client_connected_cb, | |
459 | host=address[0], | |
460 | port=address[1], | |
461 | ssl=ssl, | |
462 | backlog=1, | |
2686ac13 | 463 | limit=self._limit, |
774c64a5 JS |
464 | ) |
465 | else: | |
466 | coro = asyncio.start_unix_server( | |
467 | _client_connected_cb, | |
468 | path=address, | |
469 | ssl=ssl, | |
470 | backlog=1, | |
2686ac13 | 471 | limit=self._limit, |
774c64a5 JS |
472 | ) |
473 | ||
474 | server = await coro # Starts listening | |
475 | await connected.wait() # Waits for the callback to fire (and finish) | |
476 | assert server is None | |
477 | ||
478 | self.logger.debug("Connection accepted.") | |
4ccaab03 JS |
479 | |
480 | @upper_half | |
481 | async def _do_connect(self, address: Union[str, Tuple[str, int]], | |
482 | ssl: Optional[SSLContext] = None) -> None: | |
483 | """ | |
484 | Acting as the transport client, initiate a connection to a server. | |
485 | ||
486 | :param address: | |
487 | Address to connect to; UNIX socket path or TCP address/port. | |
488 | :param ssl: SSL context to use, if any. | |
489 | ||
490 | :raise OSError: For stream-related errors. | |
491 | """ | |
50e53306 JS |
492 | self.logger.debug("Connecting to %s ...", address) |
493 | ||
4ccaab03 | 494 | if isinstance(address, tuple): |
2686ac13 JS |
495 | connect = asyncio.open_connection( |
496 | address[0], | |
497 | address[1], | |
498 | ssl=ssl, | |
499 | limit=self._limit, | |
500 | ) | |
4ccaab03 | 501 | else: |
2686ac13 JS |
502 | connect = asyncio.open_unix_connection( |
503 | path=address, | |
504 | ssl=ssl, | |
505 | limit=self._limit, | |
506 | ) | |
4ccaab03 JS |
507 | self._reader, self._writer = await connect |
508 | ||
50e53306 JS |
509 | self.logger.debug("Connected.") |
510 | ||
4ccaab03 JS |
511 | @upper_half |
512 | async def _establish_session(self) -> None: | |
513 | """ | |
514 | Establish a new session. | |
515 | ||
516 | Starts the readers/writer tasks; subclasses may perform their | |
517 | own negotiations here. The Runstate will be RUNNING upon | |
518 | successful conclusion. | |
519 | """ | |
c58b42e0 JS |
520 | assert self.runstate == Runstate.CONNECTING |
521 | ||
4ccaab03 JS |
522 | self._outgoing = asyncio.Queue() |
523 | ||
50e53306 JS |
524 | reader_coro = self._bh_loop_forever(self._bh_recv_message, 'Reader') |
525 | writer_coro = self._bh_loop_forever(self._bh_send_message, 'Writer') | |
4ccaab03 JS |
526 | |
527 | self._reader_task = create_task(reader_coro) | |
528 | self._writer_task = create_task(writer_coro) | |
529 | ||
530 | self._bh_tasks = asyncio.gather( | |
531 | self._reader_task, | |
532 | self._writer_task, | |
533 | ) | |
534 | ||
c58b42e0 JS |
535 | self._set_state(Runstate.RUNNING) |
536 | await asyncio.sleep(0) # Allow runstate_event to process | |
537 | ||
4ccaab03 JS |
538 | @upper_half |
539 | @bottom_half | |
540 | def _schedule_disconnect(self) -> None: | |
541 | """ | |
542 | Initiate a disconnect; idempotent. | |
543 | ||
544 | This method is used both in the upper-half as a direct | |
545 | consequence of `disconnect()`, and in the bottom-half in the | |
546 | case of unhandled exceptions in the reader/writer tasks. | |
547 | ||
548 | It can be invoked no matter what the `runstate` is. | |
549 | """ | |
550 | if not self._dc_task: | |
c58b42e0 | 551 | self._set_state(Runstate.DISCONNECTING) |
50e53306 | 552 | self.logger.debug("Scheduling disconnect.") |
4ccaab03 JS |
553 | self._dc_task = create_task(self._bh_disconnect()) |
554 | ||
555 | @upper_half | |
556 | async def _wait_disconnect(self) -> None: | |
557 | """ | |
558 | Waits for a previously scheduled disconnect to finish. | |
559 | ||
560 | This method will gather any bottom half exceptions and re-raise | |
561 | the one that occurred first; presuming it to be the root cause | |
562 | of any subsequent Exceptions. It is intended to be used in the | |
563 | upper half of the call chain. | |
564 | ||
565 | :raise Exception: | |
566 | Arbitrary exception re-raised on behalf of the reader/writer. | |
567 | """ | |
c58b42e0 | 568 | assert self.runstate == Runstate.DISCONNECTING |
4ccaab03 JS |
569 | assert self._dc_task |
570 | ||
571 | aws: List[Awaitable[object]] = [self._dc_task] | |
572 | if self._bh_tasks: | |
573 | aws.insert(0, self._bh_tasks) | |
574 | all_defined_tasks = asyncio.gather(*aws) | |
575 | ||
576 | # Ensure disconnect is done; Exception (if any) is not raised here: | |
577 | await asyncio.wait((self._dc_task,)) | |
578 | ||
579 | try: | |
580 | await all_defined_tasks # Raise Exceptions from the bottom half. | |
581 | finally: | |
582 | self._cleanup() | |
c58b42e0 | 583 | self._set_state(Runstate.IDLE) |
4ccaab03 JS |
584 | |
585 | @upper_half | |
586 | def _cleanup(self) -> None: | |
587 | """ | |
588 | Fully reset this object to a clean state and return to `IDLE`. | |
589 | """ | |
590 | def _paranoid_task_erase(task: _FutureT) -> Optional[_FutureT]: | |
591 | # Help to erase a task, ENSURING it is fully quiesced first. | |
592 | assert (task is None) or task.done() | |
593 | return None if (task and task.done()) else task | |
594 | ||
c58b42e0 | 595 | assert self.runstate == Runstate.DISCONNECTING |
4ccaab03 JS |
596 | self._dc_task = _paranoid_task_erase(self._dc_task) |
597 | self._reader_task = _paranoid_task_erase(self._reader_task) | |
598 | self._writer_task = _paranoid_task_erase(self._writer_task) | |
599 | self._bh_tasks = _paranoid_task_erase(self._bh_tasks) | |
600 | ||
601 | self._reader = None | |
602 | self._writer = None | |
603 | ||
c58b42e0 JS |
604 | # NB: _runstate_changed cannot be cleared because we still need it to |
605 | # send the final runstate changed event ...! | |
606 | ||
4ccaab03 JS |
607 | # ---------------------------- |
608 | # Section: Bottom Half methods | |
609 | # ---------------------------- | |
610 | ||
611 | @bottom_half | |
612 | async def _bh_disconnect(self) -> None: | |
613 | """ | |
614 | Disconnect and cancel all outstanding tasks. | |
615 | ||
616 | It is designed to be called from its task context, | |
617 | :py:obj:`~AsyncProtocol._dc_task`. By running in its own task, | |
618 | it is free to wait on any pending actions that may still need to | |
619 | occur in either the reader or writer tasks. | |
620 | """ | |
c58b42e0 | 621 | assert self.runstate == Runstate.DISCONNECTING |
4ccaab03 JS |
622 | |
623 | def _done(task: Optional['asyncio.Future[Any]']) -> bool: | |
624 | return task is not None and task.done() | |
625 | ||
626 | # NB: We can't rely on _bh_tasks being done() here, it may not | |
627 | # yet have had a chance to run and gather itself. | |
628 | tasks = tuple(filter(None, (self._writer_task, self._reader_task))) | |
629 | error_pathway = _done(self._reader_task) or _done(self._writer_task) | |
630 | ||
631 | try: | |
632 | # Try to flush the writer, if possible: | |
633 | if not error_pathway: | |
634 | await self._bh_flush_writer() | |
50e53306 | 635 | except BaseException as err: |
4ccaab03 | 636 | error_pathway = True |
50e53306 JS |
637 | emsg = "Failed to flush the writer" |
638 | self.logger.error("%s: %s", emsg, exception_summary(err)) | |
639 | self.logger.debug("%s:\n%s\n", emsg, pretty_traceback()) | |
4ccaab03 JS |
640 | raise |
641 | finally: | |
642 | # Cancel any still-running tasks: | |
643 | if self._writer_task is not None and not self._writer_task.done(): | |
50e53306 | 644 | self.logger.debug("Cancelling writer task.") |
4ccaab03 JS |
645 | self._writer_task.cancel() |
646 | if self._reader_task is not None and not self._reader_task.done(): | |
50e53306 | 647 | self.logger.debug("Cancelling reader task.") |
4ccaab03 JS |
648 | self._reader_task.cancel() |
649 | ||
650 | # Close out the tasks entirely (Won't raise): | |
651 | if tasks: | |
50e53306 | 652 | self.logger.debug("Waiting for tasks to complete ...") |
4ccaab03 JS |
653 | await asyncio.wait(tasks) |
654 | ||
655 | # Lastly, close the stream itself. (May raise): | |
656 | await self._bh_close_stream(error_pathway) | |
50e53306 | 657 | self.logger.debug("Disconnected.") |
4ccaab03 JS |
658 | |
659 | @bottom_half | |
660 | async def _bh_flush_writer(self) -> None: | |
661 | if not self._writer_task: | |
662 | return | |
663 | ||
50e53306 | 664 | self.logger.debug("Draining the outbound queue ...") |
4ccaab03 JS |
665 | await self._outgoing.join() |
666 | if self._writer is not None: | |
50e53306 | 667 | self.logger.debug("Flushing the StreamWriter ...") |
4ccaab03 JS |
668 | await flush(self._writer) |
669 | ||
670 | @bottom_half | |
671 | async def _bh_close_stream(self, error_pathway: bool = False) -> None: | |
672 | # NB: Closing the writer also implcitly closes the reader. | |
673 | if not self._writer: | |
674 | return | |
675 | ||
676 | if not is_closing(self._writer): | |
50e53306 | 677 | self.logger.debug("Closing StreamWriter.") |
4ccaab03 JS |
678 | self._writer.close() |
679 | ||
50e53306 | 680 | self.logger.debug("Waiting for StreamWriter to close ...") |
4ccaab03 JS |
681 | try: |
682 | await wait_closed(self._writer) | |
683 | except Exception: # pylint: disable=broad-except | |
684 | # It's hard to tell if the Stream is already closed or | |
685 | # not. Even if one of the tasks has failed, it may have | |
686 | # failed for a higher-layered protocol reason. The | |
687 | # stream could still be open and perfectly fine. | |
688 | # I don't know how to discern its health here. | |
689 | ||
690 | if error_pathway: | |
691 | # We already know that *something* went wrong. Let's | |
692 | # just trust that the Exception we already have is the | |
693 | # better one to present to the user, even if we don't | |
694 | # genuinely *know* the relationship between the two. | |
50e53306 JS |
695 | self.logger.debug( |
696 | "Discarding Exception from wait_closed:\n%s\n", | |
697 | pretty_traceback(), | |
698 | ) | |
4ccaab03 JS |
699 | else: |
700 | # Oops, this is a brand-new error! | |
701 | raise | |
50e53306 JS |
702 | finally: |
703 | self.logger.debug("StreamWriter closed.") | |
4ccaab03 JS |
704 | |
705 | @bottom_half | |
50e53306 | 706 | async def _bh_loop_forever(self, async_fn: _TaskFN, name: str) -> None: |
4ccaab03 JS |
707 | """ |
708 | Run one of the bottom-half methods in a loop forever. | |
709 | ||
710 | If the bottom half ever raises any exception, schedule a | |
711 | disconnect that will terminate the entire loop. | |
712 | ||
713 | :param async_fn: The bottom-half method to run in a loop. | |
50e53306 | 714 | :param name: The name of this task, used for logging. |
4ccaab03 JS |
715 | """ |
716 | try: | |
717 | while True: | |
718 | await async_fn() | |
719 | except asyncio.CancelledError: | |
720 | # We have been cancelled by _bh_disconnect, exit gracefully. | |
50e53306 | 721 | self.logger.debug("Task.%s: cancelled.", name) |
4ccaab03 | 722 | return |
50e53306 | 723 | except BaseException as err: |
3e55dc35 JS |
724 | self.logger.log( |
725 | logging.INFO if isinstance(err, EOFError) else logging.ERROR, | |
726 | "Task.%s: %s", | |
727 | name, exception_summary(err) | |
728 | ) | |
50e53306 JS |
729 | self.logger.debug("Task.%s: failure:\n%s\n", |
730 | name, pretty_traceback()) | |
4ccaab03 JS |
731 | self._schedule_disconnect() |
732 | raise | |
50e53306 JS |
733 | finally: |
734 | self.logger.debug("Task.%s: exiting.", name) | |
4ccaab03 JS |
735 | |
736 | @bottom_half | |
737 | async def _bh_send_message(self) -> None: | |
738 | """ | |
739 | Wait for an outgoing message, then send it. | |
740 | ||
741 | Designed to be run in `_bh_loop_forever()`. | |
742 | """ | |
743 | msg = await self._outgoing.get() | |
744 | try: | |
745 | await self._send(msg) | |
746 | finally: | |
747 | self._outgoing.task_done() | |
748 | ||
749 | @bottom_half | |
750 | async def _bh_recv_message(self) -> None: | |
751 | """ | |
752 | Wait for an incoming message and call `_on_message` to route it. | |
753 | ||
754 | Designed to be run in `_bh_loop_forever()`. | |
755 | """ | |
756 | msg = await self._recv() | |
757 | await self._on_message(msg) | |
758 | ||
759 | # -------------------- | |
760 | # Section: Message I/O | |
761 | # -------------------- | |
762 | ||
12c7a57f JS |
763 | @upper_half |
764 | @bottom_half | |
765 | def _cb_outbound(self, msg: T) -> T: | |
766 | """ | |
767 | Callback: outbound message hook. | |
768 | ||
769 | This is intended for subclasses to be able to add arbitrary | |
770 | hooks to filter or manipulate outgoing messages. The base | |
771 | implementation does nothing but log the message without any | |
772 | manipulation of the message. | |
773 | ||
774 | :param msg: raw outbound message | |
775 | :return: final outbound message | |
776 | """ | |
777 | self.logger.debug("--> %s", str(msg)) | |
778 | return msg | |
779 | ||
780 | @upper_half | |
781 | @bottom_half | |
782 | def _cb_inbound(self, msg: T) -> T: | |
783 | """ | |
784 | Callback: inbound message hook. | |
785 | ||
786 | This is intended for subclasses to be able to add arbitrary | |
787 | hooks to filter or manipulate incoming messages. The base | |
788 | implementation does nothing but log the message without any | |
789 | manipulation of the message. | |
790 | ||
791 | This method does not "handle" incoming messages; it is a filter. | |
792 | The actual "endpoint" for incoming messages is `_on_message()`. | |
793 | ||
794 | :param msg: raw inbound message | |
795 | :return: processed inbound message | |
796 | """ | |
797 | self.logger.debug("<-- %s", str(msg)) | |
798 | return msg | |
799 | ||
762bd4d7 JS |
800 | @upper_half |
801 | @bottom_half | |
802 | async def _readline(self) -> bytes: | |
803 | """ | |
804 | Wait for a newline from the incoming reader. | |
805 | ||
806 | This method is provided as a convenience for upper-layer | |
807 | protocols, as many are line-based. | |
808 | ||
809 | This method *may* return a sequence of bytes without a trailing | |
810 | newline if EOF occurs, but *some* bytes were received. In this | |
811 | case, the next call will raise `EOFError`. It is assumed that | |
812 | the layer 5 protocol will decide if there is anything meaningful | |
813 | to be done with a partial message. | |
814 | ||
815 | :raise OSError: For stream-related errors. | |
816 | :raise EOFError: | |
817 | If the reader stream is at EOF and there are no bytes to return. | |
818 | :return: bytes, including the newline. | |
819 | """ | |
820 | assert self._reader is not None | |
821 | msg_bytes = await self._reader.readline() | |
822 | ||
823 | if not msg_bytes: | |
824 | if self._reader.at_eof(): | |
825 | raise EOFError | |
826 | ||
827 | return msg_bytes | |
828 | ||
4ccaab03 JS |
829 | @upper_half |
830 | @bottom_half | |
831 | async def _do_recv(self) -> T: | |
832 | """ | |
833 | Abstract: Read from the stream and return a message. | |
834 | ||
835 | Very low-level; intended to only be called by `_recv()`. | |
836 | """ | |
837 | raise NotImplementedError | |
838 | ||
839 | @upper_half | |
840 | @bottom_half | |
841 | async def _recv(self) -> T: | |
842 | """ | |
843 | Read an arbitrary protocol message. | |
844 | ||
845 | .. warning:: | |
846 | This method is intended primarily for `_bh_recv_message()` | |
847 | to use in an asynchronous task loop. Using it outside of | |
848 | this loop will "steal" messages from the normal routing | |
849 | mechanism. It is safe to use prior to `_establish_session()`, | |
850 | but should not be used otherwise. | |
851 | ||
852 | This method uses `_do_recv()` to retrieve the raw message, and | |
853 | then transforms it using `_cb_inbound()`. | |
854 | ||
855 | :return: A single (filtered, processed) protocol message. | |
856 | """ | |
12c7a57f JS |
857 | message = await self._do_recv() |
858 | return self._cb_inbound(message) | |
4ccaab03 JS |
859 | |
860 | @upper_half | |
861 | @bottom_half | |
862 | def _do_send(self, msg: T) -> None: | |
863 | """ | |
864 | Abstract: Write a message to the stream. | |
865 | ||
866 | Very low-level; intended to only be called by `_send()`. | |
867 | """ | |
868 | raise NotImplementedError | |
869 | ||
870 | @upper_half | |
871 | @bottom_half | |
872 | async def _send(self, msg: T) -> None: | |
873 | """ | |
874 | Send an arbitrary protocol message. | |
875 | ||
876 | This method will transform any outgoing messages according to | |
877 | `_cb_outbound()`. | |
878 | ||
879 | .. warning:: | |
880 | Like `_recv()`, this method is intended to be called by | |
881 | the writer task loop that processes outgoing | |
882 | messages. Calling it directly may circumvent logic | |
883 | implemented by the caller meant to correlate outgoing and | |
884 | incoming messages. | |
885 | ||
886 | :raise OSError: For problems with the underlying stream. | |
887 | """ | |
12c7a57f | 888 | msg = self._cb_outbound(msg) |
4ccaab03 JS |
889 | self._do_send(msg) |
890 | ||
891 | @bottom_half | |
892 | async def _on_message(self, msg: T) -> None: | |
893 | """ | |
894 | Called to handle the receipt of a new message. | |
895 | ||
896 | .. caution:: | |
897 | This is executed from within the reader loop, so be advised | |
898 | that waiting on either the reader or writer task will lead | |
899 | to deadlock. Additionally, any unhandled exceptions will | |
900 | directly cause the loop to halt, so logic may be best-kept | |
901 | to a minimum if at all possible. | |
902 | ||
12c7a57f | 903 | :param msg: The incoming message, already logged/filtered. |
4ccaab03 JS |
904 | """ |
905 | # Nothing to do in the abstract case. |