]> git.proxmox.com Git - mirror_ubuntu-kernels.git/commitdiff
net/packet: support mergeable feature of virtio
authorJianfeng Tan <henry.tjf@antgroup.com>
Wed, 19 Apr 2023 07:24:16 +0000 (15:24 +0800)
committerDavid S. Miller <davem@davemloft.net>
Fri, 21 Apr 2023 11:01:58 +0000 (12:01 +0100)
Packet sockets, like tap, can be used as the backend for kernel vhost.
In packet sockets, virtio net header size is currently hardcoded to be
the size of struct virtio_net_hdr, which is 10 bytes; however, it is not
always the case: some virtio features, such as mrg_rxbuf, need virtio
net header to be 12-byte long.

Mergeable buffers, as a virtio feature, is worthy of supporting: packets
that are larger than one-mbuf size will be dropped in vhost worker's
handle_rx if mrg_rxbuf feature is not used, but large packets
cannot be avoided and increasing mbuf's size is not economical.

With this virtio feature enabled by virtio-user, packet sockets with
hardcoded 10-byte virtio net header will parse mac head incorrectly in
packet_snd by taking the last two bytes of virtio net header as part of
mac header.
This incorrect mac header parsing will cause packet to be dropped due to
invalid ether head checking in later under-layer device packet receiving.

By adding extra field vnet_hdr_sz with utilizing holes in struct
packet_sock to record currently used virtio net header size and supporting
extra sockopt PACKET_VNET_HDR_SZ to set specified vnet_hdr_sz, packet
sockets can know the exact length of virtio net header that virtio user
gives.
In packet_snd, tpacket_snd and packet_recvmsg, instead of using
hardcoded virtio net header size, it can get the exact vnet_hdr_sz from
corresponding packet_sock, and parse mac header correctly based on this
information to avoid the packets being mistakenly dropped.

Signed-off-by: Jianfeng Tan <henry.tjf@antgroup.com>
Co-developed-by: Anqi Shen <amy.saq@antgroup.com>
Signed-off-by: Anqi Shen <amy.saq@antgroup.com>
Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/uapi/linux/if_packet.h
net/packet/af_packet.c
net/packet/diag.c
net/packet/internal.h

