]> git.proxmox.com Git - mirror_ubuntu-zesty-kernel.git/blobdiff - net/sctp/input.c
UBUNTU: SAUCE: hio: Fix incorrect use of enum req_opf values
[mirror_ubuntu-zesty-kernel.git] / net / sctp / input.c
index 1555fb8c68e0ed753e67d9c0c857a7d62be71743..458e506ef84bae3c53c239d6cf89a9349faafb11 100644 (file)
@@ -181,9 +181,10 @@ int sctp_rcv(struct sk_buff *skb)
         * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB
         */
        if (sk->sk_bound_dev_if && (sk->sk_bound_dev_if != af->skb_iif(skb))) {
-               if (asoc) {
-                       sctp_association_put(asoc);
+               if (transport) {
+                       sctp_transport_put(transport);
                        asoc = NULL;
+                       transport = NULL;
                } else {
                        sctp_endpoint_put(ep);
                        ep = NULL;
@@ -269,8 +270,8 @@ int sctp_rcv(struct sk_buff *skb)
        bh_unlock_sock(sk);
 
        /* Release the asoc/ep ref we took in the lookup calls. */
-       if (asoc)
-               sctp_association_put(asoc);
+       if (transport)
+               sctp_transport_put(transport);
        else
                sctp_endpoint_put(ep);
 
@@ -283,8 +284,8 @@ discard_it:
 
 discard_release:
        /* Release the asoc/ep ref we took in the lookup calls. */
-       if (asoc)
-               sctp_association_put(asoc);
+       if (transport)
+               sctp_transport_put(transport);
        else
                sctp_endpoint_put(ep);
 
@@ -300,6 +301,7 @@ int sctp_backlog_rcv(struct sock *sk, struct sk_buff *skb)
 {
        struct sctp_chunk *chunk = SCTP_INPUT_CB(skb)->chunk;
        struct sctp_inq *inqueue = &chunk->rcvr->inqueue;
+       struct sctp_transport *t = chunk->transport;
        struct sctp_ep_common *rcvr = NULL;
        int backloged = 0;
 
@@ -351,7 +353,7 @@ int sctp_backlog_rcv(struct sock *sk, struct sk_buff *skb)
 done:
        /* Release the refs we took in sctp_add_backlog */
        if (SCTP_EP_TYPE_ASSOCIATION == rcvr->type)
-               sctp_association_put(sctp_assoc(rcvr));
+               sctp_transport_put(t);
        else if (SCTP_EP_TYPE_SOCKET == rcvr->type)
                sctp_endpoint_put(sctp_ep(rcvr));
        else
@@ -363,6 +365,7 @@ done:
 static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb)
 {
        struct sctp_chunk *chunk = SCTP_INPUT_CB(skb)->chunk;
+       struct sctp_transport *t = chunk->transport;
        struct sctp_ep_common *rcvr = chunk->rcvr;
        int ret;
 
@@ -373,7 +376,7 @@ static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb)
                 * from us
                 */
                if (SCTP_EP_TYPE_ASSOCIATION == rcvr->type)
-                       sctp_association_hold(sctp_assoc(rcvr));
+                       sctp_transport_hold(t);
                else if (SCTP_EP_TYPE_SOCKET == rcvr->type)
                        sctp_endpoint_hold(sctp_ep(rcvr));
                else
@@ -537,15 +540,15 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb,
        return sk;
 
 out:
-       sctp_association_put(asoc);
+       sctp_transport_put(transport);
        return NULL;
 }
 
 /* Common cleanup code for icmp/icmpv6 error handler. */
-void sctp_err_finish(struct sock *sk, struct sctp_association *asoc)
+void sctp_err_finish(struct sock *sk, struct sctp_transport *t)
 {
        bh_unlock_sock(sk);
-       sctp_association_put(asoc);
+       sctp_transport_put(t);
 }
 
 /*
@@ -605,7 +608,7 @@ void sctp_v4_err(struct sk_buff *skb, __u32 info)
                /* PMTU discovery (RFC1191) */
                if (ICMP_FRAG_NEEDED == code) {
                        sctp_icmp_frag_needed(sk, asoc, transport,
-                                             WORD_TRUNC(info));
+                                             SCTP_TRUNC4(info));
                        goto out_unlock;
                } else {
                        if (ICMP_PROT_UNREACH == code) {
@@ -641,7 +644,7 @@ void sctp_v4_err(struct sk_buff *skb, __u32 info)
        }
 
 out_unlock:
-       sctp_err_finish(sk, asoc);
+       sctp_err_finish(sk, transport);
 }
 
 /*
@@ -673,7 +676,7 @@ static int sctp_rcv_ootb(struct sk_buff *skb)
                if (ntohs(ch->length) < sizeof(sctp_chunkhdr_t))
                        break;
 
-               ch_end = offset + WORD_ROUND(ntohs(ch->length));
+               ch_end = offset + SCTP_PAD4(ntohs(ch->length));
                if (ch_end > skb->len)
                        break;
 
@@ -787,10 +790,9 @@ hit:
 
 /* rhashtable for transport */
 struct sctp_hash_cmp_arg {
-       const struct sctp_endpoint      *ep;
-       const union sctp_addr           *laddr;
-       const union sctp_addr           *paddr;
-       const struct net                *net;
+       const union sctp_addr   *paddr;
+       const struct net        *net;
+       u16                     lport;
 };
 
 static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
@@ -798,7 +800,6 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
 {
        struct sctp_transport *t = (struct sctp_transport *)ptr;
        const struct sctp_hash_cmp_arg *x = arg->key;
-       struct sctp_association *asoc;
        int err = 1;
 
        if (!sctp_cmp_addr_exact(&t->ipaddr, x->paddr))
@@ -806,19 +807,10 @@ static inline int sctp_hash_cmp(struct rhashtable_compare_arg *arg,
        if (!sctp_transport_hold(t))
                return err;
 
-       asoc = t->asoc;
-       if (!net_eq(sock_net(asoc->base.sk), x->net))
+       if (!net_eq(sock_net(t->asoc->base.sk), x->net))
+               goto out;
+       if (x->lport != htons(t->asoc->base.bind_addr.port))
                goto out;
-       if (x->ep) {
-               if (x->ep != asoc->ep)
-                       goto out;
-       } else {
-               if (x->laddr->v4.sin_port != htons(asoc->base.bind_addr.port))
-                       goto out;
-               if (!sctp_bind_addr_match(&asoc->base.bind_addr,
-                                         x->laddr, sctp_sk(asoc->base.sk)))
-                       goto out;
-       }
 
        err = 0;
 out:
@@ -848,11 +840,9 @@ static inline u32 sctp_hash_key(const void *data, u32 len, u32 seed)
        const struct sctp_hash_cmp_arg *x = data;
        const union sctp_addr *paddr = x->paddr;
        const struct net *net = x->net;
-       u16 lport;
+       u16 lport = x->lport;
        u32 addr;
 
-       lport = x->ep ? htons(x->ep->base.bind_addr.port) :
-                       x->laddr->v4.sin_port;
        if (paddr->sa.sa_family == AF_INET6)
                addr = jhash(&paddr->v6.sin6_addr, 16, seed);
        else
@@ -872,29 +862,32 @@ static const struct rhashtable_params sctp_hash_params = {
 
 int sctp_transport_hashtable_init(void)
 {
-       return rhashtable_init(&sctp_transport_hashtable, &sctp_hash_params);
+       return rhltable_init(&sctp_transport_hashtable, &sctp_hash_params);
 }
 
 void sctp_transport_hashtable_destroy(void)
 {
-       rhashtable_destroy(&sctp_transport_hashtable);
+       rhltable_destroy(&sctp_transport_hashtable);
 }
 
-void sctp_hash_transport(struct sctp_transport *t)
+int sctp_hash_transport(struct sctp_transport *t)
 {
        struct sctp_hash_cmp_arg arg;
+       int err;
 
        if (t->asoc->temp)
-               return;
+               return 0;
 
-       arg.ep = t->asoc->ep;
-       arg.paddr = &t->ipaddr;
        arg.net   = sock_net(t->asoc->base.sk);
+       arg.paddr = &t->ipaddr;
+       arg.lport = htons(t->asoc->base.bind_addr.port);
+
+       err = rhltable_insert_key(&sctp_transport_hashtable, &arg,
+                                 &t->node, sctp_hash_params);
+       if (err)
+               pr_err_once("insert transport fail, errno %d\n", err);
 
-reinsert:
-       if (rhashtable_lookup_insert_key(&sctp_transport_hashtable, &arg,
-                                        &t->node, sctp_hash_params) == -EBUSY)
-               goto reinsert;
+       return err;
 }
 
 void sctp_unhash_transport(struct sctp_transport *t)
@@ -902,39 +895,62 @@ void sctp_unhash_transport(struct sctp_transport *t)
        if (t->asoc->temp)
                return;
 
-       rhashtable_remove_fast(&sctp_transport_hashtable, &t->node,
-                              sctp_hash_params);
+       rhltable_remove(&sctp_transport_hashtable, &t->node,
+                       sctp_hash_params);
 }
 
+/* return a transport with holding it */
 struct sctp_transport *sctp_addrs_lookup_transport(
                                struct net *net,
                                const union sctp_addr *laddr,
                                const union sctp_addr *paddr)
 {
+       struct rhlist_head *tmp, *list;
+       struct sctp_transport *t;
        struct sctp_hash_cmp_arg arg = {
-               .ep    = NULL,
-               .laddr = laddr,
                .paddr = paddr,
                .net   = net,
+               .lport = laddr->v4.sin_port,
        };
 
-       return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg,
-                                     sctp_hash_params);
+       list = rhltable_lookup(&sctp_transport_hashtable, &arg,
+                              sctp_hash_params);
+
+       rhl_for_each_entry_rcu(t, tmp, list, node) {
+               if (!sctp_transport_hold(t))
+                       continue;
+
+               if (sctp_bind_addr_match(&t->asoc->base.bind_addr,
+                                        laddr, sctp_sk(t->asoc->base.sk)))
+                       return t;
+               sctp_transport_put(t);
+       }
+
+       return NULL;
 }
 
+/* return a transport without holding it, as it's only used under sock lock */
 struct sctp_transport *sctp_epaddr_lookup_transport(
                                const struct sctp_endpoint *ep,
                                const union sctp_addr *paddr)
 {
        struct net *net = sock_net(ep->base.sk);
+       struct rhlist_head *tmp, *list;
+       struct sctp_transport *t;
        struct sctp_hash_cmp_arg arg = {
-               .ep    = ep,
                .paddr = paddr,
                .net   = net,
+               .lport = htons(ep->base.bind_addr.port),
        };
 
-       return rhashtable_lookup_fast(&sctp_transport_hashtable, &arg,
-                                     sctp_hash_params);
+       list = rhltable_lookup(&sctp_transport_hashtable, &arg,
+                              sctp_hash_params);
+
+       rhl_for_each_entry_rcu(t, tmp, list, node)
+               if (ep == t->asoc->ep)
+                       return t;
+
+       return NULL;
 }
 
 /* Look up an association. */
@@ -948,15 +964,12 @@ static struct sctp_association *__sctp_lookup_association(
        struct sctp_association *asoc = NULL;
 
        t = sctp_addrs_lookup_transport(net, local, peer);
-       if (!t || !sctp_transport_hold(t))
+       if (!t)
                goto out;
 
        asoc = t->asoc;
-       sctp_association_hold(asoc);
        *pt = t;
 
-       sctp_transport_put(t);
-
 out:
        return asoc;
 }
@@ -986,7 +999,7 @@ int sctp_has_association(struct net *net,
        struct sctp_transport *transport;
 
        if ((asoc = sctp_lookup_association(net, laddr, paddr, &transport))) {
-               sctp_association_put(asoc);
+               sctp_transport_put(transport);
                return 1;
        }
 
@@ -1021,7 +1034,6 @@ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net,
        struct sctphdr *sh = sctp_hdr(skb);
        union sctp_params params;
        sctp_init_chunk_t *init;
-       struct sctp_transport *transport;
        struct sctp_af *af;
 
        /*
@@ -1052,7 +1064,7 @@ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net,
 
                af->from_addr_param(paddr, params.addr, sh->source, 0);
 
-               asoc = __sctp_lookup_association(net, laddr, paddr, &transport);
+               asoc = __sctp_lookup_association(net, laddr, paddr, transportp);
                if (asoc)
                        return asoc;
        }
@@ -1128,7 +1140,7 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net,
                if (ntohs(ch->length) < sizeof(sctp_chunkhdr_t))
                        break;
 
-               ch_end = ((__u8 *)ch) + WORD_ROUND(ntohs(ch->length));
+               ch_end = ((__u8 *)ch) + SCTP_PAD4(ntohs(ch->length));
                if (ch_end > skb_tail_pointer(skb))
                        break;
 
@@ -1197,7 +1209,7 @@ static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net,
         * that the chunk length doesn't cause overflow.  Otherwise, we'll
         * walk off the end.
         */
-       if (WORD_ROUND(ntohs(ch->length)) > skb->len)
+       if (SCTP_PAD4(ntohs(ch->length)) > skb->len)
                return NULL;
 
        /* If this is INIT/INIT-ACK look inside the chunk too. */