]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - drivers/net/virtio_net.c
virtio-net: batch stats updating
[mirror_ubuntu-bionic-kernel.git] / drivers / net / virtio_net.c
index 4a105006ca637bc985698fa378ff6d14e1494b3f..05a83dbc910d04387ae90ca9a0da299d59967810 100644 (file)
 #include <linux/virtio.h>
 #include <linux/virtio_net.h>
 #include <linux/bpf.h>
+#include <linux/bpf_trace.h>
 #include <linux/scatterlist.h>
 #include <linux/if_vlan.h>
 #include <linux/slab.h>
 #include <linux/cpu.h>
 #include <linux/average.h>
-#include <net/busy_poll.h>
 
 static int napi_weight = NAPI_POLL_WEIGHT;
 module_param(napi_weight, int, 0444);
@@ -41,6 +41,9 @@ module_param(gso, bool, 0444);
 #define GOOD_PACKET_LEN (ETH_HLEN + VLAN_HLEN + ETH_DATA_LEN)
 #define GOOD_COPY_LEN  128
 
+/* Amount of XDP headroom to prepend to packets for use by xdp_adjust_head */
+#define VIRTIO_XDP_HEADROOM 256
+
 /* RX packet size EWMA. The average packet size is used to determine the packet
  * buffer size when refilling RX rings. As the entire RX ring may be refilled
  * at once, the weight is chosen so that the EWMA will be insensitive to short-
@@ -48,8 +51,16 @@ module_param(gso, bool, 0444);
  */
 DECLARE_EWMA(pkt_len, 1, 64)
 
+/* With mergeable buffers we align buffer address and use the low bits to
+ * encode its true size. Buffer size is up to 1 page so we need to align to
+ * square root of page size to ensure we reserve enough bits to encode the true
+ * size.
+ */
+#define MERGEABLE_BUFFER_MIN_ALIGN_SHIFT ((PAGE_SHIFT + 1) / 2)
+
 /* Minimum alignment for mergeable packet buffers. */
-#define MERGEABLE_BUFFER_ALIGN max(L1_CACHE_BYTES, 256)
+#define MERGEABLE_BUFFER_ALIGN max(L1_CACHE_BYTES, \
+                                  1 << MERGEABLE_BUFFER_MIN_ALIGN_SHIFT)
 
 #define VIRTNET_DRIVER_VERSION "1.0.0"
 
@@ -330,17 +341,21 @@ static struct sk_buff *page_to_skb(struct virtnet_info *vi,
        return skb;
 }
 