index 78c981d6a9d4484d6ab701be3ce72569d954e8b1..9efc42382fdb98f9e9b5470ae2a66ccf4d0ba4e1 100644 (file)
@@ -59,6 +59,7 @@ struct sockaddr_ll {
 #define PACKET_ROLLOVER_STATS          21
 #define PACKET_FANOUT_DATA             22
 #define PACKET_IGNORE_OUTGOING         23
+#define PACKET_VNET_HDR_SZ             24
 
 #define PACKET_FANOUT_HASH             0
 #define PACKET_FANOUT_LB               1
index 568f8d76e3c124f3b322a8d88dc3dcfbc45e7c0e..6080c0db08148fa09b197d26df0d0c8d3469083a 100644 (file)
@@ -2090,18 +2090,18 @@ static unsigned int run_filter(struct sk_buff *skb,
 }
 
 static int packet_rcv_vnet(struct msghdr *msg, const struct sk_buff *skb,
-                          size_t *len)
+                          size_t *len, int vnet_hdr_sz)
 {
-       struct virtio_net_hdr vnet_hdr;
+       struct virtio_net_hdr_mrg_rxbuf vnet_hdr = { .num_buffers = 0 };
 
-       if (*len < sizeof(vnet_hdr))
+       if (*len < vnet_hdr_sz)
                return -EINVAL;
-       *len -= sizeof(vnet_hdr);
+       *len -= vnet_hdr_sz;
 
-       if (virtio_net_hdr_from_skb(skb, &vnet_hdr, vio_le(), true, 0))
+       if (virtio_net_hdr_from_skb(skb, (struct virtio_net_hdr *)&vnet_hdr, vio_le(), true, 0))
                return -EINVAL;
 
-       return memcpy_to_msg(msg, (void *)&vnet_hdr, sizeof(vnet_hdr));
+       return memcpy_to_msg(msg, (void *)&vnet_hdr, vnet_hdr_sz);
 }
 
 /*
@@ -2250,7 +2250,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
        __u32 ts_status;
        bool is_drop_n_account = false;
        unsigned int slot_id = 0;
-       bool do_vnet = false;
+       int vnet_hdr_sz = 0;
 
        /* struct tpacket{2,3}_hdr is aligned to a multiple of TPACKET_ALIGNMENT.
         * We may add members to them until current aligned size without forcing
@@ -2308,10 +2308,9 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                netoff = TPACKET_ALIGN(po->tp_hdrlen +
                                       (maclen < 16 ? 16 : maclen)) +
                                       po->tp_reserve;
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
-                       netoff += sizeof(struct virtio_net_hdr);
-                       do_vnet = true;
-               }
+               vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
+               if (vnet_hdr_sz)
+                       netoff += vnet_hdr_sz;
                macoff = netoff - maclen;
        }
        if (netoff > USHRT_MAX) {
@@ -2337,7 +2336,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                        snaplen = po->rx_ring.frame_size - macoff;
                        if ((int)snaplen < 0) {
                                snaplen = 0;
-                               do_vnet = false;
+                               vnet_hdr_sz = 0;
                        }
                }
        } else if (unlikely(macoff + snaplen >
@@ -2351,7 +2350,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                if (unlikely((int)snaplen < 0)) {
                        snaplen = 0;
                        macoff = GET_PBDQC_FROM_RB(&po->rx_ring)->max_frame_len;
-                       do_vnet = false;
+                       vnet_hdr_sz = 0;
                }
        }
        spin_lock(&sk->sk_receive_queue.lock);
@@ -2367,7 +2366,7 @@ static int tpacket_rcv(struct sk_buff *skb, struct net_device *dev,
                __set_bit(slot_id, po->rx_ring.rx_owner_map);
        }
 
-       if (do_vnet &&
+       if (vnet_hdr_sz &&
            virtio_net_hdr_from_skb(skb, h.raw + macoff -
                                    sizeof(struct virtio_net_hdr),
                                    vio_le(), true, 0)) {
@@ -2551,16 +2550,26 @@ static int __packet_snd_vnet_parse(struct virtio_net_hdr *vnet_hdr, size_t len)
 }
 
 static int packet_snd_vnet_parse(struct msghdr *msg, size_t *len,
-                                struct virtio_net_hdr *vnet_hdr)
+                                struct virtio_net_hdr *vnet_hdr, int vnet_hdr_sz)
 {
-       if (*len < sizeof(*vnet_hdr))
+       int ret;
+
+       if (*len < vnet_hdr_sz)
                return -EINVAL;
-       *len -= sizeof(*vnet_hdr);
+       *len -= vnet_hdr_sz;
 
        if (!copy_from_iter_full(vnet_hdr, sizeof(*vnet_hdr), &msg->msg_iter))
                return -EFAULT;
 
-       return __packet_snd_vnet_parse(vnet_hdr, *len);
+       ret = __packet_snd_vnet_parse(vnet_hdr, *len);
+       if (ret)
+               return ret;
+
+       /* move iter to point to the start of mac header */
+       if (vnet_hdr_sz != sizeof(struct virtio_net_hdr))
+               iov_iter_advance(&msg->msg_iter, vnet_hdr_sz - sizeof(struct virtio_net_hdr));
+
+       return 0;
 }
 
 static int tpacket_fill_skb(struct packet_sock *po, struct sk_buff *skb,
@@ -2722,6 +2731,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        void *ph;
        DECLARE_SOCKADDR(struct sockaddr_ll *, saddr, msg->msg_name);
        bool need_wait = !(msg->msg_flags & MSG_DONTWAIT);
+       int vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
        unsigned char *addr = NULL;
        int tp_len, size_max;
        void *data;
@@ -2779,8 +2789,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
        size_max = po->tx_ring.frame_size
                - (po->tp_hdrlen - sizeof(struct sockaddr_ll));
 
-       if ((size_max > dev->mtu + reserve + VLAN_HLEN) &&
-           !packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
+       if ((size_max > dev->mtu + reserve + VLAN_HLEN) && !vnet_hdr_sz)
                size_max = dev->mtu + reserve + VLAN_HLEN;
 
        reinit_completion(&po->skb_completion);
@@ -2809,10 +2818,10 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                status = TP_STATUS_SEND_REQUEST;
                hlen = LL_RESERVED_SPACE(dev);
                tlen = dev->needed_tailroom;
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
+               if (vnet_hdr_sz) {
                        vnet_hdr = data;
-                       data += sizeof(*vnet_hdr);
-                       tp_len -= sizeof(*vnet_hdr);
+                       data += vnet_hdr_sz;
+                       tp_len -= vnet_hdr_sz;
                        if (tp_len < 0 ||
                            __packet_snd_vnet_parse(vnet_hdr, tp_len)) {
                                tp_len = -EINVAL;
@@ -2837,7 +2846,7 @@ static int tpacket_snd(struct packet_sock *po, struct msghdr *msg)
                                          addr, hlen, copylen, &sockc);
                if (likely(tp_len >= 0) &&
                    tp_len > dev->mtu + reserve &&
-                   !packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR) &&
+                   !vnet_hdr_sz &&
                    !packet_extra_vlan_len_allowed(dev, skb))
                        tp_len = -EMSGSIZE;
 
@@ -2856,7 +2865,7 @@ tpacket_error:
                        }
                }
 
