]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - drivers/vhost/vhost.c
vhost: correctly check the return value of translate_desc() in log_used()
[mirror_ubuntu-bionic-kernel.git] / drivers / vhost / vhost.c
index 33ac2b186b85eb1f4883d26d6d0d9b3a8532fc01..7cbc18bdaf672f4b5b9c48c9f8f04fe0220eaf8a 100644 (file)
@@ -30,6 +30,7 @@
 #include <linux/sched/mm.h>
 #include <linux/sched/signal.h>
 #include <linux/interval_tree_generic.h>
+#include <linux/nospec.h>
 
 #include "vhost.h"
 
@@ -213,8 +214,7 @@ int vhost_poll_start(struct vhost_poll *poll, struct file *file)
        if (mask)
                vhost_poll_wakeup(&poll->wait, 0, 0, (void *)mask);
        if (mask & POLLERR) {
-               if (poll->wqh)
-                       remove_wait_queue(poll->wqh, &poll->wait);
+               vhost_poll_stop(poll);
                ret = -EINVAL;
        }
 
@@ -757,7 +757,7 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
                struct iov_iter t;
                void __user *uaddr = vhost_vq_meta_fetch(vq,
                                     (u64)(uintptr_t)to, size,
-                                    VHOST_ADDR_DESC);
+                                    VHOST_ADDR_USED);
 
                if (uaddr)
                        return __copy_to_user(uaddr, from, size);
@@ -904,7 +904,7 @@ static void vhost_dev_lock_vqs(struct vhost_dev *d)
 {
        int i = 0;
        for (i = 0; i < d->nvqs; ++i)
-               mutex_lock(&d->vqs[i]->mutex);
+               mutex_lock_nested(&d->vqs[i]->mutex, i);
 }
 
 static void vhost_dev_unlock_vqs(struct vhost_dev *d)
