]> git.proxmox.com Git - mirror_ubuntu-bionic-kernel.git/blobdiff - mm/memcontrol.c
mm: memcontrol: fix transparent huge page allocations under pressure
[mirror_ubuntu-bionic-kernel.git] / mm / memcontrol.c
index ec4dcf1b9562b6299f215e754768da18a36b156e..c86cc442ada47c277bcf68f2d77a96c0272c629c 100644 (file)
@@ -292,6 +292,9 @@ struct mem_cgroup {
        /* vmpressure notifications */
        struct vmpressure vmpressure;
 
+       /* css_online() has been completed */
+       int initialized;
+
        /*
         * the counter to account for mem+swap usage.
         */
@@ -315,9 +318,6 @@ struct mem_cgroup {
        /* OOM-Killer disable */
        int             oom_kill_disable;
 
-       /* set when res.limit == memsw.limit */
-       bool            memsw_is_minimum;
-
        /* protect arrays of thresholds */
        struct mutex thresholds_lock;
 
@@ -480,14 +480,6 @@ enum res_type {
 /* Used for OOM nofiier */
 #define OOM_CONTROL            (0)
 
-/*
- * Reclaim flags for mem_cgroup_hierarchical_reclaim
- */
-#define MEM_CGROUP_RECLAIM_NOSWAP_BIT  0x0
-#define MEM_CGROUP_RECLAIM_NOSWAP      (1 << MEM_CGROUP_RECLAIM_NOSWAP_BIT)
-#define MEM_CGROUP_RECLAIM_SHRINK_BIT  0x1
-#define MEM_CGROUP_RECLAIM_SHRINK      (1 << MEM_CGROUP_RECLAIM_SHRINK_BIT)
-
 /*
  * The memcg_create_mutex will be held whenever a new cgroup is created.
  * As a consequence, any change that needs to protect against new child cgroups
@@ -646,11 +638,13 @@ int memcg_limited_groups_array_size;
 struct static_key memcg_kmem_enabled_key;
 EXPORT_SYMBOL(memcg_kmem_enabled_key);
 
+static void memcg_free_cache_id(int id);
+
 static void disarm_kmem_keys(struct mem_cgroup *memcg)
 {
        if (memcg_kmem_is_active(memcg)) {
                static_key_slow_dec(&memcg_kmem_enabled_key);
-               ida_simple_remove(&kmem_limited_groups, memcg->kmemcg_id);
+               memcg_free_cache_id(memcg->kmemcg_id);
        }
        /*
         * This check can't live in kmem destruction function,
@@ -1099,10 +1093,21 @@ skip_node:
         * skipping css reference should be safe.
         */
        if (next_css) {
-               if ((next_css == &root->css) ||
-                   ((next_css->flags & CSS_ONLINE) &&
-                    css_tryget_online(next_css)))
-                       return mem_cgroup_from_css(next_css);
+               struct mem_cgroup *memcg = mem_cgroup_from_css(next_css);
+
+               if (next_css == &root->css)
+                       return memcg;
+
+               if (css_tryget_online(next_css)) {
+                       /*
+                        * Make sure the memcg is initialized:
+                        * mem_cgroup_css_online() orders the the
+                        * initialization against setting the flag.
+                        */
+                       if (smp_load_acquire(&memcg->initialized))
+                               return memcg;
+                       css_put(next_css);
+               }
 
                prev_css = next_css;
                goto skip_node;
@@ -1792,42 +1797,6 @@ static void mem_cgroup_out_of_memory(struct mem_cgroup *memcg, gfp_t gfp_mask,
                         NULL, "Memory cgroup out of memory");
 }
 
-static unsigned long mem_cgroup_reclaim(struct mem_cgroup *memcg,
-                                       gfp_t gfp_mask,
-                                       unsigned long flags)
-{
-       unsigned long total = 0;
-       bool noswap = false;
-       int loop;
-
-       if (flags & MEM_CGROUP_RECLAIM_NOSWAP)
-               noswap = true;
-       if (!(flags & MEM_CGROUP_RECLAIM_SHRINK) && memcg->memsw_is_minimum)
-               noswap = true;
-
-       for (loop = 0; loop < MEM_CGROUP_MAX_RECLAIM_LOOPS; loop++) {
-               if (loop)
-                       drain_all_stock_async(memcg);
-               total += try_to_free_mem_cgroup_pages(memcg, gfp_mask, noswap);
-               /*
-                * Allow limit shrinkers, which are triggered directly
-                * by userspace, to catch signals and stop reclaim
-                * after minimal progress, regardless of the margin.
-                */
-               if (total && (flags & MEM_CGROUP_RECLAIM_SHRINK))
-                       break;
-               if (mem_cgroup_margin(memcg))
-                       break;
-               /*
-                * If nothing was reclaimed after two attempts, there
-                * may be no reclaimable pages in this hierarchy.
-                */
-               if (loop && !total)
-                       break;
-       }
-       return total;
-}
-
 /**
  * test_mem_cgroup_node_reclaimable
  * @memcg: the target memcg
@@ -2530,25 +2499,29 @@ static int try_charge(struct mem_cgroup *memcg, gfp_t gfp_mask,
        struct mem_cgroup *mem_over_limit;
        struct res_counter *fail_res;
        unsigned long nr_reclaimed;
-       unsigned long flags = 0;
        unsigned long long size;
+       bool may_swap = true;
+       bool drained = false;
        int ret = 0;
 
+       if (mem_cgroup_is_root(memcg))
+               goto done;
 retry:
        if (consume_stock(memcg, nr_pages))
                goto done;
 
        size = batch * PAGE_SIZE;
-       if (!res_counter_charge(&memcg->res, size, &fail_res)) {
-               if (!do_swap_account)
-                       goto done_restock;
-               if (!res_counter_charge(&memcg->memsw, size, &fail_res))
+       if (!do_swap_account ||
+           !res_counter_charge(&memcg->memsw, size, &fail_res)) {
+               if (!res_counter_charge(&memcg->res, size, &fail_res))
                        goto done_restock;
-               res_counter_uncharge(&memcg->res, size);
-               mem_over_limit = mem_cgroup_from_res_counter(fail_res, memsw);
-               flags |= MEM_CGROUP_RECLAIM_NOSWAP;
-       } else
+               if (do_swap_account)
+                       res_counter_uncharge(&memcg->memsw, size);
                mem_over_limit = mem_cgroup_from_res_counter(fail_res, res);
+       } else {
+               mem_over_limit = mem_cgroup_from_res_counter(fail_res, memsw);
+               may_swap = false;
+       }
 
        if (batch > nr_pages) {
                batch = nr_pages;
@@ -2572,11 +2545,18 @@ retry:
        if (!(gfp_mask & __GFP_WAIT))
                goto nomem;
 
-       nr_reclaimed = mem_cgroup_reclaim(mem_over_limit, gfp_mask, flags);
+       nr_reclaimed = try_to_free_mem_cgroup_pages(mem_over_limit, nr_pages,
+                                                   gfp_mask, may_swap);
 
        if (mem_cgroup_margin(mem_over_limit) >= nr_pages)
                goto retry;
 
+       if (!drained) {
+               drain_all_stock_async(mem_over_limit);
+               drained = true;
+               goto retry;
+       }
+
        if (gfp_mask & __GFP_NORETRY)
                goto nomem;
        /*
@@ -2611,9 +2591,7 @@ nomem:
        if (!(gfp_mask & __GFP_NOFAIL))
                return -ENOMEM;
 bypass:
-       memcg = root_mem_cgroup;
-       ret = -EINTR;
-       goto retry;
+       return -EINTR;
 
 done_restock:
        if (batch > nr_pages)
@@ -2626,6 +2604,9 @@ static void cancel_charge(struct mem_cgroup *memcg, unsigned int nr_pages)
 {
        unsigned long bytes = nr_pages * PAGE_SIZE;
 
+       if (mem_cgroup_is_root(memcg))
+               return;
+
        res_counter_uncharge(&memcg->res, bytes);
        if (do_swap_account)
                res_counter_uncharge(&memcg->memsw, bytes);
@@ -2640,6 +2621,9 @@ static void __mem_cgroup_cancel_local_charge(struct mem_cgroup *memcg,
 {
        unsigned long bytes = nr_pages * PAGE_SIZE;
 
+       if (mem_cgroup_is_root(memcg))
+               return;
+
        res_counter_uncharge_until(&memcg->res, memcg->res.parent, bytes);
        if (do_swap_account)
                res_counter_uncharge_until(&memcg->memsw,
@@ -2886,19 +2870,44 @@ int memcg_cache_id(struct mem_cgroup *memcg)
        return memcg ? memcg->kmemcg_id : -1;
 }
 
-static size_t memcg_caches_array_size(int num_groups)
+static int memcg_alloc_cache_id(void)
 {
-       ssize_t size;
-       if (num_groups <= 0)
-               return 0;
+       int id, size;
+       int err;
+
+       id = ida_simple_get(&kmem_limited_groups,
+                           0, MEMCG_CACHES_MAX_SIZE, GFP_KERNEL);
+       if (id < 0)
+               return id;
 
-       size = 2 * num_groups;
+       if (id < memcg_limited_groups_array_size)
+               return id;
+
+       /*
+        * There's no space for the new id in memcg_caches arrays,
+        * so we have to grow them.
+        */
+
+       size = 2 * (id + 1);
        if (size < MEMCG_CACHES_MIN_SIZE)
                size = MEMCG_CACHES_MIN_SIZE;
        else if (size > MEMCG_CACHES_MAX_SIZE)
                size = MEMCG_CACHES_MAX_SIZE;
 
-       return size;
+       mutex_lock(&memcg_slab_mutex);
+       err = memcg_update_all_caches(size);
+       mutex_unlock(&memcg_slab_mutex);
+
+       if (err) {
+               ida_simple_remove(&kmem_limited_groups, id);
+               return err;
+       }
+       return id;
+}
+
+static void memcg_free_cache_id(int id)
+{
+       ida_simple_remove(&kmem_limited_groups, id);
 }
 
 /*
@@ -2908,97 +2917,7 @@ static size_t memcg_caches_array_size(int num_groups)
  */
 void memcg_update_array_size(int num)
 {
-       if (num > memcg_limited_groups_array_size)
-               memcg_limited_groups_array_size = memcg_caches_array_size(num);
-}
-
-int memcg_update_cache_size(struct kmem_cache *s, int num_groups)
-{
-       struct memcg_cache_params *cur_params = s->memcg_params;
-
-       VM_BUG_ON(!is_root_cache(s));
-
-       if (num_groups > memcg_limited_groups_array_size) {
-               int i;
-               struct memcg_cache_params *new_params;
-               ssize_t size = memcg_caches_array_size(num_groups);
-
-               size *= sizeof(void *);
-               size += offsetof(struct memcg_cache_params, memcg_caches);
-
-               new_params = kzalloc(size, GFP_KERNEL);
-               if (!new_params)
-                       return -ENOMEM;
-
-               new_params->is_root_cache = true;
-
-               /*
-                * There is the chance it will be bigger than
-                * memcg_limited_groups_array_size, if we failed an allocation
-                * in a cache, in which case all caches updated before it, will
-                * have a bigger array.
-                *
-                * But if that is the case, the data after
-                * memcg_limited_groups_array_size is certainly unused
-                */
-               for (i = 0; i < memcg_limited_groups_array_size; i++) {
-                       if (!cur_params->memcg_caches[i])
-                               continue;
-                       new_params->memcg_caches[i] =
-                                               cur_params->memcg_caches[i];
-               }
-
-               /*
-                * Ideally, we would wait until all caches succeed, and only
-                * then free the old one. But this is not worth the extra
-                * pointer per-cache we'd have to have for this.
-                *
-                * It is not a big deal if some caches are left with a size
-                * bigger than the others. And all updates will reset this
-                * anyway.
-                */
-               rcu_assign_pointer(s->memcg_params, new_params);
-               if (cur_params)
-                       kfree_rcu(cur_params, rcu_head);
-       }
-       return 0;
-}
-
-int memcg_alloc_cache_params(struct mem_cgroup *memcg, struct kmem_cache *s,
-                            struct kmem_cache *root_cache)
-{
-       size_t size;
-
-       if (!memcg_kmem_enabled())
-               return 0;
-
-       if (!memcg) {
-               size = offsetof(struct memcg_cache_params, memcg_caches);
-               size += memcg_limited_groups_array_size * sizeof(void *);
-       } else
-               size = sizeof(struct memcg_cache_params);
-
-       s->memcg_params = kzalloc(size, GFP_KERNEL);
-       if (!s->memcg_params)
-               return -ENOMEM;
-
-       if (memcg) {
-               s->memcg_params->memcg = memcg;
-               s->memcg_params->root_cache = root_cache;
-               css_get(&memcg->css);
-       } else
-               s->memcg_params->is_root_cache = true;
-
-       return 0;
-}
-
-void memcg_free_cache_params(struct kmem_cache *s)
-{
-       if (!s->memcg_params)
-               return;
-       if (!s->memcg_params->is_root_cache)
-               css_put(&s->memcg_params->memcg->css);
-       kfree(s->memcg_params);
+       memcg_limited_groups_array_size = num;
 }
 
 static void memcg_register_cache(struct mem_cgroup *memcg,
@@ -3031,6 +2950,7 @@ static void memcg_register_cache(struct mem_cgroup *memcg,
        if (!cachep)
                return;
 
+       css_get(&memcg->css);
        list_add(&cachep->memcg_params->list, &memcg->memcg_slab_caches);
 
        /*
@@ -3064,6 +2984,9 @@ static void memcg_unregister_cache(struct kmem_cache *cachep)
        list_del(&cachep->memcg_params->list);
 
        kmem_cache_destroy(cachep);
+
+       /* drop the reference taken in memcg_register_cache */
+       css_put(&memcg->css);
 }
 
 /*
@@ -3668,7 +3591,6 @@ static int mem_cgroup_resize_limit(struct mem_cgroup *memcg,
                                unsigned long long val)
 {
        int retry_count;
-       u64 memswlimit, memlimit;
        int ret = 0;
        int children = mem_cgroup_count_children(memcg);
        u64 curusage, oldusage;
@@ -3695,31 +3617,23 @@ static int mem_cgroup_resize_limit(struct mem_cgroup *memcg,
                 * We have to guarantee memcg->res.limit <= memcg->memsw.limit.
                 */
                mutex_lock(&set_limit_mutex);
-               memswlimit = res_counter_read_u64(&memcg->memsw, RES_LIMIT);
-               if (memswlimit < val) {
+               if (res_counter_read_u64(&memcg->memsw, RES_LIMIT) < val) {
                        ret = -EINVAL;
                        mutex_unlock(&set_limit_mutex);
                        break;
                }
 
-               memlimit = res_counter_read_u64(&memcg->res, RES_LIMIT);
-               if (memlimit < val)
+               if (res_counter_read_u64(&memcg->res, RES_LIMIT) < val)
                        enlarge = 1;
 
                ret = res_counter_set_limit(&memcg->res, val);
-               if (!ret) {
-                       if (memswlimit == val)
-                               memcg->memsw_is_minimum = true;
-                       else
-                               memcg->memsw_is_minimum = false;
-               }
                mutex_unlock(&set_limit_mutex);
 
                if (!ret)
                        break;
 
-               mem_cgroup_reclaim(memcg, GFP_KERNEL,
-                                  MEM_CGROUP_RECLAIM_SHRINK);
+               try_to_free_mem_cgroup_pages(memcg, 1, GFP_KERNEL, true);
+
                curusage = res_counter_read_u64(&memcg->res, RES_USAGE);
                /* Usage is reduced ? */
                if (curusage >= oldusage)
@@ -3737,7 +3651,7 @@ static int mem_cgroup_resize_memsw_limit(struct mem_cgroup *memcg,
                                        unsigned long long val)
 {
        int retry_count;
-       u64 memlimit, memswlimit, oldusage, curusage;
+       u64 oldusage, curusage;
        int children = mem_cgroup_count_children(memcg);
        int ret = -EBUSY;
        int enlarge = 0;
@@ -3756,30 +3670,21 @@ static int mem_cgroup_resize_memsw_limit(struct mem_cgroup *memcg,
                 * We have to guarantee memcg->res.limit <= memcg->memsw.limit.
                 */
                mutex_lock(&set_limit_mutex);
-               memlimit = res_counter_read_u64(&memcg->res, RES_LIMIT);
-               if (memlimit > val) {
+               if (res_counter_read_u64(&memcg->res, RES_LIMIT) > val) {
                        ret = -EINVAL;
                        mutex_unlock(&set_limit_mutex);
                        break;
                }
-               memswlimit = res_counter_read_u64(&memcg->memsw, RES_LIMIT);
-               if (memswlimit < val)
+               if (res_counter_read_u64(&memcg->memsw, RES_LIMIT) < val)
                        enlarge = 1;
                ret = res_counter_set_limit(&memcg->memsw, val);
-               if (!ret) {
-                       if (memlimit == val)
-                               memcg->memsw_is_minimum = true;
-                       else
-                               memcg->memsw_is_minimum = false;
-               }
                mutex_unlock(&set_limit_mutex);
 
                if (!ret)
                        break;
 
-               mem_cgroup_reclaim(memcg, GFP_KERNEL,
-                                  MEM_CGROUP_RECLAIM_NOSWAP |
-                                  MEM_CGROUP_RECLAIM_SHRINK);
+               try_to_free_mem_cgroup_pages(memcg, 1, GFP_KERNEL, false);
+
                curusage = res_counter_read_u64(&memcg->memsw, RES_USAGE);
                /* Usage is reduced ? */
                if (curusage >= oldusage)
@@ -4028,8 +3933,8 @@ static int mem_cgroup_force_empty(struct mem_cgroup *memcg)
                if (signal_pending(current))
                        return -EINTR;
 
-               progress = try_to_free_mem_cgroup_pages(memcg, GFP_KERNEL,
-                                               false);
+               progress = try_to_free_mem_cgroup_pages(memcg, 1,
+                                                       GFP_KERNEL, true);
                if (!progress) {
                        nr_retries--;
                        /* maybe some writeback is necessary */
@@ -4093,6 +3998,46 @@ out:
        return retval;
 }
 
+static unsigned long mem_cgroup_recursive_stat(struct mem_cgroup *memcg,
+                                              enum mem_cgroup_stat_index idx)
+{
+       struct mem_cgroup *iter;
+       long val = 0;
+
+       /* Per-cpu values can be negative, use a signed accumulator */
+       for_each_mem_cgroup_tree(iter, memcg)
+               val += mem_cgroup_read_stat(iter, idx);
+
+       if (val < 0) /* race ? */
+               val = 0;
+       return val;
+}
+
+static inline u64 mem_cgroup_usage(struct mem_cgroup *memcg, bool swap)
+{
+       u64 val;
+
+       if (!mem_cgroup_is_root(memcg)) {
+               if (!swap)
+                       return res_counter_read_u64(&memcg->res, RES_USAGE);
+               else
+                       return res_counter_read_u64(&memcg->memsw, RES_USAGE);
+       }
+
+       /*
+        * Transparent hugepages are still accounted for in MEM_CGROUP_STAT_RSS
+        * as well as in MEM_CGROUP_STAT_RSS_HUGE.
+        */
+       val = mem_cgroup_recursive_stat(memcg, MEM_CGROUP_STAT_CACHE);
+       val += mem_cgroup_recursive_stat(memcg, MEM_CGROUP_STAT_RSS);
+
+       if (swap)
+               val += mem_cgroup_recursive_stat(memcg, MEM_CGROUP_STAT_SWAP);
+
+       return val << PAGE_SHIFT;
+}
+
+
 static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
                               struct cftype *cft)
 {
@@ -4102,8 +4047,12 @@ static u64 mem_cgroup_read_u64(struct cgroup_subsys_state *css,
 
        switch (type) {
        case _MEM:
+               if (name == RES_USAGE)
+                       return mem_cgroup_usage(memcg, false);
                return res_counter_read_u64(&memcg->res, name);
        case _MEMSWAP:
+               if (name == RES_USAGE)
+                       return mem_cgroup_usage(memcg, true);
                return res_counter_read_u64(&memcg->memsw, name);
        case _KMEM:
                return res_counter_read_u64(&memcg->kmem, name);
@@ -4150,23 +4099,12 @@ static int __memcg_activate_kmem(struct mem_cgroup *memcg,
        if (err)
                goto out;
 
-       memcg_id = ida_simple_get(&kmem_limited_groups,
-                                 0, MEMCG_CACHES_MAX_SIZE, GFP_KERNEL);
+       memcg_id = memcg_alloc_cache_id();
        if (memcg_id < 0) {
                err = memcg_id;
                goto out;
        }
 
-       /*
-        * Make sure we have enough space for this cgroup in each root cache's
-        * memcg_params.
-        */
-       mutex_lock(&memcg_slab_mutex);
-       err = memcg_update_all_caches(memcg_id + 1);
-       mutex_unlock(&memcg_slab_mutex);
-       if (err)
-               goto out_rmid;
-
        memcg->kmemcg_id = memcg_id;
        INIT_LIST_HEAD(&memcg->memcg_slab_caches);
 
@@ -4187,10 +4125,6 @@ static int __memcg_activate_kmem(struct mem_cgroup *memcg,
 out:
        memcg_resume_kmem_account();
        return err;
-
-out_rmid:
-       ida_simple_remove(&kmem_limited_groups, memcg_id);
-       goto out;
 }
 
 static int memcg_activate_kmem(struct mem_cgroup *memcg,
@@ -4572,10 +4506,7 @@ static void __mem_cgroup_threshold(struct mem_cgroup *memcg, bool swap)
        if (!t)
                goto unlock;
 
-       if (!swap)
-               usage = res_counter_read_u64(&memcg->res, RES_USAGE);
-       else
-               usage = res_counter_read_u64(&memcg->memsw, RES_USAGE);
+       usage = mem_cgroup_usage(memcg, swap);
 
        /*
         * current_threshold points to threshold just below or equal to usage.
@@ -4673,10 +4604,10 @@ static int __mem_cgroup_usage_register_event(struct mem_cgroup *memcg,
 
        if (type == _MEM) {
                thresholds = &memcg->thresholds;
-               usage = res_counter_read_u64(&memcg->res, RES_USAGE);
+               usage = mem_cgroup_usage(memcg, false);
        } else if (type == _MEMSWAP) {
                thresholds = &memcg->memsw_thresholds;
-               usage = res_counter_read_u64(&memcg->memsw, RES_USAGE);
+               usage = mem_cgroup_usage(memcg, true);
        } else
                BUG();
 
@@ -4762,10 +4693,10 @@ static void __mem_cgroup_usage_unregister_event(struct mem_cgroup *memcg,
 
        if (type == _MEM) {
                thresholds = &memcg->thresholds;
-               usage = res_counter_read_u64(&memcg->res, RES_USAGE);
+               usage = mem_cgroup_usage(memcg, false);
        } else if (type == _MEMSWAP) {
                thresholds = &memcg->memsw_thresholds;
-               usage = res_counter_read_u64(&memcg->memsw, RES_USAGE);
+               usage = mem_cgroup_usage(memcg, true);
        } else
                BUG();
 
@@ -5502,6 +5433,7 @@ mem_cgroup_css_online(struct cgroup_subsys_state *css)
 {
        struct mem_cgroup *memcg = mem_cgroup_from_css(css);
        struct mem_cgroup *parent = mem_cgroup_from_css(css->parent);
+       int ret;
 
        if (css->id > MEM_CGROUP_ID_MAX)
                return -ENOSPC;
@@ -5525,9 +5457,9 @@ mem_cgroup_css_online(struct cgroup_subsys_state *css)
                 * core guarantees its existence.
                 */
        } else {
-               res_counter_init(&memcg->res, &root_mem_cgroup->res);
-               res_counter_init(&memcg->memsw, &root_mem_cgroup->memsw);
-               res_counter_init(&memcg->kmem, &root_mem_cgroup->kmem);
+               res_counter_init(&memcg->res, NULL);
+               res_counter_init(&memcg->memsw, NULL);
+               res_counter_init(&memcg->kmem, NULL);
                /*
                 * Deeper hierachy with use_hierarchy == false doesn't make
                 * much sense so let cgroup subsystem know about this
@@ -5538,7 +5470,18 @@ mem_cgroup_css_online(struct cgroup_subsys_state *css)
        }
        mutex_unlock(&memcg_create_mutex);
 
-       return memcg_init_kmem(memcg, &memory_cgrp_subsys);
+       ret = memcg_init_kmem(memcg, &memory_cgrp_subsys);
+       if (ret)
+               return ret;
+
+       /*
+        * Make sure the memcg is initialized: mem_cgroup_iter()
+        * orders reading memcg->initialized against its callers
+        * reading the memcg members.
+        */
+       smp_store_release(&memcg->initialized, 1);
+
+       return 0;
 }
 
 /*
@@ -5969,8 +5912,9 @@ static void __mem_cgroup_clear_mc(void)
        /* we must fixup refcnts and charges */
        if (mc.moved_swap) {
                /* uncharge swap account from the old cgroup */
-               res_counter_uncharge(&mc.from->memsw,
-                                    PAGE_SIZE * mc.moved_swap);
+               if (!mem_cgroup_is_root(mc.from))
+                       res_counter_uncharge(&mc.from->memsw,
+                                            PAGE_SIZE * mc.moved_swap);
 
                for (i = 0; i < mc.moved_swap; i++)
                        css_put(&mc.from->css);
@@ -5979,8 +5923,9 @@ static void __mem_cgroup_clear_mc(void)
                 * we charged both to->res and to->memsw, so we should
                 * uncharge to->res.
                 */
-               res_counter_uncharge(&mc.to->res,
-                                    PAGE_SIZE * mc.moved_swap);
+               if (!mem_cgroup_is_root(mc.to))
+                       res_counter_uncharge(&mc.to->res,
+                                            PAGE_SIZE * mc.moved_swap);
                /* we've already done css_get(mc.to) */
                mc.moved_swap = 0;
        }
@@ -6345,7 +6290,8 @@ void mem_cgroup_uncharge_swap(swp_entry_t entry)
        rcu_read_lock();
        memcg = mem_cgroup_lookup(id);
        if (memcg) {
-               res_counter_uncharge(&memcg->memsw, PAGE_SIZE);
+               if (!mem_cgroup_is_root(memcg))
+                       res_counter_uncharge(&memcg->memsw, PAGE_SIZE);
                mem_cgroup_swap_statistics(memcg, false);
                css_put(&memcg->css);
        }
@@ -6509,12 +6455,15 @@ static void uncharge_batch(struct mem_cgroup *memcg, unsigned long pgpgout,
 {
        unsigned long flags;
 
-       if (nr_mem)
-               res_counter_uncharge(&memcg->res, nr_mem * PAGE_SIZE);
-       if (nr_memsw)
-               res_counter_uncharge(&memcg->memsw, nr_memsw * PAGE_SIZE);
-
-       memcg_oom_recover(memcg);
+       if (!mem_cgroup_is_root(memcg)) {
+               if (nr_mem)
+                       res_counter_uncharge(&memcg->res,
+                                            nr_mem * PAGE_SIZE);
+               if (nr_memsw)
+                       res_counter_uncharge(&memcg->memsw,
+                                            nr_memsw * PAGE_SIZE);
+               memcg_oom_recover(memcg);
+       }
 
        local_irq_save(flags);
        __this_cpu_sub(memcg->stat->count[MEM_CGROUP_STAT_RSS], nr_anon);