-               if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
+               if (vnet_hdr_sz) {
                        if (virtio_net_hdr_to_skb(skb, vnet_hdr, vio_le())) {
                                tp_len = -EINVAL;
                                goto tpacket_error;
@@ -2946,7 +2955,7 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
        struct virtio_net_hdr vnet_hdr = { 0 };
        int offset = 0;
        struct packet_sock *po = pkt_sk(sk);
-       bool has_vnet_hdr = false;
+       int vnet_hdr_sz = READ_ONCE(po->vnet_hdr_sz);
        int hlen, tlen, linear;
        int extra_len = 0;
 
@@ -2990,11 +2999,10 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
        if (sock->type == SOCK_RAW)
                reserve = dev->hard_header_len;
-       if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR)) {
-               err = packet_snd_vnet_parse(msg, &len, &vnet_hdr);
+       if (vnet_hdr_sz) {
+               err = packet_snd_vnet_parse(msg, &len, &vnet_hdr, vnet_hdr_sz);
                if (err)
                        goto out_unlock;
-               has_vnet_hdr = true;
        }
 
        if (unlikely(sock_flag(sk, SOCK_NOFCS))) {
@@ -3064,11 +3072,11 @@ static int packet_snd(struct socket *sock, struct msghdr *msg, size_t len)
 
        packet_parse_headers(skb, sock);
 
-       if (has_vnet_hdr) {
+       if (vnet_hdr_sz) {
                err = virtio_net_hdr_to_skb(skb, &vnet_hdr, vio_le());
                if (err)
                        goto out_free;
-               len += sizeof(vnet_hdr);
+               len += vnet_hdr_sz;
                virtio_net_hdr_set_proto(skb, &vnet_hdr);
        }
 
@@ -3408,7 +3416,7 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
        struct sock *sk = sock->sk;
        struct sk_buff *skb;
        int copied, err;
-       int vnet_hdr_len = 0;
+       int vnet_hdr_len = READ_ONCE(pkt_sk(sk)->vnet_hdr_sz);
        unsigned int origlen = 0;
 
        err = -EINVAL;
@@ -3449,11 +3457,10 @@ static int packet_recvmsg(struct socket *sock, struct msghdr *msg, size_t len,
 
        packet_rcv_try_clear_pressure(pkt_sk(sk));
 
-       if (packet_sock_flag(pkt_sk(sk), PACKET_SOCK_HAS_VNET_HDR)) {
-               err = packet_rcv_vnet(msg, skb, &len);
+       if (vnet_hdr_len) {
+               err = packet_rcv_vnet(msg, skb, &len, vnet_hdr_len);
                if (err)
                        goto out_free;
-               vnet_hdr_len = sizeof(struct virtio_net_hdr);
        }
 
        /* You lose any data beyond the buffer you gave. If it worries
@@ -3915,8 +3922,9 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
                return 0;
        }
        case PACKET_VNET_HDR:
+       case PACKET_VNET_HDR_SZ:
        {
-               int val;
+               int val, hdr_len;
 
                if (sock->type != SOCK_RAW)
                        return -EINVAL;
@@ -3925,11 +3933,19 @@ packet_setsockopt(struct socket *sock, int level, int optname, sockptr_t optval,
                if (copy_from_sockptr(&val, optval, sizeof(val)))
                        return -EFAULT;
 
+               if (optname == PACKET_VNET_HDR_SZ) {
+                       if (val && val != sizeof(struct virtio_net_hdr) &&
+                           val != sizeof(struct virtio_net_hdr_mrg_rxbuf))
+                               return -EINVAL;
+                       hdr_len = val;
+               } else {
+                       hdr_len = val ? sizeof(struct virtio_net_hdr) : 0;
+               }
                lock_sock(sk);
                if (po->rx_ring.pg_vec || po->tx_ring.pg_vec) {
                        ret = -EBUSY;
                } else {
-                       packet_sock_flag_set(po, PACKET_SOCK_HAS_VNET_HDR, val);
+                       WRITE_ONCE(po->vnet_hdr_sz, hdr_len);
                        ret = 0;
                }
                release_sock(sk);
@@ -4062,7 +4078,10 @@ static int packet_getsockopt(struct socket *sock, int level, int optname,
                val = packet_sock_flag(po, PACKET_SOCK_ORIGDEV);
                break;
        case PACKET_VNET_HDR:
-               val = packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR);
+               val = !!READ_ONCE(po->vnet_hdr_sz);
+               break;
+       case PACKET_VNET_HDR_SZ:
+               val = READ_ONCE(po->vnet_hdr_sz);
                break;
        case PACKET_VERSION:
                val = po->tp_version;
index de4ced5cf3e8c5798530ab3bfbe162cc3913b318..d0c4eda4cdc6f2d0ef259a889c884c69f1f81790 100644 (file)
@@ -27,7 +27,7 @@ static int pdiag_put_info(const struct packet_sock *po, struct sk_buff *nlskb)
                pinfo.pdi_flags |= PDI_AUXDATA;
        if (packet_sock_flag(po, PACKET_SOCK_ORIGDEV))
                pinfo.pdi_flags |= PDI_ORIGDEV;
-       if (packet_sock_flag(po, PACKET_SOCK_HAS_VNET_HDR))
+       if (READ_ONCE(po->vnet_hdr_sz))
                pinfo.pdi_flags |= PDI_VNETHDR;
        if (packet_sock_flag(po, PACKET_SOCK_TP_LOSS))
                pinfo.pdi_flags |= PDI_LOSS;
index 27930f69f368ed17afa1f4c1c1bf834034d621fd..63f4865202c1394dfa1b4c8b6a7890d5a4cd14c4 100644 (file)
@@ -118,6 +118,7 @@ struct packet_sock {
        struct mutex            pg_vec_lock;
        unsigned long           flags;
        int                     ifindex;        /* bound device         */
+       u8                      vnet_hdr_sz;
        __be16                  num;
        struct packet_rollover  *rollover;
        struct packet_mclist    *mclist;
@@ -139,7 +140,6 @@ enum packet_sock_flags {
        PACKET_SOCK_AUXDATA,
        PACKET_SOCK_TX_HAS_OFF,
        PACKET_SOCK_TP_LOSS,
-       PACKET_SOCK_HAS_VNET_HDR,
        PACKET_SOCK_RUNNING,
        PACKET_SOCK_PRESSURE,
        PACKET_SOCK_QDISC_BYPASS,