]> git.proxmox.com Git - mirror_ubuntu-hirsute-kernel.git/blobdiff - net/netfilter/ipset/ip_set_hash_gen.h
netfilter: ipset: Count non-static extension memory for userspace
[mirror_ubuntu-hirsute-kernel.git] / net / netfilter / ipset / ip_set_hash_gen.h
index d32fd6b036bfa8f3301fd5ab15a515b413442512..0746405a1d140098e8a9cdeb3a894581b18fc018 100644 (file)
@@ -275,7 +275,6 @@ htable_bits(u32 hashsize)
 struct htype {
        struct htable __rcu *table; /* the hash table */
        u32 maxelem;            /* max elements in the hash */
-       u32 elements;           /* current element (vs timeout) */
        u32 initval;            /* random jhash init value */
 #ifdef IP_SET_HASH_WITH_MARKMASK
        u32 markmask;           /* markmask value for mark mask to store */
@@ -344,21 +343,13 @@ mtype_del_cidr(struct htype *h, u8 cidr, u8 nets_length, u8 n)
 /* Calculate the actual memory size of the set data */
 static size_t
 mtype_ahash_memsize(const struct htype *h, const struct htable *t,
-                   u8 nets_length, size_t dsize)
+                   u8 nets_length)
 {
-       u32 i;
-       struct hbucket *n;
        size_t memsize = sizeof(*h) + sizeof(*t);
 
 #ifdef IP_SET_HASH_WITH_NETS
        memsize += sizeof(struct net_prefixes) * nets_length;
 #endif
-       for (i = 0; i < jhash_size(t->htable_bits); i++) {
-               n = rcu_dereference_bh(hbucket(t, i));
-               if (!n)
-                       continue;
-               memsize += sizeof(struct hbucket) + n->size * dsize;
-       }
 
        return memsize;
 }
@@ -400,7 +391,8 @@ mtype_flush(struct ip_set *set)
 #ifdef IP_SET_HASH_WITH_NETS
        memset(h->nets, 0, sizeof(struct net_prefixes) * NLEN(set->family));
 #endif
-       h->elements = 0;
+       set->elements = 0;
+       set->ext_size = 0;
 }
 
 /* Destroy the hashtable part of the set */
@@ -506,7 +498,7 @@ mtype_expire(struct ip_set *set, struct htype *h, u8 nets_length, size_t dsize)
                                                nets_length, k);
 #endif
                                ip_set_ext_destroy(set, data);
-                               h->elements--;
+                               set->elements--;
                                d++;
                        }
                }
@@ -532,6 +524,7 @@ mtype_expire(struct ip_set *set, struct htype *h, u8 nets_length, size_t dsize)
                                d++;
                        }
                        tmp->pos = d;
+                       set->ext_size -= AHASH_INIT_SIZE * dsize;
                        rcu_assign_pointer(hbucket(t, i), tmp);
                        kfree_rcu(n, rcu);
                }
@@ -563,7 +556,7 @@ mtype_resize(struct ip_set *set, bool retried)
        struct htype *h = set->data;
        struct htable *t, *orig;
        u8 htable_bits;
-       size_t dsize = set->dsize;
+       size_t extsize, dsize = set->dsize;
 #ifdef IP_SET_HASH_WITH_NETS
        u8 flags;
        struct mtype_elem *tmp;
@@ -606,6 +599,7 @@ retry:
        /* There can't be another parallel resizing, but dumping is possible */
        atomic_set(&orig->ref, 1);
        atomic_inc(&orig->uref);
+       extsize = 0;
        pr_debug("attempt to resize set %s from %u to %u, t %p\n",
                 set->name, orig->htable_bits, htable_bits, orig);
        for (i = 0; i < jhash_size(orig->htable_bits); i++) {
@@ -636,6 +630,7 @@ retry:
                                        goto cleanup;
                                }
                                m->size = AHASH_INIT_SIZE;
+                               extsize = sizeof(*m) + AHASH_INIT_SIZE * dsize;
                                RCU_INIT_POINTER(hbucket(t, key), m);
                        } else if (m->pos >= m->size) {
                                struct hbucket *ht;
@@ -655,6 +650,7 @@ retry:
                                memcpy(ht, m, sizeof(struct hbucket) +
                                              m->size * dsize);
                                ht->size = m->size + AHASH_INIT_SIZE;
+                               extsize += AHASH_INIT_SIZE * dsize;
                                kfree(m);
                                m = ht;
                                RCU_INIT_POINTER(hbucket(t, key), ht);
@@ -668,6 +664,7 @@ retry:
                }
        }
        rcu_assign_pointer(h->table, t);
+       set->ext_size = extsize;
 
        spin_unlock_bh(&set->lock);
 
@@ -715,11 +712,11 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
        bool deleted = false, forceadd = false, reuse = false;
        u32 key, multi = 0;
 
