]> git.proxmox.com Git - mirror_ubuntu-artful-kernel.git/blobdiff - net/xfrm/xfrm_state.c
xfrm: state: simplify rcu_read_unlock handling in two spots
[mirror_ubuntu-artful-kernel.git] / net / xfrm / xfrm_state.c
index a30f898dc1c5a82fb6221248799f5adbb07f2236..a62097e640b5690edc4d6f9bf794d1a2c80725f7 100644 (file)
@@ -20,7 +20,7 @@
 #include <linux/module.h>
 #include <linux/cache.h>
 #include <linux/audit.h>
-#include <asm/uaccess.h>
+#include <linux/uaccess.h>
 #include <linux/ktime.h>
 #include <linux/slab.h>
 #include <linux/interrupt.h>
 
 #include "xfrm_hash.h"
 
+#define xfrm_state_deref_prot(table, net) \
+       rcu_dereference_protected((table), lockdep_is_held(&(net)->xfrm.xfrm_state_lock))
+
+static void xfrm_state_gc_task(struct work_struct *work);
+
 /* Each xfrm_state may be linked to two tables:
 
    1. Hash table by (spi,daddr,ah/esp) to find SA by SPI. (input,ctl)
  */
 
 static unsigned int xfrm_state_hashmax __read_mostly = 1 * 1024 * 1024;
+static __read_mostly seqcount_t xfrm_state_hash_generation = SEQCNT_ZERO(xfrm_state_hash_generation);
+
+static DECLARE_WORK(xfrm_state_gc_work, xfrm_state_gc_task);
+static HLIST_HEAD(xfrm_state_gc_list);
+
+static inline bool xfrm_state_hold_rcu(struct xfrm_state __rcu *x)
+{
+       return atomic_inc_not_zero(&x->refcnt);
+}
 
 static inline unsigned int xfrm_dst_hash(struct net *net,
                                         const xfrm_address_t *daddr,
@@ -76,18 +90,18 @@ static void xfrm_hash_transfer(struct hlist_head *list,
                h = __xfrm_dst_hash(&x->id.daddr, &x->props.saddr,
                                    x->props.reqid, x->props.family,
                                    nhashmask);
-               hlist_add_head(&x->bydst, ndsttable+h);
+               hlist_add_head_rcu(&x->bydst, ndsttable + h);
 
                h = __xfrm_src_hash(&x->id.daddr, &x->props.saddr,
                                    x->props.family,
                                    nhashmask);
-               hlist_add_head(&x->bysrc, nsrctable+h);
+               hlist_add_head_rcu(&x->bysrc, nsrctable + h);
 
                if (x->id.spi) {
                        h = __xfrm_spi_hash(&x->id.daddr, x->id.spi,
                                            x->id.proto, x->props.family,
                                            nhashmask);
-                       hlist_add_head(&x->byspi, nspitable+h);
+                       hlist_add_head_rcu(&x->byspi, nspitable + h);
                }
        }
 }
@@ -122,25 +136,29 @@ static void xfrm_hash_resize(struct work_struct *work)
        }
 
        spin_lock_bh(&net->xfrm.xfrm_state_lock);
+       write_seqcount_begin(&xfrm_state_hash_generation);
 
        nhashmask = (nsize / sizeof(struct hlist_head)) - 1U;
+       odst = xfrm_state_deref_prot(net->xfrm.state_bydst, net);
        for (i = net->xfrm.state_hmask; i >= 0; i--)
-               xfrm_hash_transfer(net->xfrm.state_bydst+i, ndst, nsrc, nspi,
-                                  nhashmask);
+               xfrm_hash_transfer(odst + i, ndst, nsrc, nspi, nhashmask);
 
-       odst = net->xfrm.state_bydst;
-       osrc = net->xfrm.state_bysrc;
-       ospi = net->xfrm.state_byspi;
+       osrc = xfrm_state_deref_prot(net->xfrm.state_bysrc, net);
+       ospi = xfrm_state_deref_prot(net->xfrm.state_byspi, net);
        ohashmask = net->xfrm.state_hmask;
 