-static void virtnet_xdp_xmit(struct virtnet_info *vi,
+static bool virtnet_xdp_xmit(struct virtnet_info *vi,
                             struct receive_queue *rq,
-                            struct send_queue *sq,
                             struct xdp_buff *xdp,
                             void *data)
 {
        struct virtio_net_hdr_mrg_rxbuf *hdr;
        unsigned int num_sg, len;
+       struct send_queue *sq;
+       unsigned int qp;
        void *xdp_sent;
        int err;
 
+       qp = vi->curr_queue_pairs - vi->xdp_queue_pairs + smp_processor_id();
+       sq = &vi->sq[qp];
+
        /* Free up any pending old buffers before queueing new ones. */
        while ((xdp_sent = virtqueue_get_buf(sq->vq, &len)) != NULL) {
                if (vi->mergeable_rx_bufs) {
@@ -355,6 +370,7 @@ static void virtnet_xdp_xmit(struct virtnet_info *vi,
        }
 
        if (vi->mergeable_rx_bufs) {
+               xdp->data -= sizeof(struct virtio_net_hdr_mrg_rxbuf);
                /* Zero header and leave csum up to XDP layers */
                hdr = xdp->data;
                memset(hdr, 0, vi->hdr_len);
@@ -371,7 +387,9 @@ static void virtnet_xdp_xmit(struct virtnet_info *vi,
                num_sg = 2;
                sg_init_table(sq->sg, 2);
                sg_set_buf(sq->sg, hdr, vi->hdr_len);
-               skb_to_sgvec(skb, sq->sg + 1, 0, skb->len);
+               skb_to_sgvec(skb, sq->sg + 1,
+                            xdp->data - xdp->data_hard_start,
+                            xdp->data_end - xdp->data);
        }
        err = virtqueue_add_outbuf(sq->vq, sq->sg, num_sg,
                                   data, GFP_ATOMIC);
@@ -382,53 +400,12 @@ static void virtnet_xdp_xmit(struct virtnet_info *vi,
                        put_page(page);
                } else /* small buffer */
                        kfree_skb(data);
-               return; // On error abort to avoid unnecessary kick
+               /* On error abort to avoid unnecessary kick */
+               return false;
        }
 
        virtqueue_kick(sq->vq);
-}
-
-static u32 do_xdp_prog(struct virtnet_info *vi,
-                      struct receive_queue *rq,
-                      struct bpf_prog *xdp_prog,
-                      void *data, int len)
-{
-       int hdr_padded_len;
-       struct xdp_buff xdp;
-       void *buf;
-       unsigned int qp;
-       u32 act;
-
-       if (vi->mergeable_rx_bufs) {
-               hdr_padded_len = sizeof(struct virtio_net_hdr_mrg_rxbuf);
-               xdp.data = data + hdr_padded_len;
-               xdp.data_end = xdp.data + (len - vi->hdr_len);
-               buf = data;
-       } else { /* small buffers */
-               struct sk_buff *skb = data;
-
-               xdp.data = skb->data;
-               xdp.data_end = xdp.data + len;
-               buf = skb->data;
-       }
-
-       act = bpf_prog_run_xdp(xdp_prog, &xdp);
-       switch (act) {
-       case XDP_PASS:
-               return XDP_PASS;
-       case XDP_TX:
-               qp = vi->curr_queue_pairs -
-                       vi->xdp_queue_pairs +
-                       smp_processor_id();
-               xdp.data = buf;
-               virtnet_xdp_xmit(vi, rq, &vi->sq[qp], &xdp, data);
-               return XDP_TX;
-       default:
-               bpf_warn_invalid_xdp_action(act);
-       case XDP_ABORTED:
-       case XDP_DROP:
-               return XDP_DROP;
-       }
+       return true;
 }
 
 static struct sk_buff *receive_small(struct net_device *dev,
@@ -440,30 +417,44 @@ static struct sk_buff *receive_small(struct net_device *dev,
        struct bpf_prog *xdp_prog;
 
        len -= vi->hdr_len;
-       skb_trim(skb, len);
 
        rcu_read_lock();
        xdp_prog = rcu_dereference(rq->xdp_prog);
        if (xdp_prog) {
                struct virtio_net_hdr_mrg_rxbuf *hdr = buf;
+               struct xdp_buff xdp;
                u32 act;
 
                if (unlikely(hdr->hdr.gso_type || hdr->hdr.flags))
                        goto err_xdp;
-               act = do_xdp_prog(vi, rq, xdp_prog, skb, len);
+
+               xdp.data_hard_start = skb->data;
+               xdp.data = skb->data + VIRTIO_XDP_HEADROOM;
+               xdp.data_end = xdp.data + len;
+               act = bpf_prog_run_xdp(xdp_prog, &xdp);
+
                switch (act) {
                case XDP_PASS:
+                       /* Recalculate length in case bpf program changed it */
+                       __skb_pull(skb, xdp.data - xdp.data_hard_start);
+                       len = xdp.data_end - xdp.data;
                        break;
                case XDP_TX:
+                       if (unlikely(!virtnet_xdp_xmit(vi, rq, &xdp, skb)))
+                               trace_xdp_exception(vi->dev, xdp_prog, act);
                        rcu_read_unlock();
                        goto xdp_xmit;
-               case XDP_DROP:
                default:
+                       bpf_warn_invalid_xdp_action(act);
+               case XDP_ABORTED:
+                       trace_xdp_exception(vi->dev, xdp_prog, act);
+               case XDP_DROP:
                        goto err_xdp;
                }
        }
        rcu_read_unlock();
 
+       skb_trim(skb, len);
        return skb;
 
 err_xdp:
@@ -512,7 +503,7 @@ static struct page *xdp_linearize_page(struct receive_queue *rq,
                                       unsigned int *len)
 {
        struct page *page = alloc_page(GFP_ATOMIC);
-       unsigned int page_off = 0;
+       unsigned int page_off = VIRTIO_XDP_HEADROOM;
 
        if (!page)
                return NULL;
@@ -548,7 +539,8 @@ static struct page *xdp_linearize_page(struct receive_queue *rq,
                put_page(p);
        }
 
-       *len = page_off;
+       /* Headroom does not contribute to packet length */
+       *len = page_off - VIRTIO_XDP_HEADROOM;
        return page;
 err_buf:
        __free_pages(page, 0);
@@ -576,6 +568,8 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
        xdp_prog = rcu_dereference(rq->xdp_prog);
        if (xdp_prog) {
                struct page *xdp_page;
+               struct xdp_buff xdp;
+               void *data;
                u32 act;
 
                /* This happens when rx buffer size is underestimated */
@@ -585,7 +579,7 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
                                                      page, offset, &len);
                        if (!xdp_page)
                                goto err_xdp;
-                       offset = 0;
+                       offset = VIRTIO_XDP_HEADROOM;
                } else {
                        xdp_page = page;
                }
@@ -598,28 +592,47 @@ static struct sk_buff *receive_mergeable(struct net_device *dev,
                if (unlikely(hdr->hdr.gso_type))
                        goto err_xdp;
 
-               act = do_xdp_prog(vi, rq, xdp_prog,
-                                 page_address(xdp_page) + offset, len);
+               /* Allow consuming headroom but reserve enough space to push
+                * the descriptor on if we get an XDP_TX return code.
+                */
+               data = page_address(xdp_page) + offset;
+               xdp.data_hard_start = data - VIRTIO_XDP_HEADROOM + vi->hdr_len;
+               xdp.data = data + vi->hdr_len;
+               xdp.data_end = xdp.data + (len - vi->hdr_len);
+               act = bpf_prog_run_xdp(xdp_prog, &xdp);
+
                switch (act) {
                case XDP_PASS:
+                       /* recalculate offset to account for any header
+                        * adjustments. Note other cases do not build an
+                        * skb and avoid using offset
+                        */
+                       offset = xdp.data -
+                                       page_address(xdp_page) - vi->hdr_len;
+
                        /* We can only create skb based on xdp_page. */
                        if (unlikely(xdp_page != page)) {
                                rcu_read_unlock();
                                put_page(page);
                                head_skb = page_to_skb(vi, rq, xdp_page,
-                                                      0, len, PAGE_SIZE);
+                                                      offset, len, PAGE_SIZE);
                                ewma_pkt_len_add(&rq->mrg_avg_pkt_len, len);
                                return head_skb;
                        }
                        break;
                case XDP_TX:
+                       if (unlikely(!virtnet_xdp_xmit(vi, rq, &xdp, data)))
+                               trace_xdp_exception(vi->dev, xdp_prog, act);
                        ewma_pkt_len_add(&rq->mrg_avg_pkt_len, len);
                        if (unlikely(xdp_page != page))
                                goto err_xdp;
                        rcu_read_unlock();
                        goto xdp_xmit;
-               case XDP_DROP:
                default:
+                       bpf_warn_invalid_xdp_action(act);
+               case XDP_ABORTED:
+                       trace_xdp_exception(vi->dev, xdp_prog, act);
+               case XDP_DROP:
                        if (unlikely(xdp_page != page))
                                __free_pages(xdp_page, 0);
                        ewma_pkt_len_add(&rq->mrg_avg_pkt_len, len);
@@ -706,13 +719,13 @@ xdp_xmit:
        return NULL;
 }
 
-static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
-                       void *buf, unsigned int len)
+static int receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
+                      void *buf, unsigned int len)
 {
        struct net_device *dev = vi->dev;
-       struct virtnet_stats *stats = this_cpu_ptr(vi->stats);
        struct sk_buff *skb;
        struct virtio_net_hdr_mrg_rxbuf *hdr;
+       int ret;
 
        if (unlikely(len < vi->hdr_len + ETH_HLEN)) {
                pr_debug("%s: short packet %i\n", dev->name, len);
@@ -726,7 +739,7 @@ static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
                } else {
                        dev_kfree_skb(buf);
                }
-               return;
+               return 0;
        }
 
        if (vi->mergeable_rx_bufs)
@@ -737,14 +750,11 @@ static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
                skb = receive_small(dev, vi, rq, buf, len);
 
        if (unlikely(!skb))
-               return;
+               return 0;
 
        hdr = skb_vnet_hdr(skb);
 
-       u64_stats_update_begin(&stats->rx_syncp);
-       stats->rx_bytes += skb->len;
-       stats->rx_packets++;
-       u64_stats_update_end(&stats->rx_syncp);
+       ret = skb->len;
 
        if (hdr->hdr.flags & VIRTIO_NET_HDR_F_DATA_VALID)
                skb->ip_summed = CHECKSUM_UNNECESSARY;
@@ -762,30 +772,38 @@ static void receive_buf(struct virtnet_info *vi, struct receive_queue *rq,
                 ntohs(skb->protocol), skb->len, skb->pkt_type);
 
        napi_gro_receive(&rq->napi, skb);
-       return;
+       return ret;
 
 frame_err:
        dev->stats.rx_frame_errors++;
        dev_kfree_skb(skb);
+       return 0;
+}
+
+static unsigned int virtnet_get_headroom(struct virtnet_info *vi)
+{
+       return vi->xdp_queue_pairs ? VIRTIO_XDP_HEADROOM : 0;
 }
 
 static int add_recvbuf_small(struct virtnet_info *vi, struct receive_queue *rq,
                             gfp_t gfp)
 {
+       int headroom = GOOD_PACKET_LEN + virtnet_get_headroom(vi);
+       unsigned int xdp_headroom = virtnet_get_headroom(vi);
        struct sk_buff *skb;
        struct virtio_net_hdr_mrg_rxbuf *hdr;
        int err;
 
-       skb = __netdev_alloc_skb_ip_align(vi->dev, GOOD_PACKET_LEN, gfp);
+       skb = __netdev_alloc_skb_ip_align(vi->dev, headroom, gfp);
        if (unlikely(!skb))
                return -ENOMEM;
 
-       skb_put(skb, GOOD_PACKET_LEN);
+       skb_put(skb, headroom);
 
        hdr = skb_vnet_hdr(skb);
        sg_init_table(rq->sg, 2);
        sg_set_buf(rq->sg, hdr, vi->hdr_len);
-       skb_to_sgvec(skb, rq->sg + 1, 0, skb->len);
+       skb_to_sgvec(skb, rq->sg + 1, xdp_headroom, skb->len - xdp_headroom);
 
        err = virtqueue_add_inbuf(rq->vq, rq->sg, 2, skb, gfp);
        if (err < 0)
@@ -853,24 +871,27 @@ static unsigned int get_mergeable_buf_len(struct ewma_pkt_len *avg_pkt_len)
        return ALIGN(len, MERGEABLE_BUFFER_ALIGN);
 }
 
-static int add_recvbuf_mergeable(struct receive_queue *rq, gfp_t gfp)
+static int add_recvbuf_mergeable(struct virtnet_info *vi,
+                                struct receive_queue *rq, gfp_t gfp)
 {
        struct page_frag *alloc_frag = &rq->alloc_frag;
+       unsigned int headroom = virtnet_get_headroom(vi);
        char *buf;
        unsigned long ctx;
        int err;
        unsigned int len, hole;
 
        len = get_mergeable_buf_len(&rq->mrg_avg_pkt_len);
-       if (unlikely(!skb_page_frag_refill(len, alloc_frag, gfp)))
+       if (unlikely(!skb_page_frag_refill(len + headroom, alloc_frag, gfp)))
                return -ENOMEM;
 
        buf = (char *)page_address(alloc_frag->page) + alloc_frag->offset;
+       buf += headroom; /* advance address leaving hole at front of pkt */
        ctx = mergeable_buf_to_ctx(buf, len);
        get_page(alloc_frag->page);
-       alloc_frag->offset += len;
+       alloc_frag->offset += len + headroom;
        hole = alloc_frag->size - alloc_frag->offset;
-       if (hole < len) {
+       if (hole < len + headroom) {
                /* To avoid internal fragmentation, if there is very likely not
                 * enough space for another buffer, add the remaining space to
                 * the current buffer. This extra space is not included in
@@ -904,7 +925,7 @@ static bool try_fill_recv(struct virtnet_info *vi, struct receive_queue *rq,
        gfp |= __GFP_COLD;
        do {
                if (vi->mergeable_rx_bufs)
-                       err = add_recvbuf_mergeable(rq, gfp);
+                       err = add_recvbuf_mergeable(vi, rq, gfp);
                else if (vi->big_packets)
                        err = add_recvbuf_big(vi, rq, gfp);
                else
@@ -971,12 +992,13 @@ static void refill_work(struct work_struct *work)
 static int virtnet_receive(struct receive_queue *rq, int budget)
 {
        struct virtnet_info *vi = rq->vq->vdev->priv;
-       unsigned int len, received = 0;
+       unsigned int len, received = 0, bytes = 0;
        void *buf;
+       struct virtnet_stats *stats = this_cpu_ptr(vi->stats);
 
        while (received < budget &&
               (buf = virtqueue_get_buf(rq->vq, &len)) != NULL) {
-               receive_buf(vi, rq, buf, len);
+               bytes += receive_buf(vi, rq, buf, len);
                received++;
        }
 
@@ -985,6 +1007,11 @@ static int virtnet_receive(struct receive_queue *rq, int budget)
                        schedule_delayed_work(&vi->refill, 0);
        }
 
+       u64_stats_update_begin(&stats->rx_syncp);
+       stats->rx_bytes += bytes;
+       stats->rx_packets += received;
+       u64_stats_update_end(&stats->rx_syncp);
+
        return received;
 }
 
@@ -999,53 +1026,17 @@ static int virtnet_poll(struct napi_struct *napi, int budget)
        /* Out of packets? */
        if (received < budget) {
                r = virtqueue_enable_cb_prepare(rq->vq);
-               napi_complete_done(napi, received);
-               if (unlikely(virtqueue_poll(rq->vq, r)) &&
-                   napi_schedule_prep(napi)) {
-                       virtqueue_disable_cb(rq->vq);
-                       __napi_schedule(napi);
-               }
-       }
-
-       return received;
-}
-
-#ifdef CONFIG_NET_RX_BUSY_POLL
-/* must be called with local_bh_disable()d */
-static int virtnet_busy_poll(struct napi_struct *napi)
-{
-       struct receive_queue *rq =
-               container_of(napi, struct receive_queue, napi);
-       struct virtnet_info *vi = rq->vq->vdev->priv;
-       int r, received = 0, budget = 4;
-
-       if (!(vi->status & VIRTIO_NET_S_LINK_UP))
-               return LL_FLUSH_FAILED;
-
-       if (!napi_schedule_prep(napi))
-               return LL_FLUSH_BUSY;
-
-       virtqueue_disable_cb(rq->vq);
-
-again:
-       received += virtnet_receive(rq, budget);
-
-       r = virtqueue_enable_cb_prepare(rq->vq);
-       clear_bit(NAPI_STATE_SCHED, &napi->state);
-       if (unlikely(virtqueue_poll(rq->vq, r)) &&
-           napi_schedule_prep(napi)) {
-               virtqueue_disable_cb(rq->vq);
-               if (received < budget) {
-                       budget -= received;
-                       goto again;
-               } else {
-                       __napi_schedule(napi);
+               if (napi_complete_done(napi, received)) {
+                       if (unlikely(virtqueue_poll(rq->vq, r)) &&
+                           napi_schedule_prep(napi)) {
+                               virtqueue_disable_cb(rq->vq);
+                               __napi_schedule(napi);
+                       }
                }
        }
 
        return received;
 }
-#endif /* CONFIG_NET_RX_BUSY_POLL */
 
 static int virtnet_open(struct net_device *dev)
 {
@@ -1069,17 +1060,28 @@ static void free_old_xmit_skbs(struct send_queue *sq)
        unsigned int len;
        struct virtnet_info *vi = sq->vq->vdev->priv;
        struct virtnet_stats *stats = this_cpu_ptr(vi->stats);
+       unsigned int packets = 0;
+       unsigned int bytes = 0;
 
        while ((skb = virtqueue_get_buf(sq->vq, &len)) != NULL) {
                pr_debug("Sent skb %p\n", skb);
 
-               u64_stats_update_begin(&stats->tx_syncp);
-               stats->tx_bytes += skb->len;
-               stats->tx_packets++;
-               u64_stats_update_end(&stats->tx_syncp);
+               bytes += skb->len;
+               packets++;
 
                dev_kfree_skb_any(skb);
        }
+
+       /* Avoid overhead when no packets have been processed
+        * happens when called speculatively from start_xmit.
+        */
+       if (!packets)
+               return;
+
+       u64_stats_update_begin(&stats->tx_syncp);
+       stats->tx_bytes += bytes;
+       stats->tx_packets += packets;
+       u64_stats_update_end(&stats->tx_syncp);
 }
 
 static int xmit_skb(struct send_queue *sq, struct sk_buff *skb)
@@ -1104,7 +1106,7 @@ static int xmit_skb(struct send_queue *sq, struct sk_buff *skb)
                hdr = skb_vnet_hdr(skb);
 
        if (virtio_net_hdr_from_skb(skb, &hdr->hdr,
-                                   virtio_is_little_endian(vi->vdev)))
+                                   virtio_is_little_endian(vi->vdev), false))
                BUG();
 
        if (vi->mergeable_rx_bufs)
@@ -1236,10 +1238,9 @@ static int virtnet_set_mac_address(struct net_device *dev, void *p)
        struct sockaddr *addr;
        struct scatterlist sg;
 
-       addr = kmalloc(sizeof(*addr), GFP_KERNEL);
+       addr = kmemdup(p, sizeof(*addr), GFP_KERNEL);
        if (!addr)
                return -ENOMEM;
-       memcpy(addr, p, sizeof(*addr));
 
        ret = eth_prepare_mac_addr_change(dev, addr);
        if (ret)
@@ -1273,8 +1274,8 @@ out:
        return ret;
 }
 
-static struct rtnl_link_stats64 *virtnet_stats(struct net_device *dev,
-                                              struct rtnl_link_stats64 *tot)
+static void virtnet_stats(struct net_device *dev,
+                         struct rtnl_link_stats64 *tot)
 {
        struct virtnet_info *vi = netdev_priv(dev);
        int cpu;
@@ -1307,8 +1308,6 @@ static struct rtnl_link_stats64 *virtnet_stats(struct net_device *dev,
        tot->rx_dropped = dev->stats.rx_dropped;
        tot->rx_length_errors = dev->stats.rx_length_errors;
        tot->rx_frame_errors = dev->stats.rx_frame_errors;
-
-       return tot;
 }
 
 #ifdef CONFIG_NET_POLL_CONTROLLER
@@ -1331,7 +1330,7 @@ static void virtnet_ack_link_announce(struct virtnet_info *vi)
        rtnl_unlock();
 }
 
-static int virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
+static int _virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
 {
        struct scatterlist sg;
        struct net_device *dev = vi->dev;
@@ -1357,6 +1356,16 @@ static int virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
        return 0;
 }
 
+static int virtnet_set_queues(struct virtnet_info *vi, u16 queue_pairs)
+{
+       int err;
+
+       rtnl_lock();
+       err = _virtnet_set_queues(vi, queue_pairs);
+       rtnl_unlock();
+       return err;
+}
+
 static int virtnet_close(struct net_device *dev)
 {
        struct virtnet_info *vi = netdev_priv(dev);
@@ -1609,7 +1618,7 @@ static int virtnet_set_channels(struct net_device *dev,
                return -EINVAL;
 
        get_online_cpus();
-       err = virtnet_set_queues(vi, queue_pairs);
+       err = _virtnet_set_queues(vi, queue_pairs);
        if (!err) {
                netif_set_real_num_tx_queues(dev, queue_pairs);
                netif_set_real_num_rx_queues(dev, queue_pairs);
@@ -1699,12 +1708,89 @@ static const struct ethtool_ops virtnet_ethtool_ops = {
        .set_settings = virtnet_set_settings,
 };
 
+static void virtnet_freeze_down(struct virtio_device *vdev)
+{
+       struct virtnet_info *vi = vdev->priv;
+       int i;
+
+       /* Make sure no work handler is accessing the device */
+       flush_work(&vi->config_work);
+
+       netif_device_detach(vi->dev);
+       cancel_delayed_work_sync(&vi->refill);
+
+       if (netif_running(vi->dev)) {
+               for (i = 0; i < vi->max_queue_pairs; i++)
+                       napi_disable(&vi->rq[i].napi);
+       }
+}
+
+static int init_vqs(struct virtnet_info *vi);
+static void _remove_vq_common(struct virtnet_info *vi);
+
+static int virtnet_restore_up(struct virtio_device *vdev)
+{
+       struct virtnet_info *vi = vdev->priv;
+       int err, i;
+
+       err = init_vqs(vi);
+       if (err)
+               return err;
+
+       virtio_device_ready(vdev);
+
+       if (netif_running(vi->dev)) {
+               for (i = 0; i < vi->curr_queue_pairs; i++)
+                       if (!try_fill_recv(vi, &vi->rq[i], GFP_KERNEL))
+                               schedule_delayed_work(&vi->refill, 0);
+
+               for (i = 0; i < vi->max_queue_pairs; i++)
+                       virtnet_napi_enable(&vi->rq[i]);
+       }
+
+       netif_device_attach(vi->dev);
+       return err;
+}
+
+static int virtnet_reset(struct virtnet_info *vi)
+{
+       struct virtio_device *dev = vi->vdev;
+       int ret;
+
+       virtio_config_disable(dev);
+       dev->failed = dev->config->get_status(dev) & VIRTIO_CONFIG_S_FAILED;
+       virtnet_freeze_down(dev);
+       _remove_vq_common(vi);
+
+       dev->config->reset(dev);
+       virtio_add_status(dev, VIRTIO_CONFIG_S_ACKNOWLEDGE);
+       virtio_add_status(dev, VIRTIO_CONFIG_S_DRIVER);
+
+       ret = virtio_finalize_features(dev);
+       if (ret)
+               goto err;
+
+       ret = virtnet_restore_up(dev);
+       if (ret)
+               goto err;
+       ret = _virtnet_set_queues(vi, vi->curr_queue_pairs);
+       if (ret)
+               goto err;
+
+       virtio_add_status(dev, VIRTIO_CONFIG_S_DRIVER_OK);
+       virtio_config_enable(dev);
+       return 0;
+err:
+       virtio_add_status(dev, VIRTIO_CONFIG_S_FAILED);
+       return ret;
+}
+
 static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog)
 {
        unsigned long int max_sz = PAGE_SIZE - sizeof(struct padded_vnet_hdr);
        struct virtnet_info *vi = netdev_priv(dev);
        struct bpf_prog *old_prog;
-       u16 xdp_qp = 0, curr_qp;
+       u16 oxdp_qp, xdp_qp = 0, curr_qp;
        int i, err;
 
        if (virtio_has_feature(vi->vdev, VIRTIO_NET_F_GUEST_TSO4) ||
@@ -1736,21 +1822,32 @@ static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog)
                return -ENOMEM;
        }
 
-       err = virtnet_set_queues(vi, curr_qp + xdp_qp);
+       if (prog) {
+               prog = bpf_prog_add(prog, vi->max_queue_pairs - 1);
+               if (IS_ERR(prog))
+                       return PTR_ERR(prog);
+       }
+
+       err = _virtnet_set_queues(vi, curr_qp + xdp_qp);
        if (err) {
                dev_warn(&dev->dev, "XDP Device queue allocation failure.\n");
-               return err;
+               goto virtio_queue_err;
        }
 
-       if (prog) {
-               prog = bpf_prog_add(prog, vi->max_queue_pairs - 1);
-               if (IS_ERR(prog)) {
-                       virtnet_set_queues(vi, curr_qp);
-                       return PTR_ERR(prog);
-               }
+       oxdp_qp = vi->xdp_queue_pairs;
+
+       /* Changing the headroom in buffers is a disruptive operation because
+        * existing buffers must be flushed and reallocated. This will happen
+        * when a xdp program is initially added or xdp is disabled by removing
+        * the xdp program resulting in number of XDP queues changing.
+        */
+       if (vi->xdp_queue_pairs != xdp_qp) {
+               vi->xdp_queue_pairs = xdp_qp;
+               err = virtnet_reset(vi);
+               if (err)
+                       goto virtio_reset_err;
        }
 
-       vi->xdp_queue_pairs = xdp_qp;
        netif_set_real_num_rx_queues(dev, curr_qp + xdp_qp);
 
        for (i = 0; i < vi->max_queue_pairs; i++) {
@@ -1761,6 +1858,21 @@ static int virtnet_xdp_set(struct net_device *dev, struct bpf_prog *prog)
        }
 
        return 0;
+
+virtio_reset_err:
+       /* On reset error do our best to unwind XDP changes inflight and return
+        * error up to user space for resolution. The underlying reset hung on
+        * us so not much we can do here.
+        */
+       dev_warn(&dev->dev, "XDP reset failure and queues unstable\n");
+       vi->xdp_queue_pairs = oxdp_qp;
+virtio_queue_err:
+       /* On queue set error we can unwind bpf ref count and user space can
+        * retry this is most likely an allocation failure.
+        */
+       if (prog)
+               bpf_prog_sub(prog, vi->max_queue_pairs - 1);
+       return err;
 }
 
 static bool virtnet_xdp_query(struct net_device *dev)
@@ -1800,9 +1912,6 @@ static const struct net_device_ops virtnet_netdev = {
        .ndo_vlan_rx_kill_vid = virtnet_vlan_rx_kill_vid,
 #ifdef CONFIG_NET_POLL_CONTROLLER
        .ndo_poll_controller = virtnet_netpoll,
-#endif
-#ifdef CONFIG_NET_RX_BUSY_POLL
-       .ndo_busy_poll          = virtnet_busy_poll,
 #endif
        .ndo_xdp                = virtnet_xdp,
 };
@@ -1864,12 +1973,11 @@ static void virtnet_free_queues(struct virtnet_info *vi)
        kfree(vi->sq);
 }
 
-static void free_receive_bufs(struct virtnet_info *vi)
+static void _free_receive_bufs(struct virtnet_info *vi)
 {
        struct bpf_prog *old_prog;
        int i;
 
-       rtnl_lock();
        for (i = 0; i < vi->max_queue_pairs; i++) {
                while (vi->rq[i].pages)
                        __free_pages(get_a_page(&vi->rq[i], GFP_KERNEL), 0);
@@ -1879,6 +1987,12 @@ static void free_receive_bufs(struct virtnet_info *vi)
                if (old_prog)
                        bpf_prog_put(old_prog);
        }
+}
+
+static void free_receive_bufs(struct virtnet_info *vi)
+{
+       rtnl_lock();
+       _free_receive_bufs(vi);
        rtnl_unlock();
 }
 
@@ -1890,8 +2004,12 @@ static void free_receive_page_frags(struct virtnet_info *vi)
                        put_page(vi->rq[i].alloc_frag.page);
 }
 
-static bool is_xdp_queue(struct virtnet_info *vi, int q)
+static bool is_xdp_raw_buffer_queue(struct virtnet_info *vi, int q)
 {
+       /* For small receive mode always use kfree_skb variants */
+       if (!vi->mergeable_rx_bufs)
+               return false;
+
        if (q < (vi->curr_queue_pairs - vi->xdp_queue_pairs))
                return false;
        else if (q < vi->curr_queue_pairs)
@@ -1908,7 +2026,7 @@ static void free_unused_bufs(struct virtnet_info *vi)
        for (i = 0; i < vi->max_queue_pairs; i++) {
                struct virtqueue *vq = vi->sq[i].vq;
                while ((buf = virtqueue_detach_unused_buf(vq)) != NULL) {
-                       if (!is_xdp_queue(vi, i))
+                       if (!is_xdp_raw_buffer_queue(vi, i))
                                dev_kfree_skb(buf);
                        else
                                put_page(virt_to_head_page(buf));
@@ -2313,9 +2431,7 @@ static int virtnet_probe(struct virtio_device *vdev)
                goto free_unregister_netdev;
        }
 
-       rtnl_lock();
        virtnet_set_queues(vi, vi->curr_queue_pairs);
-       rtnl_unlock();
 
        /* Assume link up if device can't report link status,
           otherwise get link status from config. */
@@ -2347,6 +2463,15 @@ free:
        return err;
 }
 
+static void _remove_vq_common(struct virtnet_info *vi)
+{
+       vi->vdev->config->reset(vi->vdev);
+       free_unused_bufs(vi);
+       _free_receive_bufs(vi);
+       free_receive_page_frags(vi);
+       virtnet_del_vqs(vi);
+}
+
 static void remove_vq_common(struct virtnet_info *vi)
 {
        vi->vdev->config->reset(vi->vdev);
@@ -2382,21 +2507,9 @@ static void virtnet_remove(struct virtio_device *vdev)
 static int virtnet_freeze(struct virtio_device *vdev)
 {
        struct virtnet_info *vi = vdev->priv;
-       int i;
 
        virtnet_cpu_notif_remove(vi);
-
-       /* Make sure no work handler is accessing the device */
-       flush_work(&vi->config_work);
-
-       netif_device_detach(vi->dev);
-       cancel_delayed_work_sync(&vi->refill);
-
-       if (netif_running(vi->dev)) {
-               for (i = 0; i < vi->max_queue_pairs; i++)
-                       napi_disable(&vi->rq[i].napi);
-       }
-
+       virtnet_freeze_down(vdev);
        remove_vq_common(vi);
 
        return 0;
@@ -2405,28 +2518,12 @@ static int virtnet_freeze(struct virtio_device *vdev)
 static int virtnet_restore(struct virtio_device *vdev)
 {
        struct virtnet_info *vi = vdev->priv;
-       int err, i;
+       int err;
 
-       err = init_vqs(vi);
+       err = virtnet_restore_up(vdev);
        if (err)
                return err;
-
-       virtio_device_ready(vdev);
-
-       if (netif_running(vi->dev)) {
-               for (i = 0; i < vi->curr_queue_pairs; i++)
-                       if (!try_fill_recv(vi, &vi->rq[i], GFP_KERNEL))
-                               schedule_delayed_work(&vi->refill, 0);
-
-               for (i = 0; i < vi->max_queue_pairs; i++)
-                       virtnet_napi_enable(&vi->rq[i]);
-       }
-
-       netif_device_attach(vi->dev);
-
-       rtnl_lock();
        virtnet_set_queues(vi, vi->curr_queue_pairs);
-       rtnl_unlock();
 
        err = virtnet_cpu_notif_add(vi);
        if (err)