]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - kernel/auditfilter.c
audit: Add typespecific uid and gid comparators
[mirror_ubuntu-jammy-kernel.git] / kernel / auditfilter.c
index a6c3f1abd206c9d9736cbe5834483e36fd1d62ff..b30320cea26f0aad984f7df0c4add14736c28156 100644 (file)
@@ -342,6 +342,8 @@ static struct audit_entry *audit_rule_to_entry(struct audit_rule *rule)
 
                f->type = rule->fields[i] & ~(AUDIT_NEGATE|AUDIT_OPERATORS);
                f->val = rule->values[i];
+               f->uid = INVALID_UID;
+               f->gid = INVALID_GID;
 
                err = -EINVAL;
                if (f->op == Audit_bad)
@@ -350,16 +352,32 @@ static struct audit_entry *audit_rule_to_entry(struct audit_rule *rule)
                switch(f->type) {
                default:
                        goto exit_free;
-               case AUDIT_PID:
                case AUDIT_UID:
                case AUDIT_EUID:
                case AUDIT_SUID:
                case AUDIT_FSUID:
+               case AUDIT_LOGINUID:
+                       /* bit ops not implemented for uid comparisons */
+                       if (f->op == Audit_bitmask || f->op == Audit_bittest)
+                               goto exit_free;
+
+                       f->uid = make_kuid(current_user_ns(), f->val);
+                       if (!uid_valid(f->uid))
+                               goto exit_free;
+                       break;
                case AUDIT_GID:
                case AUDIT_EGID:
                case AUDIT_SGID:
                case AUDIT_FSGID:
-               case AUDIT_LOGINUID:
+                       /* bit ops not implemented for gid comparisons */
+                       if (f->op == Audit_bitmask || f->op == Audit_bittest)
+                               goto exit_free;
+
+                       f->gid = make_kgid(current_user_ns(), f->val);
+                       if (!gid_valid(f->gid))
+                               goto exit_free;
+                       break;
+               case AUDIT_PID:
                case AUDIT_PERS:
                case AUDIT_MSGTYPE:
                case AUDIT_PPID:
@@ -437,19 +455,39 @@ static struct audit_entry *audit_data_to_entry(struct audit_rule_data *data,
 
                f->type = data->fields[i];
                f->val = data->values[i];
+               f->uid = INVALID_UID;
+               f->gid = INVALID_GID;
                f->lsm_str = NULL;
                f->lsm_rule = NULL;
                switch(f->type) {
-               case AUDIT_PID:
                case AUDIT_UID:
                case AUDIT_EUID:
                case AUDIT_SUID:
                case AUDIT_FSUID:
+               case AUDIT_LOGINUID:
+               case AUDIT_OBJ_UID:
+                       /* bit ops not implemented for uid comparisons */
+                       if (f->op == Audit_bitmask || f->op == Audit_bittest)
+                               goto exit_free;
+
+                       f->uid = make_kuid(current_user_ns(), f->val);
+                       if (!uid_valid(f->uid))
+                               goto exit_free;
+                       break;
                case AUDIT_GID:
                case AUDIT_EGID:
                case AUDIT_SGID:
                case AUDIT_FSGID:
-               case AUDIT_LOGINUID:
+               case AUDIT_OBJ_GID:
+                       /* bit ops not implemented for gid comparisons */
+                       if (f->op == Audit_bitmask || f->op == Audit_bittest)
+                               goto exit_free;
+
+                       f->gid = make_kgid(current_user_ns(), f->val);
+                       if (!gid_valid(f->gid))
+                               goto exit_free;
+                       break;
+               case AUDIT_PID:
                case AUDIT_PERS:
                case AUDIT_MSGTYPE:
                case AUDIT_PPID:
@@ -461,8 +499,6 @@ static struct audit_entry *audit_data_to_entry(struct audit_rule_data *data,
                case AUDIT_ARG1:
                case AUDIT_ARG2:
                case AUDIT_ARG3:
-               case AUDIT_OBJ_UID:
-               case AUDIT_OBJ_GID:
                        break;
                case AUDIT_ARCH:
                        entry->rule.arch_f = f;
@@ -707,6 +743,23 @@ static int audit_compare_rule(struct audit_krule *a, struct audit_krule *b)
                        if (strcmp(a->filterkey, b->filterkey))
                                return 1;
                        break;
+               case AUDIT_UID:
+               case AUDIT_EUID:
+               case AUDIT_SUID:
+               case AUDIT_FSUID:
+               case AUDIT_LOGINUID:
+               case AUDIT_OBJ_UID:
+                       if (!uid_eq(a->fields[i].uid, b->fields[i].uid))
+                               return 1;
+                       break;
+               case AUDIT_GID:
+               case AUDIT_EGID:
+               case AUDIT_SGID:
+               case AUDIT_FSGID:
+               case AUDIT_OBJ_GID:
+                       if (!gid_eq(a->fields[i].gid, b->fields[i].gid))
+                               return 1;
+                       break;
                default:
                        if (a->fields[i].val != b->fields[i].val)
                                return 1;
@@ -1098,7 +1151,7 @@ static void audit_log_rule_change(uid_t loginuid, u32 sessionid, u32 sid,
  * @sessionid: sessionid for netlink audit message
  * @sid: SE Linux Security ID of sender
  */
-int audit_receive_filter(int type, int pid, int uid, int seq, void *data,
+int audit_receive_filter(int type, int pid, int seq, void *data,
                         size_t datasz, uid_t loginuid, u32 sessionid, u32 sid)
 {
        struct task_struct *tsk;
@@ -1198,6 +1251,52 @@ int audit_comparator(u32 left, u32 op, u32 right)
        }
 }
 
+int audit_uid_comparator(kuid_t left, u32 op, kuid_t right)
+{
+       switch (op) {
+       case Audit_equal:
+               return uid_eq(left, right);
+       case Audit_not_equal:
+               return !uid_eq(left, right);
+       case Audit_lt:
+               return uid_lt(left, right);
+       case Audit_le:
+               return uid_lte(left, right);
+       case Audit_gt:
+               return uid_gt(left, right);
+       case Audit_ge:
+               return uid_gte(left, right);
+       case Audit_bitmask:
+       case Audit_bittest:
+       default:
+               BUG();
+               return 0;
+       }
+}
+
+int audit_gid_comparator(kgid_t left, u32 op, kgid_t right)
+{
+       switch (op) {
+       case Audit_equal:
+               return gid_eq(left, right);
+       case Audit_not_equal:
+               return !gid_eq(left, right);
+       case Audit_lt:
+               return gid_lt(left, right);
+       case Audit_le:
+               return gid_lte(left, right);
+       case Audit_gt:
+               return gid_gt(left, right);
+       case Audit_ge:
+               return gid_gte(left, right);
+       case Audit_bitmask:
+       case Audit_bittest:
+       default:
+               BUG();
+               return 0;
+       }
+}
+
 /* Compare given dentry name with last component in given path,
  * return of 0 indicates a match. */
 int audit_compare_dname_path(const char *dname, const char *path,
@@ -1236,8 +1335,7 @@ int audit_compare_dname_path(const char *dname, const char *path,
        return strncmp(p, dname, dlen);
 }
 
-static int audit_filter_user_rules(struct netlink_skb_parms *cb,
-                                  struct audit_krule *rule,
+static int audit_filter_user_rules(struct audit_krule *rule,
                                   enum audit_state *state)
 {
        int i;
@@ -1249,17 +1347,17 @@ static int audit_filter_user_rules(struct netlink_skb_parms *cb,
 
                switch (f->type) {
                case AUDIT_PID:
-                       result = audit_comparator(cb->creds.pid, f->op, f->val);
+                       result = audit_comparator(task_pid_vnr(current), f->op, f->val);
                        break;
                case AUDIT_UID:
-                       result = audit_comparator(cb->creds.uid, f->op, f->val);
+                       result = audit_uid_comparator(current_uid(), f->op, f->uid);
                        break;
                case AUDIT_GID:
-                       result = audit_comparator(cb->creds.gid, f->op, f->val);
+                       result = audit_gid_comparator(current_gid(), f->op, f->gid);
                        break;
                case AUDIT_LOGINUID:
-                       result = audit_comparator(audit_get_loginuid(current),
-                                                 f->op, f->val);
+                       result = audit_uid_comparator(audit_get_loginuid(current),
+                                                 f->op, f->uid);
                        break;
                case AUDIT_SUBJ_USER:
                case AUDIT_SUBJ_ROLE:
@@ -1287,7 +1385,7 @@ static int audit_filter_user_rules(struct netlink_skb_parms *cb,
        return 1;
 }
 
-int audit_filter_user(struct netlink_skb_parms *cb)
+int audit_filter_user(void)
 {
        enum audit_state state = AUDIT_DISABLED;
        struct audit_entry *e;
@@ -1295,7 +1393,7 @@ int audit_filter_user(struct netlink_skb_parms *cb)
 
        rcu_read_lock();
        list_for_each_entry_rcu(e, &audit_filter_list[AUDIT_FILTER_USER], list) {
-               if (audit_filter_user_rules(cb, &e->rule, &state)) {
+               if (audit_filter_user_rules(&e->rule, &state)) {
                        if (state == AUDIT_DISABLED)
                                ret = 0;
                        break;