-       net->xfrm.state_bydst = ndst;
-       net->xfrm.state_bysrc = nsrc;
-       net->xfrm.state_byspi = nspi;
+       rcu_assign_pointer(net->xfrm.state_bydst, ndst);
+       rcu_assign_pointer(net->xfrm.state_bysrc, nsrc);
+       rcu_assign_pointer(net->xfrm.state_byspi, nspi);
        net->xfrm.state_hmask = nhashmask;
 
+       write_seqcount_end(&xfrm_state_hash_generation);
        spin_unlock_bh(&net->xfrm.xfrm_state_lock);
 
        osize = (ohashmask + 1) * sizeof(struct hlist_head);
+
+       synchronize_rcu();
+
        xfrm_hash_free(odst, osize);
        xfrm_hash_free(osrc, osize);
        xfrm_hash_free(ospi, osize);
@@ -174,7 +192,7 @@ int xfrm_register_type(const struct xfrm_type *type, unsigned short family)
        else
                err = -EEXIST;
        spin_unlock_bh(&xfrm_type_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_register_type);
@@ -195,7 +213,7 @@ int xfrm_unregister_type(const struct xfrm_type *type, unsigned short family)
        else
                typemap[type->proto] = NULL;
        spin_unlock_bh(&xfrm_type_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_unregister_type);
@@ -213,17 +231,18 @@ retry:
                return NULL;
        typemap = afinfo->type_map;
 
-       type = typemap[proto];
+       type = READ_ONCE(typemap[proto]);
        if (unlikely(type && !try_module_get(type->owner)))
                type = NULL;
+
+       rcu_read_unlock();
+
        if (!type && !modload_attempted) {
-               xfrm_state_put_afinfo(afinfo);
                request_module("xfrm-type-%d-%d", family, proto);
                modload_attempted = 1;
                goto retry;
        }
 
-       xfrm_state_put_afinfo(afinfo);
        return type;
 }
 
@@ -262,7 +281,7 @@ int xfrm_register_mode(struct xfrm_mode *mode, int family)
 
 out:
        spin_unlock_bh(&xfrm_mode_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_register_mode);
@@ -290,7 +309,7 @@ int xfrm_unregister_mode(struct xfrm_mode *mode, int family)
        }
 
        spin_unlock_bh(&xfrm_mode_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_unregister_mode);
@@ -309,17 +328,17 @@ retry:
        if (unlikely(afinfo == NULL))
                return NULL;
 
-       mode = afinfo->mode_map[encap];
+       mode = READ_ONCE(afinfo->mode_map[encap]);
        if (unlikely(mode && !try_module_get(mode->owner)))
                mode = NULL;
+
+       rcu_read_unlock();
        if (!mode && !modload_attempted) {
-               xfrm_state_put_afinfo(afinfo);
                request_module("xfrm-mode-%d-%d", family, encap);
                modload_attempted = 1;
                goto retry;
        }
 
-       xfrm_state_put_afinfo(afinfo);
        return mode;
 }
 
@@ -356,27 +375,20 @@ static void xfrm_state_gc_destroy(struct xfrm_state *x)
 
 static void xfrm_state_gc_task(struct work_struct *work)
 {
-       struct net *net = container_of(work, struct net, xfrm.state_gc_work);
        struct xfrm_state *x;
        struct hlist_node *tmp;
        struct hlist_head gc_list;
 
        spin_lock_bh(&xfrm_state_gc_lock);
-       hlist_move_list(&net->xfrm.state_gc_list, &gc_list);
+       hlist_move_list(&xfrm_state_gc_list, &gc_list);
        spin_unlock_bh(&xfrm_state_gc_lock);
 
+       synchronize_rcu();
+
        hlist_for_each_entry_safe(x, tmp, &gc_list, gclist)
                xfrm_state_gc_destroy(x);
 }
 
