]> git.proxmox.com Git - mirror_qemu.git/blobdiff - io/channel-tls.c
Merge remote-tracking branch 'remotes/armbru/tags/pull-qapi-2019-01-24' into staging
[mirror_qemu.git] / io / channel-tls.c
index d24dc8c613ab494d701246367d685b93501bd1b4..c98ead21b01ea23cbfa0a99a3cff4c01141158e6 100644 (file)
@@ -140,21 +140,28 @@ qio_channel_tls_new_client(QIOChannel *master,
     return NULL;
 }
 
+struct QIOChannelTLSData {
+    QIOTask *task;
+    GMainContext *context;
+};
+typedef struct QIOChannelTLSData QIOChannelTLSData;
 
 static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
                                              GIOCondition condition,
                                              gpointer user_data);
 
 static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
-                                           QIOTask *task)
+                                           QIOTask *task,
+                                           GMainContext *context)
 {
     Error *err = NULL;
     QCryptoTLSSessionHandshakeStatus status;
 
     if (qcrypto_tls_session_handshake(ioc->session, &err) < 0) {
         trace_qio_channel_tls_handshake_fail(ioc);
-        qio_task_abort(task, err);
-        goto cleanup;
+        qio_task_set_error(task, err);
+        qio_task_complete(task);
+        return;
     }
 
     status = qcrypto_tls_session_get_handshake_status(ioc->session);
@@ -163,13 +170,22 @@ static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
         if (qcrypto_tls_session_check_credentials(ioc->session,
                                                   &err) < 0) {
             trace_qio_channel_tls_credentials_deny(ioc);
-            qio_task_abort(task, err);
-            goto cleanup;
+            qio_task_set_error(task, err);
+        } else {
+            trace_qio_channel_tls_credentials_allow(ioc);
         }
-        trace_qio_channel_tls_credentials_allow(ioc);
         qio_task_complete(task);
     } else {
         GIOCondition condition;
+        QIOChannelTLSData *data = g_new0(typeof(*data), 1);
+
+        data->task = task;
+        data->context = context;
+
+        if (context) {
+            g_main_context_ref(context);
+        }
+
         if (status == QCRYPTO_TLS_HANDSHAKE_SENDING) {
             condition = G_IO_OUT;
         } else {
@@ -177,15 +193,13 @@ static void qio_channel_tls_handshake_task(QIOChannelTLS *ioc,
         }
 
         trace_qio_channel_tls_handshake_pending(ioc, status);
-        qio_channel_add_watch(ioc->master,
-                              condition,
-                              qio_channel_tls_handshake_io,
-                              task,
-                              NULL);
+        qio_channel_add_watch_full(ioc->master,
+                                   condition,
+                                   qio_channel_tls_handshake_io,
+                                   data,
+                                   NULL,
+                                   context);
     }
-
- cleanup:
-    error_free(err);
 }
 
 
@@ -193,14 +207,18 @@ static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
                                              GIOCondition condition,
                                              gpointer user_data)
 {
-    QIOTask *task = user_data;
+    QIOChannelTLSData *data = user_data;
+    QIOTask *task = data->task;
+    GMainContext *context = data->context;
     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(
         qio_task_get_source(task));
 
-    qio_channel_tls_handshake_task(
-       tioc, task);
+    g_free(data);
+    qio_channel_tls_handshake_task(tioc, task, context);
 
-    object_unref(OBJECT(tioc));
+    if (context) {
+        g_main_context_unref(context);
+    }
 
     return FALSE;
 }
@@ -208,7 +226,8 @@ static gboolean qio_channel_tls_handshake_io(QIOChannel *ioc,
 void qio_channel_tls_handshake(QIOChannelTLS *ioc,
                                QIOTaskFunc func,
                                gpointer opaque,
-                               GDestroyNotify destroy)
+                               GDestroyNotify destroy,
+                               GMainContext *context)
 {
     QIOTask *task;
 
@@ -216,7 +235,7 @@ void qio_channel_tls_handshake(QIOChannelTLS *ioc,
                         func, opaque, destroy);
 
     trace_qio_channel_tls_handshake_start(ioc);
-    qio_channel_tls_handshake_task(ioc, task);
+    qio_channel_tls_handshake_task(ioc, task, context);
 }
 
 
@@ -256,6 +275,9 @@ static ssize_t qio_channel_tls_readv(QIOChannel *ioc,
                 } else {
                     return QIO_CHANNEL_ERR_BLOCK;
                 }
+            } else if (errno == ECONNABORTED &&
+                       (tioc->shutdown & QIO_CHANNEL_SHUTDOWN_READ)) {
+                return 0;
             }
 
             error_setg_errno(errp, errno,
@@ -338,6 +360,8 @@ static int qio_channel_tls_shutdown(QIOChannel *ioc,
 {
     QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
 
+    tioc->shutdown |= how;
+
     return qio_channel_shutdown(tioc->master, how, errp);
 }
 
@@ -349,6 +373,17 @@ static int qio_channel_tls_close(QIOChannel *ioc,
     return qio_channel_close(tioc->master, errp);
 }
 
+static void qio_channel_tls_set_aio_fd_handler(QIOChannel *ioc,
+                                               AioContext *ctx,
+                                               IOHandler *io_read,
+                                               IOHandler *io_write,
+                                               void *opaque)
+{
+    QIOChannelTLS *tioc = QIO_CHANNEL_TLS(ioc);
+
+    qio_channel_set_aio_fd_handler(tioc->master, ctx, io_read, io_write, opaque);
+}
+
 static GSource *qio_channel_tls_create_watch(QIOChannel *ioc,
                                              GIOCondition condition)
 {
@@ -376,6 +411,7 @@ static void qio_channel_tls_class_init(ObjectClass *klass,
     ioc_klass->io_close = qio_channel_tls_close;
     ioc_klass->io_shutdown = qio_channel_tls_shutdown;
     ioc_klass->io_create_watch = qio_channel_tls_create_watch;
+    ioc_klass->io_set_aio_fd_handler = qio_channel_tls_set_aio_fd_handler;
 }
 
 static const TypeInfo qio_channel_tls_info = {