]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - drivers/vhost/vhost.c
vhost: introduce O(1) vq metadata cache
[mirror_ubuntu-bionic-kernel.git] / drivers / vhost / vhost.c
index 1f7e4e4e6f8efee40262024c493c225e197d2ae8..998bed505530df03e033719e88c5d38405d34ac6 100644 (file)
@@ -282,6 +282,22 @@ void vhost_poll_queue(struct vhost_poll *poll)
 }
 EXPORT_SYMBOL_GPL(vhost_poll_queue);
 
+static void __vhost_vq_meta_reset(struct vhost_virtqueue *vq)
+{
+       int j;
+
+       for (j = 0; j < VHOST_NUM_ADDRS; j++)
+               vq->meta_iotlb[j] = NULL;
+}
+
+static void vhost_vq_meta_reset(struct vhost_dev *d)
+{
+       int i;
+
+       for (i = 0; i < d->nvqs; ++i)
+               __vhost_vq_meta_reset(d->vqs[i]);
+}
+
 static void vhost_vq_reset(struct vhost_dev *dev,
                           struct vhost_virtqueue *vq)
 {
@@ -312,6 +328,7 @@ static void vhost_vq_reset(struct vhost_dev *dev,
        vq->busyloop_timeout = 0;
        vq->umem = NULL;
        vq->iotlb = NULL;
+       __vhost_vq_meta_reset(vq);
 }
 
 static int vhost_worker(void *data)
@@ -691,6 +708,18 @@ static int vq_memory_access_ok(void __user *log_base, struct vhost_umem *umem,
        return 1;
 }
 
+static inline void __user *vhost_vq_meta_fetch(struct vhost_virtqueue *vq,
+                                              u64 addr, unsigned int size,
+                                              int type)
+{
+       const struct vhost_umem_node *node = vq->meta_iotlb[type];
+
+       if (!node)
+               return NULL;
+
+       return (void *)(uintptr_t)(node->userspace_addr + addr - node->start);
+}
+
 /* Can we switch to this memory table? */
 /* Caller should have device mutex but not vq mutex */
 static int memory_access_ok(struct vhost_dev *d, struct vhost_umem *umem,
@@ -733,8 +762,14 @@ static int vhost_copy_to_user(struct vhost_virtqueue *vq, void __user *to,
                 * could be access through iotlb. So -EAGAIN should
                 * not happen in this case.
                 */
-               /* TODO: more fast path */
                struct iov_iter t;
+               void __user *uaddr = vhost_vq_meta_fetch(vq,
+                                    (u64)(uintptr_t)to, size,
+                                    VHOST_ADDR_DESC);
+
+               if (uaddr)
+                       return __copy_to_user(uaddr, from, size);
+
                ret = translate_desc(vq, (u64)(uintptr_t)to, size, vq->iotlb_iov,
                                     ARRAY_SIZE(vq->iotlb_iov),
                                     VHOST_ACCESS_WO);
@@ -762,8 +797,14 @@ static int vhost_copy_from_user(struct vhost_virtqueue *vq, void *to,
                 * could be access through iotlb. So -EAGAIN should
                 * not happen in this case.
                 */
-               /* TODO: more fast path */
+               void __user *uaddr = vhost_vq_meta_fetch(vq,
+                                    (u64)(uintptr_t)from, size,
+                                    VHOST_ADDR_DESC);
                struct iov_iter f;
+
+               if (uaddr)
+                       return __copy_from_user(to, uaddr, size);
+
                ret = translate_desc(vq, (u64)(uintptr_t)from, size, vq->iotlb_iov,
                                     ARRAY_SIZE(vq->iotlb_iov),
                                     VHOST_ACCESS_RO);
@@ -783,17 +824,12 @@ out:
        return ret;
 }
 
-static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
-                                    void __user *addr, unsigned size)
+static void __user *__vhost_get_user_slow(struct vhost_virtqueue *vq,
+                                         void __user *addr, unsigned int size,
+                                         int type)
 {
        int ret;
 
-       /* This function should be called after iotlb
-        * prefetch, which means we're sure that vq
-        * could be access through iotlb. So -EAGAIN should
-        * not happen in this case.
-        */
-       /* TODO: more fast path */
        ret = translate_desc(vq, (u64)(uintptr_t)addr, size, vq->iotlb_iov,
                             ARRAY_SIZE(vq->iotlb_iov),
                             VHOST_ACCESS_RO);
@@ -814,14 +850,32 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
        return vq->iotlb_iov[0].iov_base;
 }
 
-#define vhost_put_user(vq, x, ptr) \
+/* This function should be called after iotlb
+ * prefetch, which means we're sure that vq
+ * could be access through iotlb. So -EAGAIN should
+ * not happen in this case.
+ */
+static inline void __user *__vhost_get_user(struct vhost_virtqueue *vq,
+                                           void *addr, unsigned int size,
+                                           int type)
+{
+       void __user *uaddr = vhost_vq_meta_fetch(vq,
+                            (u64)(uintptr_t)addr, size, type);
+       if (uaddr)
+               return uaddr;
+
+       return __vhost_get_user_slow(vq, addr, size, type);
+}
+
+#define vhost_put_user(vq, x, ptr)             \
 ({ \
        int ret = -EFAULT; \
        if (!vq->iotlb) { \
                ret = __put_user(x, ptr); \
        } else { \
                __typeof__(ptr) to = \
-                       (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+                       (__typeof__(ptr)) __vhost_get_user(vq, ptr,     \
+                                         sizeof(*ptr), VHOST_ADDR_USED); \
                if (to != NULL) \
                        ret = __put_user(x, to); \
                else \
@@ -830,14 +884,16 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
        ret; \
 })
 
-#define vhost_get_user(vq, x, ptr) \
+#define vhost_get_user(vq, x, ptr, type)               \
 ({ \
        int ret; \
        if (!vq->iotlb) { \
                ret = __get_user(x, ptr); \
        } else { \
                __typeof__(ptr) from = \
-                       (__typeof__(ptr)) __vhost_get_user(vq, ptr, sizeof(*ptr)); \
+                       (__typeof__(ptr)) __vhost_get_user(vq, ptr, \
+                                                          sizeof(*ptr), \
+                                                          type); \
                if (from != NULL) \
                        ret = __get_user(x, from); \
                else \
@@ -846,6 +902,12 @@ static void __user *__vhost_get_user(struct vhost_virtqueue *vq,
        ret; \
 })
 
+#define vhost_get_avail(vq, x, ptr) \
+       vhost_get_user(vq, x, ptr, VHOST_ADDR_AVAIL)
+
+#define vhost_get_used(vq, x, ptr) \
+       vhost_get_user(vq, x, ptr, VHOST_ADDR_USED)
+
 static void vhost_dev_lock_vqs(struct vhost_dev *d)
 {
        int i = 0;
@@ -951,6 +1013,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
                        ret = -EFAULT;
                        break;
                }
+               vhost_vq_meta_reset(dev);
                if (vhost_new_umem_range(dev->iotlb, msg->iova, msg->size,
                                         msg->iova + msg->size - 1,
                                         msg->uaddr, msg->perm)) {
@@ -960,6 +1023,7 @@ static int vhost_process_iotlb_msg(struct vhost_dev *dev,
                vhost_iotlb_notify_vq(dev, msg);
                break;
        case VHOST_IOTLB_INVALIDATE:
+               vhost_vq_meta_reset(dev);
                vhost_del_umem_range(dev->iotlb, msg->iova,
                                     msg->iova + msg->size - 1);
                break;
@@ -1103,12 +1167,26 @@ static int vq_access_ok(struct vhost_virtqueue *vq, unsigned int num,
                        sizeof *used + num * sizeof *used->ring + s);
 }
 
+static void vhost_vq_meta_update(struct vhost_virtqueue *vq,
+                                const struct vhost_umem_node *node,
+                                int type)
+{
+       int access = (type == VHOST_ADDR_USED) ?
+                    VHOST_ACCESS_WO : VHOST_ACCESS_RO;
+
+       if (likely(node->perm & access))
+               vq->meta_iotlb[type] = node;
+}
+
 static int iotlb_access_ok(struct vhost_virtqueue *vq,
-                          int access, u64 addr, u64 len)
+                          int access, u64 addr, u64 len, int type)
 {
        const struct vhost_umem_node *node;
        struct vhost_umem *umem = vq->iotlb;
-       u64 s = 0, size;
+       u64 s = 0, size, orig_addr = addr;
+
+       if (vhost_vq_meta_fetch(vq, addr, len, type))
+               return true;
 
        while (len > s) {
                node = vhost_umem_interval_tree_iter_first(&umem->umem_tree,
@@ -1125,6 +1203,10 @@ static int iotlb_access_ok(struct vhost_virtqueue *vq,
                }
 
                size = node->size - addr + node->start;
+
+               if (orig_addr == addr && size >= len)
+                       vhost_vq_meta_update(vq, node, type);
+
                s += size;
                addr += size;
        }
@@ -1141,13 +1223,15 @@ int vq_iotlb_prefetch(struct vhost_virtqueue *vq)
                return 1;
 
        return iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->desc,
-                              num * sizeof *vq->desc) &&
+                              num * sizeof(*vq->desc), VHOST_ADDR_DESC) &&
               iotlb_access_ok(vq, VHOST_ACCESS_RO, (u64)(uintptr_t)vq->avail,
                               sizeof *vq->avail +
-                              num * sizeof *vq->avail->ring + s) &&
+                              num * sizeof(*vq->avail->ring) + s,
+                              VHOST_ADDR_AVAIL) &&
               iotlb_access_ok(vq, VHOST_ACCESS_WO, (u64)(uintptr_t)vq->used,
                               sizeof *vq->used +
-                              num * sizeof *vq->used->ring + s);
+                              num * sizeof(*vq->used->ring) + s,
+                              VHOST_ADDR_USED);
 }
 EXPORT_SYMBOL_GPL(vq_iotlb_prefetch);
 
@@ -1728,7 +1812,7 @@ int vhost_vq_init_access(struct vhost_virtqueue *vq)
                r = -EFAULT;
                goto err;
        }
-       r = vhost_get_user(vq, last_used_idx, &vq->used->idx);
+       r = vhost_get_used(vq, last_used_idx, &vq->used->idx);
        if (r) {
                vq_err(vq, "Can't access used idx at %p\n",
                       &vq->used->idx);
@@ -1932,7 +2016,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
        last_avail_idx = vq->last_avail_idx;
 
        if (vq->avail_idx == vq->last_avail_idx) {
-               if (unlikely(vhost_get_user(vq, avail_idx, &vq->avail->idx))) {
+               if (unlikely(vhost_get_avail(vq, avail_idx, &vq->avail->idx))) {
                        vq_err(vq, "Failed to access avail idx at %p\n",
                                &vq->avail->idx);
                        return -EFAULT;
@@ -1959,7 +2043,7 @@ int vhost_get_vq_desc(struct vhost_virtqueue *vq,
 
        /* Grab the next descriptor number they're advertising, and increment
         * the index we've seen. */
-       if (unlikely(vhost_get_user(vq, ring_head,
+       if (unlikely(vhost_get_avail(vq, ring_head,
                     &vq->avail->ring[last_avail_idx & (vq->num - 1)]))) {
                vq_err(vq, "Failed to read head: idx %d address %p\n",
                       last_avail_idx,
@@ -2175,7 +2259,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
                 * with the barrier that the Guest executes when enabling
                 * interrupts. */
                smp_mb();
-               if (vhost_get_user(vq, flags, &vq->avail->flags)) {
+               if (vhost_get_avail(vq, flags, &vq->avail->flags)) {
                        vq_err(vq, "Failed to get flags");
                        return true;
                }
@@ -2202,7 +2286,7 @@ static bool vhost_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
         * interrupts. */
        smp_mb();
 
-       if (vhost_get_user(vq, event, vhost_used_event(vq))) {
+       if (vhost_get_avail(vq, event, vhost_used_event(vq))) {
                vq_err(vq, "Failed to get used event idx");
                return true;
        }
@@ -2246,7 +2330,7 @@ bool vhost_vq_avail_empty(struct vhost_dev *dev, struct vhost_virtqueue *vq)
        __virtio16 avail_idx;
        int r;
 
-       r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
+       r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
        if (r)
                return false;
 
@@ -2281,7 +2365,7 @@ bool vhost_enable_notify(struct vhost_dev *dev, struct vhost_virtqueue *vq)
        /* They could have slipped one in as we were doing that: make
         * sure it's written, then check again. */
        smp_mb();
-       r = vhost_get_user(vq, avail_idx, &vq->avail->idx);
+       r = vhost_get_avail(vq, avail_idx, &vq->avail->idx);
        if (r) {
                vq_err(vq, "Failed to check avail idx at %p: %d\n",
                       &vq->avail->idx, r);