]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - kernel/ucount.c
x86/retpoline: Cleanup some #ifdefery
[mirror_ubuntu-jammy-kernel.git] / kernel / ucount.c
index bb51849e6375288493d1429e07a49ee8de925986..a1d67261501a6d1b90bd1226844add4cc88c734a 100644 (file)
@@ -184,6 +184,7 @@ struct ucounts *alloc_ucounts(struct user_namespace *ns, kuid_t uid)
                        kfree(new);
                } else {
                        hlist_add_head(&new->node, hashent);
+                       get_user_ns(new->ns);
                        spin_unlock_irq(&ucounts_lock);
                        return new;
                }
@@ -204,6 +205,7 @@ void put_ucounts(struct ucounts *ucounts)
        if (atomic_dec_and_lock_irqsave(&ucounts->count, &ucounts_lock, flags)) {
                hlist_del_init(&ucounts->node);
                spin_unlock_irqrestore(&ucounts_lock, flags);
+               put_user_ns(ucounts->ns);
                kfree(ucounts);
        }
 }
@@ -258,15 +260,16 @@ void dec_ucount(struct ucounts *ucounts, enum ucount_type type)
 long inc_rlimit_ucounts(struct ucounts *ucounts, enum ucount_type type, long v)
 {
        struct ucounts *iter;
+       long max = LONG_MAX;
        long ret = 0;
 
        for (iter = ucounts; iter; iter = iter->ns->ucounts) {
-               long max = READ_ONCE(iter->ns->ucount_max[type]);
                long new = atomic_long_add_return(v, &iter->ucount[type]);
                if (new < 0 || new > max)
                        ret = LONG_MAX;
                else if (iter == ucounts)
                        ret = new;
+               max = READ_ONCE(iter->ns->ucount_max[type]);
        }
        return ret;
 }
@@ -284,15 +287,67 @@ bool dec_rlimit_ucounts(struct ucounts *ucounts, enum ucount_type type, long v)
        return (new == 0);
 }
 
-bool is_ucounts_overlimit(struct ucounts *ucounts, enum ucount_type type, unsigned long max)
+static void do_dec_rlimit_put_ucounts(struct ucounts *ucounts,
+                               struct ucounts *last, enum ucount_type type)
+{
+       struct ucounts *iter, *next;
+       for (iter = ucounts; iter != last; iter = next) {
+               long dec = atomic_long_add_return(-1, &iter->ucount[type]);
+               WARN_ON_ONCE(dec < 0);
+               next = iter->ns->ucounts;
+               if (dec == 0)
+                       put_ucounts(iter);
+       }
+}
+
+void dec_rlimit_put_ucounts(struct ucounts *ucounts, enum ucount_type type)
 {
+       do_dec_rlimit_put_ucounts(ucounts, NULL, type);
+}
+
+long inc_rlimit_get_ucounts(struct ucounts *ucounts, enum ucount_type type)
+{
+       /* Caller must hold a reference to ucounts */
        struct ucounts *iter;
-       if (get_ucounts_value(ucounts, type) > max)
-               return true;
+       long max = LONG_MAX;
+       long dec, ret = 0;
+
        for (iter = ucounts; iter; iter = iter->ns->ucounts) {
+               long new = atomic_long_add_return(1, &iter->ucount[type]);
+               if (new < 0 || new > max)
+                       goto unwind;
+               if (iter == ucounts)
+                       ret = new;
                max = READ_ONCE(iter->ns->ucount_max[type]);
-               if (get_ucounts_value(iter, type) > max)
+               /*
+                * Grab an extra ucount reference for the caller when
+                * the rlimit count was previously 0.
+                */
+               if (new != 1)
+                       continue;
+               if (!get_ucounts(iter))
+                       goto dec_unwind;
+       }
+       return ret;
+dec_unwind:
+       dec = atomic_long_add_return(-1, &iter->ucount[type]);
+       WARN_ON_ONCE(dec < 0);
+unwind:
+       do_dec_rlimit_put_ucounts(ucounts, iter, type);
+       return 0;
+}
+
+bool is_ucounts_overlimit(struct ucounts *ucounts, enum ucount_type type, unsigned long rlimit)
+{
+       struct ucounts *iter;
+       long max = rlimit;
+       if (rlimit > LONG_MAX)
+               max = LONG_MAX;
+       for (iter = ucounts; iter; iter = iter->ns->ucounts) {
+               long val = get_ucounts_value(iter, type);
+               if (val < 0 || val > max)
                        return true;
+               max = READ_ONCE(iter->ns->ucount_max[type]);
        }
        return false;
 }