-static inline unsigned long make_jiffies(long secs)
-{
-       if (secs >= (MAX_SCHEDULE_TIMEOUT-1)/HZ)
-               return MAX_SCHEDULE_TIMEOUT-1;
-       else
-               return secs*HZ;
-}
-
 static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
 {
        struct tasklet_hrtimer *thr = container_of(me, struct tasklet_hrtimer, timer);
@@ -398,7 +410,7 @@ static enum hrtimer_restart xfrm_timer_handler(struct hrtimer *me)
                        if (x->xflags & XFRM_SOFT_EXPIRE) {
                                /* enter hard expire without soft expire first?!
                                 * setting a new date could trigger this.
-                                * workarbound: fix x->curflt.add_time by below:
+                                * workaround: fix x->curflt.add_time by below:
                                 */
                                x->curlft.add_time = now - x->saved_tmo - 1;
                                tmo = x->lft.hard_add_expires_seconds - x->saved_tmo;
@@ -501,14 +513,12 @@ EXPORT_SYMBOL(xfrm_state_alloc);
 
 void __xfrm_state_destroy(struct xfrm_state *x)
 {
-       struct net *net = xs_net(x);
-
        WARN_ON(x->km.state != XFRM_STATE_DEAD);
 
        spin_lock_bh(&xfrm_state_gc_lock);
-       hlist_add_head(&x->gclist, &net->xfrm.state_gc_list);
+       hlist_add_head(&x->gclist, &xfrm_state_gc_list);
        spin_unlock_bh(&xfrm_state_gc_lock);
-       schedule_work(&net->xfrm.state_gc_work);
+       schedule_work(&xfrm_state_gc_work);
 }
 EXPORT_SYMBOL(__xfrm_state_destroy);
 
@@ -521,10 +531,10 @@ int __xfrm_state_delete(struct xfrm_state *x)
                x->km.state = XFRM_STATE_DEAD;
                spin_lock(&net->xfrm.xfrm_state_lock);
                list_del(&x->km.all);
-               hlist_del(&x->bydst);
-               hlist_del(&x->bysrc);
+               hlist_del_rcu(&x->bydst);
+               hlist_del_rcu(&x->bysrc);
                if (x->id.spi)
-                       hlist_del(&x->byspi);
+                       hlist_del_rcu(&x->byspi);
                net->xfrm.state_num--;
                spin_unlock(&net->xfrm.xfrm_state_lock);
 
@@ -630,26 +640,23 @@ void xfrm_sad_getinfo(struct net *net, struct xfrmk_sadinfo *si)
 }
 EXPORT_SYMBOL(xfrm_sad_getinfo);
 
-static int
+static void
 xfrm_init_tempstate(struct xfrm_state *x, const struct flowi *fl,
                    const struct xfrm_tmpl *tmpl,
                    const xfrm_address_t *daddr, const xfrm_address_t *saddr,
                    unsigned short family)
 {
-       struct xfrm_state_afinfo *afinfo = xfrm_state_get_afinfo(family);
-       if (!afinfo)
-               return -1;
-       afinfo->init_tempsel(&x->sel, fl);
+       struct xfrm_state_afinfo *afinfo = xfrm_state_afinfo_get_rcu(family);
+
+       if (afinfo)
+               afinfo->init_tempsel(&x->sel, fl);
 
        if (family != tmpl->encap_family) {
-               xfrm_state_put_afinfo(afinfo);
-               afinfo = xfrm_state_get_afinfo(tmpl->encap_family);
+               afinfo = xfrm_state_afinfo_get_rcu(tmpl->encap_family);
                if (!afinfo)
-                       return -1;
+                       return;
        }
        afinfo->init_temprop(x, tmpl, daddr, saddr);
-       xfrm_state_put_afinfo(afinfo);
-       return 0;
 }
 
 static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
@@ -660,7 +667,7 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
        unsigned int h = xfrm_spi_hash(net, daddr, spi, proto, family);
        struct xfrm_state *x;
 
-       hlist_for_each_entry(x, net->xfrm.state_byspi+h, byspi) {
+       hlist_for_each_entry_rcu(x, net->xfrm.state_byspi + h, byspi) {
                if (x->props.family != family ||
                    x->id.spi       != spi ||
                    x->id.proto     != proto ||
@@ -669,7 +676,8 @@ static struct xfrm_state *__xfrm_state_lookup(struct net *net, u32 mark,
 
                if ((mark & x->mark.m) != x->mark.v)
                        continue;
-               xfrm_state_hold(x);
+               if (!xfrm_state_hold_rcu(x))
+                       continue;
                return x;
        }
 
@@ -684,7 +692,7 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
        unsigned int h = xfrm_src_hash(net, daddr, saddr, family);
        struct xfrm_state *x;
 
-       hlist_for_each_entry(x, net->xfrm.state_bysrc+h, bysrc) {
+       hlist_for_each_entry_rcu(x, net->xfrm.state_bysrc + h, bysrc) {
                if (x->props.family != family ||
                    x->id.proto     != proto ||
                    !xfrm_addr_equal(&x->id.daddr, daddr, family) ||
@@ -693,7 +701,8 @@ static struct xfrm_state *__xfrm_state_lookup_byaddr(struct net *net, u32 mark,
 
                if ((mark & x->mark.m) != x->mark.v)
                        continue;
-               xfrm_state_hold(x);
+               if (!xfrm_state_hold_rcu(x))
+                       continue;
                return x;
        }
 
@@ -776,13 +785,16 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
        struct xfrm_state *best = NULL;
        u32 mark = pol->mark.v & pol->mark.m;
        unsigned short encap_family = tmpl->encap_family;
+       unsigned int sequence;
        struct km_event c;
 
        to_put = NULL;
 
-       spin_lock_bh(&net->xfrm.xfrm_state_lock);
+       sequence = read_seqcount_begin(&xfrm_state_hash_generation);
+
+       rcu_read_lock();
        h = xfrm_dst_hash(net, daddr, saddr, tmpl->reqid, encap_family);
-       hlist_for_each_entry(x, net->xfrm.state_bydst+h, bydst) {
+       hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h, bydst) {
                if (x->props.family == encap_family &&
                    x->props.reqid == tmpl->reqid &&
                    (mark & x->mark.m) == x->mark.v &&
@@ -798,7 +810,7 @@ xfrm_state_find(const xfrm_address_t *daddr, const xfrm_address_t *saddr,
                goto found;
 
        h_wildcard = xfrm_dst_hash(net, daddr, &saddr_wildcard, tmpl->reqid, encap_family);
-       hlist_for_each_entry(x, net->xfrm.state_bydst+h_wildcard, bydst) {
+       hlist_for_each_entry_rcu(x, net->xfrm.state_bydst + h_wildcard, bydst) {
                if (x->props.family == encap_family &&
                    x->props.reqid == tmpl->reqid &&
                    (mark & x->mark.m) == x->mark.v &&
@@ -851,19 +863,21 @@ found:
                }
 
                if (km_query(x, tmpl, pol) == 0) {
+                       spin_lock_bh(&net->xfrm.xfrm_state_lock);
                        x->km.state = XFRM_STATE_ACQ;
                        list_add(&x->km.all, &net->xfrm.state_all);
-                       hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
+                       hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
                        h = xfrm_src_hash(net, daddr, saddr, encap_family);
-                       hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
+                       hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
                        if (x->id.spi) {
                                h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, encap_family);
-                               hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
+                               hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
                        }
                        x->lft.hard_add_expires_seconds = net->xfrm.sysctl_acq_expires;
                        tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
                        net->xfrm.state_num++;
                        xfrm_hash_grow_check(net, x->bydst.next != NULL);
+                       spin_unlock_bh(&net->xfrm.xfrm_state_lock);
                } else {
                        x->km.state = XFRM_STATE_DEAD;
                        to_put = x;
@@ -872,13 +886,26 @@ found:
                }
        }
 out:
-       if (x)
-               xfrm_state_hold(x);
-       else
+       if (x) {
+               if (!xfrm_state_hold_rcu(x)) {
+                       *err = -EAGAIN;
+                       x = NULL;
+               }
+       } else {
                *err = acquire_in_progress ? -EAGAIN : error;
-       spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+       }
+       rcu_read_unlock();
        if (to_put)
                xfrm_state_put(to_put);
+
+       if (read_seqcount_retry(&xfrm_state_hash_generation, sequence)) {
+               *err = -EAGAIN;
+               if (x) {
+                       xfrm_state_put(x);
+                       x = NULL;
+               }
+       }
+
        return x;
 }
 
@@ -946,16 +973,16 @@ static void __xfrm_state_insert(struct xfrm_state *x)
 
        h = xfrm_dst_hash(net, &x->id.daddr, &x->props.saddr,
                          x->props.reqid, x->props.family);
-       hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
+       hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
 
        h = xfrm_src_hash(net, &x->id.daddr, &x->props.saddr, x->props.family);
-       hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
+       hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
 
        if (x->id.spi) {
                h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto,
                                  x->props.family);
 
-               hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
+               hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
        }
 
        tasklet_hrtimer_start(&x->mtimer, ktime_set(1, 0), HRTIMER_MODE_REL);
@@ -1064,9 +1091,9 @@ static struct xfrm_state *__find_acq_core(struct net *net,
                xfrm_state_hold(x);
                tasklet_hrtimer_start(&x->mtimer, ktime_set(net->xfrm.sysctl_acq_expires, 0), HRTIMER_MODE_REL);
                list_add(&x->km.all, &net->xfrm.state_all);
-               hlist_add_head(&x->bydst, net->xfrm.state_bydst+h);
+               hlist_add_head_rcu(&x->bydst, net->xfrm.state_bydst + h);
                h = xfrm_src_hash(net, daddr, saddr, family);
-               hlist_add_head(&x->bysrc, net->xfrm.state_bysrc+h);
+               hlist_add_head_rcu(&x->bysrc, net->xfrm.state_bysrc + h);
 
                net->xfrm.state_num++;
 
@@ -1375,7 +1402,7 @@ int xfrm_state_check_expire(struct xfrm_state *x)
        if (x->curlft.bytes >= x->lft.hard_byte_limit ||
            x->curlft.packets >= x->lft.hard_packet_limit) {
                x->km.state = XFRM_STATE_EXPIRED;
-               tasklet_hrtimer_start(&x->mtimer, ktime_set(0, 0), HRTIMER_MODE_REL);
+               tasklet_hrtimer_start(&x->mtimer, 0, HRTIMER_MODE_REL);
                return -EINVAL;
        }
 
@@ -1395,9 +1422,9 @@ xfrm_state_lookup(struct net *net, u32 mark, const xfrm_address_t *daddr, __be32
 {
        struct xfrm_state *x;
 
-       spin_lock_bh(&net->xfrm.xfrm_state_lock);
+       rcu_read_lock();
        x = __xfrm_state_lookup(net, mark, daddr, spi, proto, family);
-       spin_unlock_bh(&net->xfrm.xfrm_state_lock);
+       rcu_read_unlock();
        return x;
 }
 EXPORT_SYMBOL(xfrm_state_lookup);
@@ -1445,7 +1472,7 @@ xfrm_tmpl_sort(struct xfrm_tmpl **dst, struct xfrm_tmpl **src, int n,
        if (afinfo->tmpl_sort)
                err = afinfo->tmpl_sort(dst, src, n);
        spin_unlock_bh(&net->xfrm.xfrm_state_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_tmpl_sort);
@@ -1465,7 +1492,7 @@ xfrm_state_sort(struct xfrm_state **dst, struct xfrm_state **src, int n,
        if (afinfo->state_sort)
                err = afinfo->state_sort(dst, src, n);
        spin_unlock_bh(&net->xfrm.xfrm_state_lock);
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
        return err;
 }
 EXPORT_SYMBOL(xfrm_state_sort);
@@ -1582,7 +1609,7 @@ int xfrm_alloc_spi(struct xfrm_state *x, u32 low, u32 high)
        if (x->id.spi) {
                spin_lock_bh(&net->xfrm.xfrm_state_lock);
                h = xfrm_spi_hash(net, &x->id.daddr, x->id.spi, x->id.proto, x->props.family);
-               hlist_add_head(&x->byspi, net->xfrm.state_byspi+h);
+               hlist_add_head_rcu(&x->byspi, net->xfrm.state_byspi + h);
                spin_unlock_bh(&net->xfrm.xfrm_state_lock);
 
                err = 0;
@@ -1903,10 +1930,10 @@ EXPORT_SYMBOL(xfrm_unregister_km);
 int xfrm_state_register_afinfo(struct xfrm_state_afinfo *afinfo)
 {
        int err = 0;
-       if (unlikely(afinfo == NULL))
-               return -EINVAL;
-       if (unlikely(afinfo->family >= NPROTO))
+
+       if (WARN_ON(afinfo->family >= NPROTO))
                return -EAFNOSUPPORT;
+
        spin_lock_bh(&xfrm_state_afinfo_lock);
        if (unlikely(xfrm_state_afinfo[afinfo->family] != NULL))
                err = -EEXIST;
@@ -1919,14 +1946,14 @@ EXPORT_SYMBOL(xfrm_state_register_afinfo);
 
 int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo)
 {
-       int err = 0;
-       if (unlikely(afinfo == NULL))
-               return -EINVAL;
-       if (unlikely(afinfo->family >= NPROTO))
+       int err = 0, family = afinfo->family;
+
+       if (WARN_ON(family >= NPROTO))
                return -EAFNOSUPPORT;
+
        spin_lock_bh(&xfrm_state_afinfo_lock);
        if (likely(xfrm_state_afinfo[afinfo->family] != NULL)) {
-               if (unlikely(xfrm_state_afinfo[afinfo->family] != afinfo))
+               if (rcu_access_pointer(xfrm_state_afinfo[family]) != afinfo)
                        err = -EINVAL;
                else
                        RCU_INIT_POINTER(xfrm_state_afinfo[afinfo->family], NULL);
@@ -1937,6 +1964,14 @@ int xfrm_state_unregister_afinfo(struct xfrm_state_afinfo *afinfo)
 }
 EXPORT_SYMBOL(xfrm_state_unregister_afinfo);
 
+struct xfrm_state_afinfo *xfrm_state_afinfo_get_rcu(unsigned int family)
+{
+       if (unlikely(family >= NPROTO))
+               return NULL;
+
+       return rcu_dereference(xfrm_state_afinfo[family]);
+}
+
 struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family)
 {
        struct xfrm_state_afinfo *afinfo;
@@ -1949,11 +1984,6 @@ struct xfrm_state_afinfo *xfrm_state_get_afinfo(unsigned int family)
        return afinfo;
 }
 
-void xfrm_state_put_afinfo(struct xfrm_state_afinfo *afinfo)
-{
-       rcu_read_unlock();
-}
-
 /* Temporarily located here until net/xfrm/xfrm_tunnel.c is created */
 void xfrm_state_delete_tunnel(struct xfrm_state *x)
 {
@@ -1971,16 +2001,13 @@ EXPORT_SYMBOL(xfrm_state_delete_tunnel);
 
 int xfrm_state_mtu(struct xfrm_state *x, int mtu)
 {
-       int res;
+       const struct xfrm_type *type = READ_ONCE(x->type);
 
-       spin_lock_bh(&x->lock);
        if (x->km.state == XFRM_STATE_VALID &&
-           x->type && x->type->get_mtu)
-               res = x->type->get_mtu(x, mtu);
-       else
-               res = mtu - x->props.header_len;
-       spin_unlock_bh(&x->lock);
-       return res;
+           type && type->get_mtu)
+               return type->get_mtu(x, mtu);
+
+       return mtu - x->props.header_len;
 }
 
 int __xfrm_init_state(struct xfrm_state *x, bool init_replay)
@@ -1999,7 +2026,7 @@ int __xfrm_init_state(struct xfrm_state *x, bool init_replay)
        if (afinfo->init_flags)
                err = afinfo->init_flags(x);
 
-       xfrm_state_put_afinfo(afinfo);
+       rcu_read_unlock();
 
        if (err)
                goto error;
@@ -2100,8 +2127,6 @@ int __net_init xfrm_state_init(struct net *net)
 
        net->xfrm.state_num = 0;
        INIT_WORK(&net->xfrm.state_hash_work, xfrm_hash_resize);
-       INIT_HLIST_HEAD(&net->xfrm.state_gc_list);
-       INIT_WORK(&net->xfrm.state_gc_work, xfrm_state_gc_task);
        spin_lock_init(&net->xfrm.xfrm_state_lock);
        return 0;
 
@@ -2119,7 +2144,7 @@ void xfrm_state_fini(struct net *net)
 
        flush_work(&net->xfrm.state_hash_work);
        xfrm_state_flush(net, IPSEC_PROTO_ANY, false);
-       flush_work(&net->xfrm.state_gc_work);
+       flush_work(&xfrm_state_gc_work);
 
        WARN_ON(!list_empty(&net->xfrm.state_all));