-       if (h->elements >= h->maxelem) {
+       if (set->elements >= h->maxelem) {
                if (SET_WITH_TIMEOUT(set))
                        /* FIXME: when set is full, we slow down here */
                        mtype_expire(set, h, NLEN(set->family), set->dsize);
-               if (h->elements >= h->maxelem && SET_WITH_FORCEADD(set))
+               if (set->elements >= h->maxelem && SET_WITH_FORCEADD(set))
                        forceadd = true;
        }
 
@@ -732,7 +729,7 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                                pr_warn("Set %s is full, maxelem %u reached\n",
                                        set->name, h->maxelem);
                        return -IPSET_ERR_HASH_FULL;
-               } else if (h->elements >= h->maxelem) {
+               } else if (set->elements >= h->maxelem) {
                        goto set_full;
                }
                old = NULL;
@@ -741,6 +738,7 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                if (!n)
                        return -ENOMEM;
                n->size = AHASH_INIT_SIZE;
+               set->ext_size += sizeof(*n) + AHASH_INIT_SIZE * set->dsize;
                goto copy_elem;
        }
        for (i = 0; i < n->pos; i++) {
@@ -781,11 +779,11 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                                        NLEN(set->family), i);
 #endif
                        ip_set_ext_destroy(set, data);
-                       h->elements--;
+                       set->elements--;
                }
                goto copy_data;
        }
-       if (h->elements >= h->maxelem)
+       if (set->elements >= h->maxelem)
                goto set_full;
        /* Create a new slot */
        if (n->pos >= n->size) {
@@ -804,13 +802,14 @@ mtype_add(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                memcpy(n, old, sizeof(struct hbucket) +
                       old->size * set->dsize);
                n->size = old->size + AHASH_INIT_SIZE;
+               set->ext_size += AHASH_INIT_SIZE * set->dsize;
        }
 
 copy_elem:
        j = n->pos++;
        data = ahash_data(n, j, set->dsize);
 copy_data:
-       h->elements++;
+       set->elements++;
 #ifdef IP_SET_HASH_WITH_NETS
        for (i = 0; i < IPSET_NET_COUNT; i++)
                mtype_add_cidr(h, NCIDR_PUT(DCIDR_GET(d->cidr, i)),
@@ -824,7 +823,7 @@ overwrite_extensions:
        if (SET_WITH_COUNTER(set))
                ip_set_init_counter(ext_counter(data, set), ext);
        if (SET_WITH_COMMENT(set))
-               ip_set_init_comment(ext_comment(data, set), ext);
+               ip_set_init_comment(set, ext_comment(data, set), ext);
        if (SET_WITH_SKBINFO(set))
                ip_set_init_skbinfo(ext_skbinfo(data, set), ext);
        /* Must come last for the case when timed out entry is reused */
@@ -883,7 +882,7 @@ mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                smp_mb__after_atomic();
                if (i + 1 == n->pos)
                        n->pos--;
-               h->elements--;
+               set->elements--;
 #ifdef IP_SET_HASH_WITH_NETS
                for (j = 0; j < IPSET_NET_COUNT; j++)
                        mtype_del_cidr(h, NCIDR_PUT(DCIDR_GET(d->cidr, j)),
@@ -896,6 +895,7 @@ mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                                k++;
                }
                if (n->pos == 0 && k == 0) {
+                       set->ext_size -= sizeof(*n) + n->size * dsize;
                        rcu_assign_pointer(hbucket(t, key), NULL);
                        kfree_rcu(n, rcu);
                } else if (k >= AHASH_INIT_SIZE) {
@@ -914,6 +914,7 @@ mtype_del(struct ip_set *set, void *value, const struct ip_set_ext *ext,
                                k++;
                        }
                        tmp->pos = k;
+                       set->ext_size -= AHASH_INIT_SIZE * dsize;
                        rcu_assign_pointer(hbucket(t, key), tmp);
                        kfree_rcu(n, rcu);
                }
@@ -1062,7 +1063,7 @@ mtype_head(struct ip_set *set, struct sk_buff *skb)
 
        rcu_read_lock_bh();
        t = rcu_dereference_bh_nfnl(h->table);
-       memsize = mtype_ahash_memsize(h, t, NLEN(set->family), set->dsize);
+       memsize = mtype_ahash_memsize(h, t, NLEN(set->family)) + set->ext_size;
        htable_bits = t->htable_bits;
        rcu_read_unlock_bh();
 
@@ -1083,7 +1084,8 @@ mtype_head(struct ip_set *set, struct sk_buff *skb)
                goto nla_put_failure;
 #endif
        if (nla_put_net32(skb, IPSET_ATTR_REFERENCES, htonl(set->ref)) ||
-           nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)))
+           nla_put_net32(skb, IPSET_ATTR_MEMSIZE, htonl(memsize)) ||
+           nla_put_net32(skb, IPSET_ATTR_ELEMENTS, htonl(set->elements)))
                goto nla_put_failure;
        if (unlikely(ip_set_put_flags(skb, set)))
                goto nla_put_failure;