]> git.proxmox.com Git - mirror_qemu.git/commitdiff
python/qmp: allow sockets to be passed to connect()
authorJohn Snow <jsnow@redhat.com>
Wed, 17 May 2023 16:34:02 +0000 (12:34 -0400)
committerJohn Snow <jsnow@redhat.com>
Wed, 31 May 2023 20:25:35 +0000 (16:25 -0400)
Allow existing sockets to be passed to connect(). The changes are pretty
minimal, and this allows for far greater flexibility in setting up
communications with an endpoint.

Signed-off-by: John Snow <jsnow@redhat.com>
Message-id: 20230517163406.2593480-2-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
python/qemu/qmp/protocol.py

index 22e60298d2807c0e71e28b7bed278c8c6c595543..d534db4631f45d614b589b6bf2fd68b57a4cfccf 100644 (file)
@@ -370,7 +370,7 @@ class AsyncProtocol(Generic[T]):
 
     @upper_half
     @require(Runstate.IDLE)
-    async def connect(self, address: SocketAddrT,
+    async def connect(self, address: Union[SocketAddrT, socket.socket],
                       ssl: Optional[SSLContext] = None) -> None:
         """
         Connect to the server and begin processing message queues.
@@ -615,7 +615,7 @@ class AsyncProtocol(Generic[T]):
         self.logger.debug("Connection accepted.")
 
     @upper_half
-    async def _do_connect(self, address: SocketAddrT,
+    async def _do_connect(self, address: Union[SocketAddrT, socket.socket],
                           ssl: Optional[SSLContext] = None) -> None:
         """
         Acting as the transport client, initiate a connection to a server.
@@ -634,9 +634,17 @@ class AsyncProtocol(Generic[T]):
         # otherwise yield.
         await asyncio.sleep(0)
 
-        self.logger.debug("Connecting to %s ...", address)
-
-        if isinstance(address, tuple):
+        if isinstance(address, socket.socket):
+            self.logger.debug("Connecting with existing socket: "
+                              "fd=%d, family=%r, type=%r",
+                              address.fileno(), address.family, address.type)
+            connect = asyncio.open_connection(
+                limit=self._limit,
+                ssl=ssl,
+                sock=address,
+            )
+        elif isinstance(address, tuple):
+            self.logger.debug("Connecting to %s ...", address)
             connect = asyncio.open_connection(
                 address[0],
                 address[1],
@@ -644,13 +652,14 @@ class AsyncProtocol(Generic[T]):
                 limit=self._limit,
             )
         else:
+            self.logger.debug("Connecting to file://%s ...", address)
             connect = asyncio.open_unix_connection(
                 path=address,
                 ssl=ssl,
                 limit=self._limit,
             )
-        self._reader, self._writer = await connect
 
+        self._reader, self._writer = await connect
         self.logger.debug("Connected.")
 
     @upper_half