]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - drivers/vhost/net.c
vhost_net: try batch dequing from skb array
[mirror_ubuntu-bionic-kernel.git] / drivers / vhost / net.c
index f61f852d6cfd114978036208bf2b0a5c841231d4..e3d7ea1288c68a55efead365e5a760f37d487015 100644 (file)
@@ -28,6 +28,8 @@
 #include <linux/if_macvlan.h>
 #include <linux/if_tap.h>
 #include <linux/if_vlan.h>
+#include <linux/skb_array.h>
+#include <linux/skbuff.h>
 
 #include <net/sock.h>
 
@@ -85,6 +87,13 @@ struct vhost_net_ubuf_ref {
        struct vhost_virtqueue *vq;
 };
 
+#define VHOST_RX_BATCH 64
+struct vhost_net_buf {
+       struct sk_buff **queue;
+       int tail;
+       int head;
+};
+
 struct vhost_net_virtqueue {
        struct vhost_virtqueue vq;
        size_t vhost_hlen;
@@ -99,6 +108,8 @@ struct vhost_net_virtqueue {
        /* Reference counting for outstanding ubufs.
         * Protected by vq mutex. Writers must also take device mutex. */
        struct vhost_net_ubuf_ref *ubufs;
+       struct skb_array *rx_array;
+       struct vhost_net_buf rxq;
 };
 
 struct vhost_net {
@@ -117,6 +128,71 @@ struct vhost_net {
 
 static unsigned vhost_net_zcopy_mask __read_mostly;
 
+static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
+{
+       if (rxq->tail != rxq->head)
+               return rxq->queue[rxq->head];
+       else
+               return NULL;
+}
+
+static int vhost_net_buf_get_size(struct vhost_net_buf *rxq)
+{
+       return rxq->tail - rxq->head;
+}
+
+static int vhost_net_buf_is_empty(struct vhost_net_buf *rxq)
+{
+       return rxq->tail == rxq->head;
+}
+
+static void *vhost_net_buf_consume(struct vhost_net_buf *rxq)
+{
+       void *ret = vhost_net_buf_get_ptr(rxq);
+       ++rxq->head;
+       return ret;
+}
+
+static int vhost_net_buf_produce(struct vhost_net_virtqueue *nvq)
+{
+       struct vhost_net_buf *rxq = &nvq->rxq;
+
+       rxq->head = 0;
+       rxq->tail = skb_array_consume_batched(nvq->rx_array, rxq->queue,
+                                             VHOST_RX_BATCH);
+       return rxq->tail;
+}
+
+static void vhost_net_buf_unproduce(struct vhost_net_virtqueue *nvq)
+{
+       struct vhost_net_buf *rxq = &nvq->rxq;
+
+       if (nvq->rx_array && !vhost_net_buf_is_empty(rxq)) {
+               skb_array_unconsume(nvq->rx_array, rxq->queue + rxq->head,
+                                   vhost_net_buf_get_size(rxq));
+               rxq->head = rxq->tail = 0;
+       }
+}
+
+static int vhost_net_buf_peek(struct vhost_net_virtqueue *nvq)
+{
+       struct vhost_net_buf *rxq = &nvq->rxq;
+
+       if (!vhost_net_buf_is_empty(rxq))
+               goto out;
+
+       if (!vhost_net_buf_produce(nvq))
+               return 0;
+
+out:
+       return __skb_array_len_with_tag(vhost_net_buf_get_ptr(rxq));
+}
+
+static void vhost_net_buf_init(struct vhost_net_buf *rxq)
+{
+       rxq->head = rxq->tail = 0;
+}
+
 static void vhost_net_enable_zcopy(int vq)
 {
        vhost_net_zcopy_mask |= 0x1 << vq;
@@ -201,6 +277,7 @@ static void vhost_net_vq_reset(struct vhost_net *n)
                n->vqs[i].ubufs = NULL;
                n->vqs[i].vhost_hlen = 0;
                n->vqs[i].sock_hlen = 0;
+               vhost_net_buf_init(&n->vqs[i].rxq);
        }
 
 }
@@ -503,15 +580,14 @@ out:
        mutex_unlock(&vq->mutex);
 }
 
-static int peek_head_len(struct sock *sk)
+static int peek_head_len(struct vhost_net_virtqueue *rvq, struct sock *sk)
 {
-       struct socket *sock = sk->sk_socket;
        struct sk_buff *head;
        int len = 0;
        unsigned long flags;
 
-       if (sock->ops->peek_len)
-               return sock->ops->peek_len(sock);
+       if (rvq->rx_array)
+               return vhost_net_buf_peek(rvq);
 
        spin_lock_irqsave(&sk->sk_receive_queue.lock, flags);
        head = skb_peek(&sk->sk_receive_queue);
@@ -537,10 +613,11 @@ static int sk_has_rx_data(struct sock *sk)
 
 static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
 {
+       struct vhost_net_virtqueue *rvq = &net->vqs[VHOST_NET_VQ_RX];
        struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_TX];
        struct vhost_virtqueue *vq = &nvq->vq;
        unsigned long uninitialized_var(endtime);
-       int len = peek_head_len(sk);
+       int len = peek_head_len(rvq, sk);
 
        if (!len && vq->busyloop_timeout) {
                /* Both tx vq and rx socket were polled here */
@@ -561,7 +638,7 @@ static int vhost_net_rx_peek_head_len(struct vhost_net *net, struct sock *sk)
                        vhost_poll_queue(&vq->poll);
                mutex_unlock(&vq->mutex);
 
-               len = peek_head_len(sk);
+               len = peek_head_len(rvq, sk);
        }
 
        return len;
@@ -699,6 +776,8 @@ static void handle_rx(struct vhost_net *net)
                /* On error, stop handling until the next kick. */
                if (unlikely(headcount < 0))
                        goto out;
+               if (nvq->rx_array)
+                       msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
                /* On overrun, truncate and discard */
                if (unlikely(headcount > UIO_MAXIOV)) {
                        iov_iter_init(&msg.msg_iter, READ, vq->iov, 1, 1);
@@ -815,6 +894,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
        struct vhost_net *n;
        struct vhost_dev *dev;
        struct vhost_virtqueue **vqs;
+       struct sk_buff **queue;
        int i;
 
        n = kvmalloc(sizeof *n, GFP_KERNEL | __GFP_REPEAT);
@@ -826,6 +906,15 @@ static int vhost_net_open(struct inode *inode, struct file *f)
                return -ENOMEM;
        }
 
+       queue = kmalloc_array(VHOST_RX_BATCH, sizeof(struct sk_buff *),
+                             GFP_KERNEL);
+       if (!queue) {
+               kfree(vqs);
+               kvfree(n);
+               return -ENOMEM;
+       }
+       n->vqs[VHOST_NET_VQ_RX].rxq.queue = queue;
+
        dev = &n->dev;
        vqs[VHOST_NET_VQ_TX] = &n->vqs[VHOST_NET_VQ_TX].vq;
        vqs[VHOST_NET_VQ_RX] = &n->vqs[VHOST_NET_VQ_RX].vq;
@@ -838,6 +927,7 @@ static int vhost_net_open(struct inode *inode, struct file *f)
                n->vqs[i].done_idx = 0;
                n->vqs[i].vhost_hlen = 0;
                n->vqs[i].sock_hlen = 0;
+               vhost_net_buf_init(&n->vqs[i].rxq);
        }
        vhost_dev_init(dev, vqs, VHOST_NET_VQ_MAX);
 
@@ -853,11 +943,14 @@ static struct socket *vhost_net_stop_vq(struct vhost_net *n,
                                        struct vhost_virtqueue *vq)
 {
        struct socket *sock;
+       struct vhost_net_virtqueue *nvq =
+               container_of(vq, struct vhost_net_virtqueue, vq);
 
        mutex_lock(&vq->mutex);
        sock = vq->private_data;
        vhost_net_disable_vq(n, vq);
        vq->private_data = NULL;
+       vhost_net_buf_unproduce(nvq);
        mutex_unlock(&vq->mutex);
        return sock;
 }
@@ -912,6 +1005,7 @@ static int vhost_net_release(struct inode *inode, struct file *f)
        /* We do an extra flush before freeing memory,
         * since jobs can re-queue themselves. */
        vhost_net_flush(n);
+       kfree(n->vqs[VHOST_NET_VQ_RX].rxq.queue);
        kfree(n->dev.vqs);
        kvfree(n);
        return 0;
@@ -950,6 +1044,25 @@ err:
        return ERR_PTR(r);
 }
 
+static struct skb_array *get_tap_skb_array(int fd)
+{
+       struct skb_array *array;
+       struct file *file = fget(fd);
+
+       if (!file)
+               return NULL;
+       array = tun_get_skb_array(file);
+       if (!IS_ERR(array))
+               goto out;
+       array = tap_get_skb_array(file);
+       if (!IS_ERR(array))
+               goto out;
+       array = NULL;
+out:
+       fput(file);
+       return array;
+}
+
 static struct socket *get_tap_socket(int fd)
 {
        struct file *file = fget(fd);
@@ -1026,6 +1139,9 @@ static long vhost_net_set_backend(struct vhost_net *n, unsigned index, int fd)
 
                vhost_net_disable_vq(n, vq);
                vq->private_data = sock;
+               vhost_net_buf_unproduce(nvq);
+               if (index == VHOST_NET_VQ_RX)
+                       nvq->rx_array = get_tap_skb_array(fd);
                r = vhost_vq_init_access(vq);
                if (r)
                        goto err_used;