]> git.proxmox.com Git - mirror_frr.git/blobdiff - zebra/kernel_socket.c
zebra: refactor route socket message handling
[mirror_frr.git] / zebra / kernel_socket.c
index 25c6e6c64b67ebc443f8bfc898ecab2f2250914e..8a7cb0e5284cee64eeb96b65c16fd6934a0db32b 100644 (file)
@@ -131,60 +131,6 @@ extern struct zebra_privs_t zserv_privs;
 
 #endif /* !SA_SIZE */
 
-/*
- * We use a call to an inline function to copy (PNT) to (DEST)
- * 1. Calculating the length of the copy requires an #ifdef to determine
- *    if sa_len is a field and can't be used directly inside a #define
- * 2. So the compiler doesn't complain when DEST is NULL, which is only true
- *    when we are skipping the copy and incrementing to the next SA
- */
-static inline void rta_copy(union sockunion *dest, caddr_t src)
-{
-       int len;
-       if (!dest)
-               return;
-#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
-       len = (((struct sockaddr *)src)->sa_len > sizeof(*dest))
-                     ? sizeof(*dest)
-                     : ((struct sockaddr *)src)->sa_len;
-#else
-       len = (SAROUNDUP(src) > sizeof(*dest)) ? sizeof(*dest) : SAROUNDUP(src);
-#endif
-       memcpy(dest, src, len);
-}
-
-#define RTA_ADDR_GET(DEST, RTA, RTMADDRS, PNT)                                 \
-       if ((RTMADDRS) & (RTA)) {                                              \
-               int len = SAROUNDUP((PNT));                                    \
-               if (af_check(((struct sockaddr *)(PNT))->sa_family))           \
-                       rta_copy((DEST), (PNT));                               \
-               (PNT) += len;                                                  \
-       }
-#define RTA_ATTR_GET(DEST, RTA, RTMADDRS, PNT)                                 \
-       if ((RTMADDRS) & (RTA)) {                                              \
-               int len = SAROUNDUP((PNT));                                    \
-               rta_copy((DEST), (PNT));                                       \
-               (PNT) += len;                                                  \
-       }
-
-#define RTA_NAME_GET(DEST, RTA, RTMADDRS, PNT, LEN)                            \
-       if ((RTMADDRS) & (RTA)) {                                              \
-               uint8_t *pdest = (uint8_t *)(DEST);                            \
-               int len = SAROUNDUP((PNT));                                    \
-               struct sockaddr_dl *sdl = (struct sockaddr_dl *)(PNT);         \
-               if (IS_ZEBRA_DEBUG_KERNEL)                                     \
-                       zlog_debug("%s: RTA_SDL_GET nlen %d, alen %d",         \
-                                  __func__, sdl->sdl_nlen, sdl->sdl_alen);    \
-               if (((DEST) != NULL) && (sdl->sdl_family == AF_LINK)           \
-                   && (sdl->sdl_nlen < IFNAMSIZ) && (sdl->sdl_nlen <= len)) { \
-                       memcpy(pdest, sdl->sdl_data, sdl->sdl_nlen);           \
-                       pdest[sdl->sdl_nlen] = '\0';                           \
-                       (LEN) = sdl->sdl_nlen;                                 \
-               }                                                              \
-               (PNT) += len;                                                  \
-       } else {                                                               \
-               (LEN) = 0;                                                     \
-       }
 /* Routing socket message types. */
 const struct message rtm_type_str[] = {{RTM_ADD, "RTM_ADD"},
                                       {RTM_DELETE, "RTM_DELETE"},
@@ -284,6 +230,9 @@ int dplane_routing_sock = -1;
 /* Yes I'm checking ugly routing socket behavior. */
 /* #define DEBUG */
 
+size_t rta_get(caddr_t sap, void *dest, size_t destlen);
+size_t rta_getsdlname(caddr_t sap, void *dest, short *destlen);
+
 /* Supported address family check. */
 static inline int af_check(int family)
 {
@@ -294,6 +243,63 @@ static inline int af_check(int family)
        return 0;
 }
 
+size_t rta_get(caddr_t sap, void *destp, size_t destlen)
+{
+       struct sockaddr *sa = (struct sockaddr *)sap;
+       uint8_t *dest = destp;
+       size_t tlen, copylen;
+
+#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
+       copylen = sa->sa_len;
+       tlen = (copylen == 0) ? sizeof(ROUNDUP_TYPE) : ROUNDUP(copylen);
+#else  /* !HAVE_STRUCT_SOCKADDR_SA_LEN */
+       copylen = tlen = SAROUNDUP(sap);
+#endif /* !HAVE_STRUCT_SOCKADDR_SA_LEN */
+
+       if (copylen > 0 && dest != NULL && af_check(sa->sa_family)) {
+               if (copylen > destlen) {
+                       zlog_warn("%s: destination buffer too small (%lu vs %lu)",
+                                 __func__, copylen, destlen);
+                       memcpy(dest, sap, destlen);
+               } else
+                       memcpy(dest, sap, copylen);
+       }
+
+       return tlen;
+}
+
+size_t rta_getsdlname(caddr_t sap, void *destp, short *destlen)
+{
+       struct sockaddr_dl *sdl = (struct sockaddr_dl *)sap;
+       struct sockaddr *sa = (struct sockaddr *)sap;
+       uint8_t *dest = destp;
+       size_t tlen, copylen;
+
+       copylen = sdl->sdl_nlen;
+#ifdef HAVE_STRUCT_SOCKADDR_SA_LEN
+       tlen = (sa->sa_len == 0) ? sizeof(ROUNDUP_TYPE) : ROUNDUP(sa->sa_len);
+#else  /* !HAVE_STRUCT_SOCKADDR_SA_LEN */
+       tlen = SAROUNDUP(sap);
+#endif /* !HAVE_STRUCT_SOCKADDR_SA_LEN */
+
+       if (copylen > 0 && dest != NULL && sdl->sdl_family == AF_LINK) {
+               if (copylen > IFNAMSIZ) {
+                       zlog_warn("%s: destination buffer too small (%lu vs %d)",
+                                 __func__, copylen, IFNAMSIZ);
+                       memcpy(dest, sdl->sdl_data, IFNAMSIZ);
+                       dest[IFNAMSIZ] = 0;
+                       *destlen = IFNAMSIZ;
+               } else {
+                       memcpy(dest, sdl->sdl_data, copylen);
+                       dest[copylen] = 0;
+                       *destlen = copylen;
+               }
+       } else
+               *destlen = 0;
+
+       return tlen;
+}
+
 /* Dump routing table flag for debug purpose. */
 static void rtm_flag_dump(int flag)
 {
@@ -406,6 +412,7 @@ int ifm_read(struct if_msghdr *ifm)
        struct sockaddr_dl *sdl;
        char ifname[IFNAMSIZ];
        short ifnlen = 0;
+       int maskbit;
        caddr_t cp;
 
        /* terminate ifname at head (for strnlen) and tail (for safety) */
@@ -437,21 +444,19 @@ int ifm_read(struct if_msghdr *ifm)
                cp = cp + 12;
 #endif
 
-       RTA_ADDR_GET(NULL, RTA_DST, ifm->ifm_addrs, cp);
-       RTA_ADDR_GET(NULL, RTA_GATEWAY, ifm->ifm_addrs, cp);
-       RTA_ATTR_GET(NULL, RTA_NETMASK, ifm->ifm_addrs, cp);
-       RTA_ADDR_GET(NULL, RTA_GENMASK, ifm->ifm_addrs, cp);
-       sdl = (struct sockaddr_dl *)cp;
-       RTA_NAME_GET(ifname, RTA_IFP, ifm->ifm_addrs, cp, ifnlen);
-       RTA_ADDR_GET(NULL, RTA_IFA, ifm->ifm_addrs, cp);
-       RTA_ADDR_GET(NULL, RTA_AUTHOR, ifm->ifm_addrs, cp);
-       RTA_ADDR_GET(NULL, RTA_BRD, ifm->ifm_addrs, cp);
-#ifdef RTA_LABEL
-       RTA_ATTR_GET(NULL, RTA_LABEL, ifm->ifm_addrs, cp);
-#endif
-#ifdef RTA_SRC
-       RTA_ADDR_GET(NULL, RTA_SRC, ifm->ifm_addrs, cp);
-#endif
+       /* Look up for RTA_IFP and skip others. */
+       for (maskbit = 1; maskbit; maskbit <<= 1) {
+               if ((maskbit & ifm->ifm_addrs) == 0)
+                       continue;
+               if (maskbit != RTA_IFP) {
+                       cp += rta_get(cp, NULL, 0);
+                       continue;
+               }
+
+               /* Save the pointer to the structure. */
+               sdl = (struct sockaddr_dl *)cp;
+               cp += rta_getsdlname(cp, ifname, &ifnlen);
+       }
 
        if (IS_ZEBRA_DEBUG_KERNEL)
                zlog_debug("%s: sdl ifname %s", __func__,
@@ -558,7 +563,7 @@ int ifm_read(struct if_msghdr *ifm)
                 *  - Solaris has no sdl_len, but sdl_data[244]
                 *    presumably, it's not going to run past that, so sizeof()
                 *    is fine here.
-                * a nonzero ifnlen from RTA_NAME_GET() means sdl is valid
+                * a nonzero ifnlen from rta_getsdlname() means sdl is valid
                 */
                ifp->ll_type = ZEBRA_LLT_UNKNOWN;
                ifp->hw_addr_len = 0;
@@ -652,6 +657,7 @@ static void ifam_read_mesg(struct ifa_msghdr *ifm, union sockunion *addr,
        caddr_t pnt, end;
        union sockunion dst;
        union sockunion gateway;
+       int maskbit;
 
        pnt = (caddr_t)(ifm + 1);
        end = ((caddr_t)ifm) + ifm->ifam_msglen;
@@ -664,20 +670,41 @@ static void ifam_read_mesg(struct ifa_msghdr *ifm, union sockunion *addr,
        memset(&gateway, 0, sizeof(union sockunion));
 
        /* We fetch each socket variable into sockunion. */
-       RTA_ADDR_GET(&dst, RTA_DST, ifm->ifam_addrs, pnt);
-       RTA_ADDR_GET(&gateway, RTA_GATEWAY, ifm->ifam_addrs, pnt);
-       RTA_ATTR_GET(mask, RTA_NETMASK, ifm->ifam_addrs, pnt);
-       RTA_ADDR_GET(NULL, RTA_GENMASK, ifm->ifam_addrs, pnt);
-       RTA_NAME_GET(ifname, RTA_IFP, ifm->ifam_addrs, pnt, *ifnlen);
-       RTA_ADDR_GET(addr, RTA_IFA, ifm->ifam_addrs, pnt);
-       RTA_ADDR_GET(NULL, RTA_AUTHOR, ifm->ifam_addrs, pnt);
-       RTA_ADDR_GET(brd, RTA_BRD, ifm->ifam_addrs, pnt);
-#ifdef RTA_LABEL
-       RTA_ATTR_GET(NULL, RTA_LABEL, ifm->ifam_addrs, pnt);
-#endif
-#ifdef RTA_SRC
-       RTA_ADDR_GET(NULL, RTA_SRC, ifm->ifam_addrs, pnt);
-#endif
+       for (maskbit = 1; maskbit; maskbit <<= 1) {
+               if ((maskbit & ifm->ifam_addrs) == 0)
+                       continue;
+
+               switch (maskbit) {
+               case RTA_DST:
+                       pnt += rta_get(pnt, &dst, sizeof(dst));
+                       break;
+               case RTA_GATEWAY:
+                       pnt += rta_get(pnt, &gateway, sizeof(gateway));
+                       break;
+               case RTA_NETMASK:
+                       pnt += rta_get(pnt, mask, sizeof(*mask));
+                       break;
+               case RTA_IFP:
+                       pnt += rta_getsdlname(pnt, ifname, ifnlen);
+                       break;
+               case RTA_IFA:
+                       pnt += rta_get(pnt, addr, sizeof(*addr));
+                       break;
+               case RTA_BRD:
+                       pnt += rta_get(pnt, brd, sizeof(*brd));
+                       break;
+
+               default:
+                       pnt += rta_get(pnt, NULL, 0);
+                       break;
+               }
+
+               if (pnt > end) {
+                       zlog_warn("%s: overflow detected (pnt:%p end:%p)",
+                                 __func__, pnt, end);
+                       break;
+               }
+       }
 
        if (IS_ZEBRA_DEBUG_KERNEL) {
                int family = sockunion_family(addr);
@@ -712,6 +739,7 @@ static void ifam_read_mesg(struct ifa_msghdr *ifm, union sockunion *addr,
        }
 
        /* Assert read up end point matches to end point */
+       pnt = (caddr_t)ROUNDUP((size_t)pnt);
        if (pnt != end)
                zlog_debug("ifam_read() doesn't read all socket data");
 }
@@ -820,6 +848,7 @@ static int rtm_read_mesg(struct rt_msghdr *rtm, union sockunion *dest,
                         char *ifname, short *ifnlen)
 {
        caddr_t pnt, end;
+       int maskbit;
 
        /* Pnt points out socket data start point. */
        pnt = (caddr_t)(rtm + 1);
@@ -838,25 +867,36 @@ static int rtm_read_mesg(struct rt_msghdr *rtm, union sockunion *dest,
        memset(mask, 0, sizeof(union sockunion));
 
        /* We fetch each socket variable into sockunion. */
-       RTA_ADDR_GET(dest, RTA_DST, rtm->rtm_addrs, pnt);
-       RTA_ADDR_GET(gate, RTA_GATEWAY, rtm->rtm_addrs, pnt);
-       RTA_ATTR_GET(mask, RTA_NETMASK, rtm->rtm_addrs, pnt);
-       RTA_ADDR_GET(NULL, RTA_GENMASK, rtm->rtm_addrs, pnt);
-       RTA_NAME_GET(ifname, RTA_IFP, rtm->rtm_addrs, pnt, *ifnlen);
-       RTA_ADDR_GET(NULL, RTA_IFA, rtm->rtm_addrs, pnt);
-       RTA_ADDR_GET(NULL, RTA_AUTHOR, rtm->rtm_addrs, pnt);
-       RTA_ADDR_GET(NULL, RTA_BRD, rtm->rtm_addrs, pnt);
-#ifdef RTA_LABEL
-#if 0
-       union sockunion label;
-       memset(&label, 0, sizeof(label));
-       RTA_ATTR_GET(&label, RTA_LABEL, rtm->rtm_addrs, pnt);
-#endif
-       RTA_ATTR_GET(NULL, RTA_LABEL, rtm->rtm_addrs, pnt);
-#endif
-#ifdef RTA_SRC
-       RTA_ADDR_GET(NULL, RTA_SRC, rtm->rtm_addrs, pnt);
-#endif
+       /* We fetch each socket variable into sockunion. */
+       for (maskbit = 1; maskbit; maskbit <<= 1) {
+               if ((maskbit & rtm->rtm_addrs) == 0)
+                       continue;
+
+               switch (maskbit) {
+               case RTA_DST:
+                       pnt += rta_get(pnt, dest, sizeof(*dest));
+                       break;
+               case RTA_GATEWAY:
+                       pnt += rta_get(pnt, gate, sizeof(*gate));
+                       break;
+               case RTA_NETMASK:
+                       pnt += rta_get(pnt, mask, sizeof(*mask));
+                       break;
+               case RTA_IFP:
+                       pnt += rta_getsdlname(pnt, ifname, ifnlen);
+                       break;
+
+               default:
+                       pnt += rta_get(pnt, NULL, 0);
+                       break;
+               }
+
+               if (pnt > end) {
+                       zlog_warn("%s: overflow detected (pnt:%p end:%p)",
+                                 __func__, pnt, end);
+                       break;
+               }
+       }
 
        /* If there is netmask information set it's family same as
           destination family*/