]> git.proxmox.com Git - mirror_qemu.git/blobdiff - io/channel-tls.c
Merge remote-tracking branch 'remotes/kraxel/tags/vga-20190705-pull-request' into...
[mirror_qemu.git] / io / channel-tls.c
index 7608fd9de09046b191bd90fa728c00a4e377d983..7ec8ceff2f01744080126985c101621a1aa50e01 100644 (file)
@@ -19,6 +19,8 @@
  */
 
 #include "qemu/osdep.h"
+#include "qapi/error.h"
+#include "qemu/module.h"
 #include "io/channel-tls.h"
 #include "trace.h"
 
@@ -110,8 +112,8 @@ qio_channel_tls_new_client(QIOChannel *master,
     ioc = QIO_CHANNEL(tioc);
 
     tioc->master = master;
-    if (master->features & (1 << QIO_CHANNEL_FEATURE_SHUTDOWN)) {
-        ioc->features |= (1 << QIO_CHANNEL_FEATURE_SHUTDOWN);
+    if (qio_channel_has_feature(master, QIO_CHANNEL_FEATURE_SHUTDOWN)) {
+        qio_channel_set_feature(ioc, QIO_CHANNEL_FEATURE_SHUTDOWN);
     }
     object_ref(OBJECT(master));
 
@@ -139,21 +141,28 @@ qio_channel_tls_new_client(QIOChannel *master,
     return NULL;
 }
 
+struct QIOChannelTLSData {
+    QIOTask *task;
+    GMainContext *context;
+};
+typedef struct QIOChannelTLSData QIOChannelTLSData;
 
 static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
                                              GIOCondition condition,
                                              gpointer user_data);
 
 static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
-                                           QIOTask *task)
+                                           QIOTask *task,
+                                           GMainContext *context)
 {
     Error *err = NULL;
     QCryptoTLSSessionHandshakeStatus status;
 
     if (qcrypto_tls_session_handshake(ioc->session, &err) < 0) {
         trace_qio_channel_tls_handshake_fail(ioc);
-        qio_task_abort(task, err);
-        goto cleanup;
+        qio_task_set_error(task, err);
+        qio_task_complete(task);
+        return;
     }
 
     status = qcrypto_tls_session_get_handshake_status(ioc->session);
@@ -162,13 +171,22 @@ static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
         if (qcrypto_tls_session_check_credentials(ioc->session,
                                                   &err) < 0) {
             trace_qio_channel_tls_credentials_deny(ioc);
-            qio_task_abort(task, err);
-            goto cleanup;
+            qio_task_set_error(task, err);
+        } else {
+            trace_qio_channel_tls_credentials_allow(ioc);
         }
-        trace_qio_channel_tls_credentials_allow(ioc);
         qio_task_complete(task);
     } else {
         GIOCondition condition;
+        QIOChannelTLSData *data = g_new0(typeof(*data), 1);
+
+        data->task = task;
+        data->context = context;
+
+        if (context) {
+            g_main_context_ref(context);
+        }
+
         if (status == QCRYPTO_TLS_HANDSHAKE_SENDING) {
             condition = G_IO_OUT;
         } else {
@@ -176,15 +194,13 @@ static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
         }
 
         trace_qio_channel_tls_handshake_pending(ioc, status);
-        qio_channel_add_watch(ioc->master,
-                              condition,
-                              qio_channel_tls_handshake_io,
-                              task,
-                              NULL);
+        qio_channel_add_watch_full(ioc->master,
+                                   condition,
+                                   qio_channel_tls_handshake_io,
+                                   data,
+                                   NULL,
+                                   context);
     }
-
- cleanup:
-    error_free(err);
 }
 
 
@@ -192,14 +208,18 @@ static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
                                              GIOCondition condition,
                                              gpointer user_data)
 {
-    QIOTask *task = user_data;
+    QIOChannelTLSData *data = user_data;
+    QIOTask *task = data->task;
+    GMainContext *context = data->context;
     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(
         qio_task_get_source(task));
 
-    qio_channel_tls_handshake_task(
-       tioc, task);
+    g_free(data);
+    qio_channel_tls_handshake_task(tioc, task, context);
 
-    object_unref(OBJECT(tioc));
+    if (context) {
+        g_main_context_unref(context);
+    }
 
     return FALSE;
 }
@@ -207,7 +227,8 @@ static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
 void qio_channel_tls_handshake(QIOChannelTLS *ioc,
                                QIOTaskFunc func,
                                gpointer opaque,
-                               GDestroyNotify destroy)
+                               GDestroyNotify destroy,
+                               GMainContext *context)
 {
     QIOTask *task;
 
@@ -215,7 +236,7 @@ void qio_channel_tls_handshake(QIOChannelTLS *ioc,
                         func, opaque, destroy);
 
     trace_qio_channel_tls_handshake_start(ioc);
-    qio_channel_tls_handshake_task(ioc, task);
+    qio_channel_tls_handshake_task(ioc, task, context);
 }
 
 
@@ -255,6 +276,9 @@ static ssize_t qio_channel_tls_readv(QIOChannel *ioc,
                 } else {
                     return QIO_CHANNEL_ERR_BLOCK;
                 }
+            } else if (errno == ECONNABORTED &&
+                       (tioc->shutdown & QIO_CHANNEL_SHUTDOWN_READ)) {
+                return 0;
             }
 
             error_setg_errno(errp, errno,
@@ -337,6 +361,8 @@ static int qio_channel_tls_shutdown(QIOChannel *ioc,
 {
     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
 
+    tioc->shutdown |= how;
+
     return qio_channel_shutdown(tioc->master, how, errp);
 }
 
@@ -348,6 +374,17 @@ static int qio_channel_tls_close(QIOChannel *ioc,
     return qio_channel_close(tioc->master, errp);
 }
 
+static void qio_channel_tls_set_aio_fd_handler(QIOChannel *ioc,
+                                               AioContext *ctx,
+                                               IOHandler *io_read,
+                                               IOHandler *io_write,
+                                               void *opaque)
+{
+    QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
+
+    qio_channel_set_aio_fd_handler(tioc->master, ctx, io_read, io_write, opaque);
+}
+
 static GSource *qio_channel_tls_create_watch(QIOChannel *ioc,
                                              GIOCondition condition)
 {
@@ -375,6 +412,7 @@ static void qio_channel_tls_class_init(ObjectClass *klass,
     ioc_klass->io_close = qio_channel_tls_close;
     ioc_klass->io_shutdown = qio_channel_tls_shutdown;
     ioc_klass->io_create_watch = qio_channel_tls_create_watch;
+    ioc_klass->io_set_aio_fd_handler = qio_channel_tls_set_aio_fd_handler;
 }
 
 static const TypeInfo qio_channel_tls_info = {