]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - net/l2tp/l2tp_core.c
l2tp: take reference on sessions being dumped
[mirror_ubuntu-bionic-kernel.git] / net / l2tp / l2tp_core.c
index 8adab6335ced9f1018318094be20c132a70f8475..e37d9554da7b47df0571c41478e6e758b8a8e6f9 100644 (file)
@@ -278,7 +278,57 @@ struct l2tp_session *l2tp_session_find(struct net *net, struct l2tp_tunnel *tunn
 }
 EXPORT_SYMBOL_GPL(l2tp_session_find);
 
-struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
+/* Like l2tp_session_find() but takes a reference on the returned session.
+ * Optionally calls session->ref() too if do_ref is true.
+ */
+struct l2tp_session *l2tp_session_get(struct net *net,
+                                     struct l2tp_tunnel *tunnel,
+                                     u32 session_id, bool do_ref)
+{
+       struct hlist_head *session_list;
+       struct l2tp_session *session;
+
+       if (!tunnel) {
+               struct l2tp_net *pn = l2tp_pernet(net);
+
+               session_list = l2tp_session_id_hash_2(pn, session_id);
+
+               rcu_read_lock_bh();
+               hlist_for_each_entry_rcu(session, session_list, global_hlist) {
+                       if (session->session_id == session_id) {
+                               l2tp_session_inc_refcount(session);
+                               if (do_ref && session->ref)
+                                       session->ref(session);
+                               rcu_read_unlock_bh();
+
+                               return session;
+                       }
+               }
+               rcu_read_unlock_bh();
+
+               return NULL;
+       }
+
+       session_list = l2tp_session_id_hash(tunnel, session_id);
+       read_lock_bh(&tunnel->hlist_lock);
+       hlist_for_each_entry(session, session_list, hlist) {
+               if (session->session_id == session_id) {
+                       l2tp_session_inc_refcount(session);
+                       if (do_ref && session->ref)
+                               session->ref(session);
+                       read_unlock_bh(&tunnel->hlist_lock);
+
+                       return session;
+               }
+       }
+       read_unlock_bh(&tunnel->hlist_lock);
+
+       return NULL;
+}
+EXPORT_SYMBOL_GPL(l2tp_session_get);
+
+struct l2tp_session *l2tp_session_get_nth(struct l2tp_tunnel *tunnel, int nth,
+                                         bool do_ref)
 {
        int hash;
        struct l2tp_session *session;
@@ -288,6 +338,9 @@ struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
        for (hash = 0; hash < L2TP_HASH_SIZE; hash++) {
                hlist_for_each_entry(session, &tunnel->session_hlist[hash], hlist) {
                        if (++count > nth) {
+                               l2tp_session_inc_refcount(session);
+                               if (do_ref && session->ref)
+                                       session->ref(session);
                                read_unlock_bh(&tunnel->hlist_lock);
                                return session;
                        }
@@ -298,12 +351,13 @@ struct l2tp_session *l2tp_session_find_nth(struct l2tp_tunnel *tunnel, int nth)
 
        return NULL;
 }
-EXPORT_SYMBOL_GPL(l2tp_session_find_nth);
+EXPORT_SYMBOL_GPL(l2tp_session_get_nth);
 
 /* Lookup a session by interface name.
  * This is very inefficient but is only used by management interfaces.
  */
-struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
+struct l2tp_session *l2tp_session_get_by_ifname(struct net *net, char *ifname,
+                                               bool do_ref)
 {
        struct l2tp_net *pn = l2tp_pernet(net);
        int hash;
@@ -313,7 +367,11 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
        for (hash = 0; hash < L2TP_HASH_SIZE_2; hash++) {
                hlist_for_each_entry_rcu(session, &pn->l2tp_session_hlist[hash], global_hlist) {
                        if (!strcmp(session->ifname, ifname)) {
+                               l2tp_session_inc_refcount(session);
+                               if (do_ref && session->ref)
+                                       session->ref(session);
                                rcu_read_unlock_bh();
+
                                return session;
                        }
                }
@@ -323,7 +381,49 @@ struct l2tp_session *l2tp_session_find_by_ifname(struct net *net, char *ifname)
 
        return NULL;
 }
-EXPORT_SYMBOL_GPL(l2tp_session_find_by_ifname);
+EXPORT_SYMBOL_GPL(l2tp_session_get_by_ifname);
+
+static int l2tp_session_add_to_tunnel(struct l2tp_tunnel *tunnel,
+                                     struct l2tp_session *session)
+{
+       struct l2tp_session *session_walk;
+       struct hlist_head *g_head;
+       struct hlist_head *head;
+       struct l2tp_net *pn;
+
+       head = l2tp_session_id_hash(tunnel, session->session_id);
+
+       write_lock_bh(&tunnel->hlist_lock);
+       hlist_for_each_entry(session_walk, head, hlist)
+               if (session_walk->session_id == session->session_id)
+                       goto exist;
+
+       if (tunnel->version == L2TP_HDR_VER_3) {
+               pn = l2tp_pernet(tunnel->l2tp_net);
+               g_head = l2tp_session_id_hash_2(l2tp_pernet(tunnel->l2tp_net),
+                                               session->session_id);
+
+               spin_lock_bh(&pn->l2tp_session_hlist_lock);
+               hlist_for_each_entry(session_walk, g_head, global_hlist)
+                       if (session_walk->session_id == session->session_id)
+                               goto exist_glob;
+
+               hlist_add_head_rcu(&session->global_hlist, g_head);
+               spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+       }
+
+       hlist_add_head(&session->hlist, head);
+       write_unlock_bh(&tunnel->hlist_lock);
+
+       return 0;
+
+exist_glob:
+       spin_unlock_bh(&pn->l2tp_session_hlist_lock);
+exist:
+       write_unlock_bh(&tunnel->hlist_lock);
+
+       return -EEXIST;
+}
 
 /* Lookup a tunnel by id
  */
@@ -633,6 +733,9 @@ discard:
  * a data (not control) frame before coming here. Fields up to the
  * session-id have already been parsed and ptr points to the data
  * after the session-id.
+ *
+ * session->ref() must have been called prior to l2tp_recv_common().
+ * session->deref() will be called automatically after skb is processed.
  */
 void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
                      unsigned char *ptr, unsigned char *optr, u16 hdrflags,
@@ -642,14 +745,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
        int offset;
        u32 ns, nr;
 
-       /* The ref count is increased since we now hold a pointer to
-        * the session. Take care to decrement the refcnt when exiting
-        * this function from now on...
-        */
-       l2tp_session_inc_refcount(session);
-       if (session->ref)
-               (*session->ref)(session);
-
        /* Parse and check optional cookie */
        if (session->peer_cookie_len > 0) {
                if (memcmp(ptr, &session->peer_cookie[0], session->peer_cookie_len)) {
@@ -802,8 +897,6 @@ void l2tp_recv_common(struct l2tp_session *session, struct sk_buff *skb,
        /* Try to dequeue as many skbs from reorder_q as we can. */
        l2tp_recv_dequeue(session);
 
-       l2tp_session_dec_refcount(session);
-
        return;
 
 discard:
@@ -812,8 +905,6 @@ discard:
 
        if (session->deref)
                (*session->deref)(session);
-
-       l2tp_session_dec_refcount(session);
 }
 EXPORT_SYMBOL(l2tp_recv_common);
 
@@ -920,8 +1011,14 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
        }
 
        /* Find the session context */
-       session = l2tp_session_find(tunnel->l2tp_net, tunnel, session_id);
+       session = l2tp_session_get(tunnel->l2tp_net, tunnel, session_id, true);
        if (!session || !session->recv_skb) {
+               if (session) {
+                       if (session->deref)
+                               session->deref(session);
+                       l2tp_session_dec_refcount(session);
+               }
+
                /* Not found? Pass to userspace to deal with */
                l2tp_info(tunnel, L2TP_MSG_DATA,
                          "%s: no session found (%u/%u). Passing up.\n",
@@ -930,6 +1027,7 @@ static int l2tp_udp_recv_core(struct l2tp_tunnel *tunnel, struct sk_buff *skb,
        }
 
        l2tp_recv_common(session, skb, ptr, optr, hdrflags, length, payload_hook);
+       l2tp_session_dec_refcount(session);
 
        return 0;
 
@@ -1738,6 +1836,7 @@ EXPORT_SYMBOL_GPL(l2tp_session_set_header_len);
 struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunnel, u32 session_id, u32 peer_session_id, struct l2tp_session_cfg *cfg)
 {
        struct l2tp_session *session;
+       int err;
 
        session = kzalloc(sizeof(struct l2tp_session) + priv_size, GFP_KERNEL);
        if (session != NULL) {
@@ -1793,6 +1892,13 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
 
                l2tp_session_set_header_len(session, tunnel->version);
 
+               err = l2tp_session_add_to_tunnel(tunnel, session);
+               if (err) {
+                       kfree(session);
+
+                       return ERR_PTR(err);
+               }
+
                /* Bump the reference count. The session context is deleted
                 * only when this drops to zero.
                 */
@@ -1802,28 +1908,14 @@ struct l2tp_session *l2tp_session_create(int priv_size, struct l2tp_tunnel *tunn
                /* Ensure tunnel socket isn't deleted */
                sock_hold(tunnel->sock);
 
-               /* Add session to the tunnel's hash list */
-               write_lock_bh(&tunnel->hlist_lock);
-               hlist_add_head(&session->hlist,
-                              l2tp_session_id_hash(tunnel, session_id));
-               write_unlock_bh(&tunnel->hlist_lock);
-
-               /* And to the global session list if L2TPv3 */
-               if (tunnel->version != L2TP_HDR_VER_2) {
-                       struct l2tp_net *pn = l2tp_pernet(tunnel->l2tp_net);
-
-                       spin_lock_bh(&pn->l2tp_session_hlist_lock);
-                       hlist_add_head_rcu(&session->global_hlist,
-                                          l2tp_session_id_hash_2(pn, session_id));
-                       spin_unlock_bh(&pn->l2tp_session_hlist_lock);
-               }
-
                /* Ignore management session in session count value */
                if (session->session_id != 0)
                        atomic_inc(&l2tp_session_count);
+
+               return session;
        }
 
-       return session;
+       return ERR_PTR(-ENOMEM);
 }
 EXPORT_SYMBOL_GPL(l2tp_session_create);