@@ -961,7 +961,7 @@ static void vhost_iotlb_notify_vq(struct vhost_dev *d,
        list_for_each_entry_safe(node, n, &d->pending_list, node) {
                struct vhost_iotlb_msg *vq_msg = &node->msg.iotlb;
                if (msg->iova <= vq_msg->iova &&
-                   msg->iova + msg->size - 1 > vq_msg->iova &&
+                   msg->iova + msg->size - 1 >= vq_msg->iova &&
                    vq_msg->type == VHOST_IOTLB_MISS) {
                        vhost_poll_queue(&node->vq->poll);
                        list_del(&node->node);
@@ -994,6 +994,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
 {
        int ret = 0;
 
+       mutex_lock(&dev->mutex);
        vhost_dev_lock_vqs(dev);
        switch (msg->type) {
        case VHOST_IOTLB_UPDATE:
@@ -1015,6 +1016,10 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
                vhost_iotlb_notify_vq(dev, msg);
                break;
        case VHOST_IOTLB_INVALIDATE:
+               if (!dev->iotlb) {
+                       ret = -EFAULT;
+                       break;
+               }
                vhost_vq_meta_reset(dev);
                vhost_del_umem_range(dev->iotlb, msg->iova,
                                     msg->iova + msg->size - 1);
@@ -1025,6 +1030,8 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
        }
 
        vhost_dev_unlock_vqs(dev);
+       mutex_unlock(&dev->mutex);
+
        return ret;
 }
 ssize_t vhost_chr_write_iter(struct vhost_dev *dev,
@@ -1253,14 +1260,14 @@ static int vq_log_access_ok(struct vhost_virtqueue *vq,
 /* Caller should have vq mutex and device mutex */
 int vhost_vq_access_ok(struct vhost_virtqueue *vq)
 {
-       if (vq->iotlb) {
-               /* When device IOTLB was used, the access validation
-                * will be validated during prefetching.
-                */
+       if (!vq_log_access_ok(vq, vq->log_base))
+               return 0;
+
+       /* Access validation occurs at prefetch time with IOTLB */
+       if (vq->iotlb)
                return 1;
-       }
-       return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used) &&
-               vq_log_access_ok(vq, vq->log_base);
+
+       return vq_access_ok(vq, vq->num, vq->desc, vq->avail, vq->used);
 }
 EXPORT_SYMBOL_GPL(vhost_vq_access_ok);
 
@@ -1364,6 +1371,7 @@ long vhost_vring_ioctl(struct vhost_dev *d, int ioctl, void __user *argp)
        if (idx >= d->nvqs)
                return -ENOBUFS;
 
+       idx = array_index_nospec(idx, d->nvqs);
        vq = d->vqs[idx];
 
        mutex_lock(&vq->mutex);
@@ -1576,9 +1584,12 @@ int vhost_init_device_iotlb(struct vhost_dev *d, bool enabled)
        d->iotlb = niotlb;
 
        for (i = 0; i < d->nvqs; ++i) {
-               mutex_lock(&d->vqs[i]->mutex);
-               d->vqs[i]->iotlb = niotlb;
-               mutex_unlock(&d->vqs[i]->mutex);
+               struct vhost_virtqueue *vq = d->vqs[i];
+
+               mutex_lock(&vq->mutex);
+               vq->iotlb = niotlb;
+               __vhost_vq_meta_reset(vq);
+               mutex_unlock(&vq->mutex);
        }
 
        vhost_umem_clean(oiotlb);
@@ -1719,13 +1730,87 @@ static int log_write(void __user *log_base,
        return r;
 }
 
+static int log_write_hva(struct vhost_virtqueue *vq, u64 hva, u64 len)
+{
+       struct vhost_umem *umem = vq->umem;
+       struct vhost_umem_node *u;
+       u64 start, end, l, min;
+       int r;
+       bool hit = false;
+
+       while (len) {
+               min = len;
+               /* More than one GPAs can be mapped into a single HVA. So
+                * iterate all possible umems here to be safe.
+                */
+               list_for_each_entry(u, &umem->umem_list, link) {
+                       if (u->userspace_addr > hva - 1 + len ||
+                           u->userspace_addr - 1 + u->size < hva)
+                               continue;
+                       start = max(u->userspace_addr, hva);
+                       end = min(u->userspace_addr - 1 + u->size,
+                                 hva - 1 + len);
+                       l = end - start + 1;
+                       r = log_write(vq->log_base,
+                                     u->start + start - u->userspace_addr,
+                                     l);
+                       if (r < 0)
+                               return r;
+                       hit = true;
+                       min = min(l, min);
+               }
+
+               if (!hit)
+                       return -EFAULT;
+
+               len -= min;
+               hva += min;
+       }
+
+       return 0;
+}
+
+static int log_used(struct vhost_virtqueue *vq, u64 used_offset, u64 len)
+{
+       struct iovec iov[64];
+       int i, ret;
+
+       if (!vq->iotlb)
+               return log_write(vq->log_base, vq->log_addr + used_offset, len);
+
+       ret = translate_desc(vq, (uintptr_t)vq->used + used_offset,
+                            len, iov, 64, VHOST_ACCESS_WO);
+       if (ret < 0)
+               return ret;
+
+       for (i = 0; i < ret; i++) {
+               ret = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+                                   iov[i].iov_len);
+               if (ret)
+                       return ret;
+       }
+
+       return 0;
+}
+
 int vhost_log_write(struct vhost_virtqueue *vq, struct vhost_log *log,
-                   unsigned int log_num, u64 len)
+                   unsigned int log_num, u64 len, struct iovec *iov, int count)
 {
        int i, r;
 
        /* Make sure data written is seen before log. */
        smp_wmb();
+
+       if (vq->iotlb) {
+               for (i = 0; i < count; i++) {
+                       r = log_write_hva(vq, (uintptr_t)iov[i].iov_base,
+                                         iov[i].iov_len);
+                       if (r < 0)
+                               return r;
+               }
+               return 0;
+       }
+
        for (i = 0; i < log_num; ++i) {
                u64 l = min(log[i].len, len);
                r = log_write(vq->log_base, log[i].addr, l);
@@ -1755,9 +1840,8 @@ static int vhost_update_used_flags(struct vhost_virtqueue *vq)
                smp_wmb();
                /* Log used flag write. */
                used = &vq->used->flags;
-               log_write(vq->log_base, vq->log_addr +
-                         (used - (void __user *)vq->used),
-                         sizeof vq->used->flags);
+               log_used(vq, (used - (void __user *)vq->used),
+                        sizeof vq->used->flags);
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -1775,9 +1859,8 @@ static int vhost_update_avail_event(struct vhost_virtqueue *vq, u16 avail_event)
                smp_wmb();
                /* Log avail event write */
                used = vhost_avail_event(vq);
-               log_write(vq->log_base, vq->log_addr +
-                         (used - (void __user *)vq->used),
-                         sizeof *vhost_avail_event(vq));
+               log_used(vq, (used - (void __user *)vq->used),
+                        sizeof *vhost_avail_event(vq));
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -2182,10 +2265,8 @@ static int __vhost_add_used_n(struct vhost_virtqueue *vq,
                /* Make sure data is seen before log. */
                smp_wmb();
                /* Log used ring entry write. */
-               log_write(vq->log_base,
-                         vq->log_addr +
-                          ((void __user *)used - (void __user *)vq->used),
-                         count * sizeof *used);
+               log_used(vq, ((void __user *)used - (void __user *)vq->used),
+                        count * sizeof *used);
        }
        old = vq->last_used_idx;
        new = (vq->last_used_idx += count);
@@ -2224,10 +2305,11 @@ int vhost_add_used_n(struct vhost_virtqueue *vq, struct vring_used_elem *heads,
                return -EFAULT;
        }
        if (unlikely(vq->log_used)) {
+               /* Make sure used idx is seen before log. */
+               smp_wmb();
                /* Log used index update. */
-               log_write(vq->log_base,
-                         vq->log_addr + offsetof(struct vring_used, idx),
-                         sizeof vq->used->idx);
+               log_used(vq, offsetof(struct vring_used, idx),
+                        sizeof vq->used->idx);
                if (vq->log_ctx)
                        eventfd_signal(vq->log_ctx, 1);
        }
@@ -2380,6 +2462,9 @@ struct vhost_msg_node *vhost_new_msg(struct vhost_virtqueue *vq, int type)
        struct vhost_msg_node *node = kmalloc(sizeof *node, GFP_KERNEL);
        if (!node)
                return NULL;
+
+       /* Make sure all padding within the structure is initialized. */
+       memset(&node->msg, 0, sizeof node->msg);
        node->vq = vq;
        node->msg.type = type;
        return node;