]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blobdiff - net/netfilter/nf_tables_api.c
netfilter: nf_tables: disallow jump to implicit chain from set element
[mirror_ubuntu-jammy-kernel.git] / net / netfilter / nf_tables_api.c
index b03794114fee0b4d461d6e88c3d7e8d92a45a815..8bc4460b627aef8671d7d6267fba579a824a8e1e 100644 (file)
@@ -32,7 +32,6 @@ static LIST_HEAD(nf_tables_objects);
 static LIST_HEAD(nf_tables_flowtables);
 static LIST_HEAD(nf_tables_destroy_list);
 static DEFINE_SPINLOCK(nf_tables_destroy_list_lock);
-static u64 table_handle;
 
 enum {
        NFT_VALIDATE_SKIP       = 0,
@@ -153,6 +152,7 @@ static struct nft_trans *nft_trans_alloc_gfp(const struct nft_ctx *ctx,
        if (trans == NULL)
                return NULL;
 
+       INIT_LIST_HEAD(&trans->list);
        trans->msg_type = msg_type;
        trans->ctx      = *ctx;
 
@@ -836,7 +836,7 @@ static int nf_tables_dump_tables(struct sk_buff *skb,
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (family != NFPROTO_UNSPEC && family != table->family)
@@ -1155,7 +1155,7 @@ static int nf_tables_newtable(struct sk_buff *skb, const struct nfnl_info *info,
        INIT_LIST_HEAD(&table->flowtables);
        table->family = family;
        table->flags = flags;
-       table->handle = ++table_handle;
+       table->handle = ++nft_net->table_handle;
        if (table->flags & NFT_TABLE_F_OWNER)
                table->nlpid = NETLINK_CB(skb).portid;
 
@@ -1625,7 +1625,7 @@ static int nf_tables_dump_chains(struct sk_buff *skb,
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (family != NFPROTO_UNSPEC && family != table->family)
@@ -2072,7 +2072,7 @@ static int nft_basechain_init(struct nft_base_chain *basechain, u8 family,
        chain->flags |= NFT_CHAIN_BASE | flags;
        basechain->policy = NF_ACCEPT;
        if (chain->flags & NFT_CHAIN_HW_OFFLOAD &&
-           nft_chain_offload_priority(basechain) < 0)
+           !nft_chain_offload_support(basechain))
                return -EOPNOTSUPP;
 
        flow_block_init(&basechain->flow_block);
@@ -2101,9 +2101,9 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
                              struct netlink_ext_ack *extack)
 {
        const struct nlattr * const *nla = ctx->nla;
+       struct nft_stats __percpu *stats = NULL;
        struct nft_table *table = ctx->table;
        struct nft_base_chain *basechain;
-       struct nft_stats __percpu *stats;
        struct net *net = ctx->net;
        char name[NFT_NAME_MAXLEN];
        struct nft_trans *trans;
@@ -2140,7 +2140,6 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
                                return PTR_ERR(stats);
                        }
                        rcu_assign_pointer(basechain->stats, stats);
-                       static_branch_inc(&nft_counters_enabled);
                }
 
                err = nft_basechain_init(basechain, family, &hook, flags);
@@ -2223,6 +2222,9 @@ static int nf_tables_addchain(struct nft_ctx *ctx, u8 family, u8 genmask,
                goto err_unregister_hook;
        }
 
+       if (stats)
+               static_branch_inc(&nft_counters_enabled);
+
        table->use++;
 
        return 0;
@@ -2478,6 +2480,9 @@ static int nf_tables_newchain(struct sk_buff *skb, const struct nfnl_info *info,
        nft_ctx_init(&ctx, net, skb, info->nlh, family, table, chain, nla);
 
        if (chain != NULL) {
+               if (chain->flags & NFT_CHAIN_BINDING)
+                       return -EINVAL;
+
                if (info->nlh->nlmsg_flags & NLM_F_EXCL) {
                        NL_SET_BAD_ATTR(extack, attr);
                        return -EEXIST;
@@ -3053,7 +3058,7 @@ static int nf_tables_dump_rules(struct sk_buff *skb,
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (family != NFPROTO_UNSPEC && family != table->family)
@@ -3809,7 +3814,7 @@ cont:
                list_for_each_entry(i, &ctx->table->sets, list) {
                        int tmp;
 
-                       if (!nft_is_active_next(ctx->net, set))
+                       if (!nft_is_active_next(ctx->net, i))
                                continue;
                        if (!sscanf(i->name, name, &tmp))
                                continue;
@@ -4035,7 +4040,7 @@ static int nf_tables_dump_sets(struct sk_buff *skb, struct netlink_callback *cb)
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (ctx->family != NFPROTO_UNSPEC &&
@@ -4353,6 +4358,11 @@ static int nf_tables_newset(struct sk_buff *skb, const struct nfnl_info *info,
                err = nf_tables_set_desc_parse(&desc, nla[NFTA_SET_DESC]);
                if (err < 0)
                        return err;
+
+               if (desc.field_count > 1 && !(flags & NFT_SET_CONCAT))
+                       return -EINVAL;
+       } else if (flags & NFT_SET_CONCAT) {
+               return -EINVAL;
        }
 
        if (nla[NFTA_SET_EXPR] || nla[NFTA_SET_EXPRESSIONS])
@@ -4963,6 +4973,8 @@ static int nf_tables_dump_set(struct sk_buff *skb, struct netlink_callback *cb)
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
+       cb->seq = READ_ONCE(nft_net->base_seq);
+
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (dump_ctx->ctx.family != NFPROTO_UNSPEC &&
                    dump_ctx->ctx.family != table->family)
@@ -5098,6 +5110,9 @@ static int nft_setelem_parse_flags(const struct nft_set *set,
        if (!(set->flags & NFT_SET_INTERVAL) &&
            *flags & NFT_SET_ELEM_INTERVAL_END)
                return -EINVAL;
+       if ((*flags & (NFT_SET_ELEM_INTERVAL_END | NFT_SET_ELEM_CATCHALL)) ==
+           (NFT_SET_ELEM_INTERVAL_END | NFT_SET_ELEM_CATCHALL))
+               return -EINVAL;
 
        return 0;
 }
@@ -5105,19 +5120,13 @@ static int nft_setelem_parse_flags(const struct nft_set *set,
 static int nft_setelem_parse_key(struct nft_ctx *ctx, struct nft_set *set,
                                 struct nft_data *key, struct nlattr *attr)
 {
-       struct nft_data_desc desc;
-       int err;
-
-       err = nft_data_init(ctx, key, NFT_DATA_VALUE_MAXLEN, &desc, attr);
-       if (err < 0)
-               return err;
-
-       if (desc.type != NFT_DATA_VALUE || desc.len != set->klen) {
-               nft_data_release(key, desc.type);
-               return -EINVAL;
-       }
+       struct nft_data_desc desc = {
+               .type   = NFT_DATA_VALUE,
+               .size   = NFT_DATA_VALUE_MAXLEN,
+               .len    = set->klen,
+       };
 
-       return 0;
+       return nft_data_init(ctx, key, &desc, attr);
 }
 
 static int nft_setelem_parse_data(struct nft_ctx *ctx, struct nft_set *set,
@@ -5126,24 +5135,18 @@ static int nft_setelem_parse_data(struct nft_ctx *ctx, struct nft_set *set,
                                  struct nlattr *attr)
 {
        u32 dtype;
-       int err;
-
-       err = nft_data_init(ctx, data, NFT_DATA_VALUE_MAXLEN, desc, attr);
-       if (err < 0)
-               return err;
 
        if (set->dtype == NFT_DATA_VERDICT)
                dtype = NFT_DATA_VERDICT;
        else
                dtype = NFT_DATA_VALUE;
 
-       if (dtype != desc->type ||
-           set->dlen != desc->len) {
-               nft_data_release(data, desc->type);
-               return -EINVAL;
-       }
+       desc->type = dtype;
+       desc->size = NFT_DATA_VALUE_MAXLEN;
+       desc->len = set->dlen;
+       desc->flags = NFT_DATA_DESC_SETELEM;
 
-       return 0;
+       return nft_data_init(ctx, data, desc, attr);
 }
 
 static void *nft_setelem_catchall_get(const struct net *net,
@@ -5476,7 +5479,7 @@ int nft_set_elem_expr_clone(const struct nft_ctx *ctx, struct nft_set *set,
 
                err = nft_expr_clone(expr, set->exprs[i]);
                if (err < 0) {
-                       nft_expr_destroy(ctx, expr);
+                       kfree(expr);
                        goto err_expr;
                }
                expr_array[i] = expr;
@@ -5708,6 +5711,24 @@ static void nft_setelem_remove(const struct net *net,
                set->ops->remove(net, set, elem);
 }
 
+static bool nft_setelem_valid_key_end(const struct nft_set *set,
+                                     struct nlattr **nla, u32 flags)
+{
+       if ((set->flags & (NFT_SET_CONCAT | NFT_SET_INTERVAL)) ==
+                         (NFT_SET_CONCAT | NFT_SET_INTERVAL)) {
+               if (flags & NFT_SET_ELEM_INTERVAL_END)
+                       return false;
+               if (!nla[NFTA_SET_ELEM_KEY_END] &&
+                   !(flags & NFT_SET_ELEM_CATCHALL))
+                       return false;
+       } else {
+               if (nla[NFTA_SET_ELEM_KEY_END])
+                       return false;
+       }
+
+       return true;
+}
+
 static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                            const struct nlattr *attr, u32 nlmsg_flags)
 {
@@ -5743,8 +5764,11 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
        if (!nla[NFTA_SET_ELEM_KEY] && !(flags & NFT_SET_ELEM_CATCHALL))
                return -EINVAL;
 
-       if (flags != 0)
-               nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS);
+       if (flags != 0) {
+               err = nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS);
+               if (err < 0)
+                       return err;
+       }
 
        if (set->flags & NFT_SET_MAP) {
                if (nla[NFTA_SET_ELEM_DATA] == NULL &&
@@ -5755,6 +5779,18 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                        return -EINVAL;
        }
 
+       if (set->flags & NFT_SET_OBJECT) {
+               if (!nla[NFTA_SET_ELEM_OBJREF] &&
+                   !(flags & NFT_SET_ELEM_INTERVAL_END))
+                       return -EINVAL;
+       } else {
+               if (nla[NFTA_SET_ELEM_OBJREF])
+                       return -EINVAL;
+       }
+
+       if (!nft_setelem_valid_key_end(set, nla, flags))
+               return -EINVAL;
+
        if ((flags & NFT_SET_ELEM_INTERVAL_END) &&
             (nla[NFTA_SET_ELEM_DATA] ||
              nla[NFTA_SET_ELEM_OBJREF] ||
@@ -5762,6 +5798,7 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
              nla[NFTA_SET_ELEM_EXPIRATION] ||
              nla[NFTA_SET_ELEM_USERDATA] ||
              nla[NFTA_SET_ELEM_EXPR] ||
+             nla[NFTA_SET_ELEM_KEY_END] ||
              nla[NFTA_SET_ELEM_EXPRESSIONS]))
                return -EINVAL;
 
@@ -5853,7 +5890,9 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                if (err < 0)
                        goto err_set_elem_expr;
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen);
+               if (err < 0)
+                       goto err_parse_key;
        }
 
        if (nla[NFTA_SET_ELEM_KEY_END]) {
@@ -5862,29 +5901,34 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                if (err < 0)
                        goto err_parse_key;
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen);
+               if (err < 0)
+                       goto err_parse_key_end;
        }
 
        if (timeout > 0) {
-               nft_set_ext_add(&tmpl, NFT_SET_EXT_EXPIRATION);
-               if (timeout != set->timeout)
-                       nft_set_ext_add(&tmpl, NFT_SET_EXT_TIMEOUT);
+               err = nft_set_ext_add(&tmpl, NFT_SET_EXT_EXPIRATION);
+               if (err < 0)
+                       goto err_parse_key_end;
+
+               if (timeout != set->timeout) {
+                       err = nft_set_ext_add(&tmpl, NFT_SET_EXT_TIMEOUT);
+                       if (err < 0)
+                               goto err_parse_key_end;
+               }
        }
 
        if (num_exprs) {
                for (i = 0; i < num_exprs; i++)
                        size += expr_array[i]->ops->size;
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_EXPRESSIONS,
-                                      sizeof(struct nft_set_elem_expr) +
-                                      size);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_EXPRESSIONS,
+                                            sizeof(struct nft_set_elem_expr) + size);
+               if (err < 0)
+                       goto err_parse_key_end;
        }
 
        if (nla[NFTA_SET_ELEM_OBJREF] != NULL) {
-               if (!(set->flags & NFT_SET_OBJECT)) {
-                       err = -EINVAL;
-                       goto err_parse_key_end;
-               }
                obj = nft_obj_lookup(ctx->net, ctx->table,
                                     nla[NFTA_SET_ELEM_OBJREF],
                                     set->objtype, genmask);
@@ -5892,7 +5936,9 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                        err = PTR_ERR(obj);
                        goto err_parse_key_end;
                }
-               nft_set_ext_add(&tmpl, NFT_SET_EXT_OBJREF);
+               err = nft_set_ext_add(&tmpl, NFT_SET_EXT_OBJREF);
+               if (err < 0)
+                       goto err_parse_key_end;
        }
 
        if (nla[NFTA_SET_ELEM_DATA] != NULL) {
@@ -5926,7 +5972,9 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
                                                          NFT_VALIDATE_NEED);
                }
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_DATA, desc.len);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_DATA, desc.len);
+               if (err < 0)
+                       goto err_parse_data;
        }
 
        /* The full maximum length of userdata can exceed the maximum
@@ -5936,9 +5984,12 @@ static int nft_add_set_elem(struct nft_ctx *ctx, struct nft_set *set,
        ulen = 0;
        if (nla[NFTA_SET_ELEM_USERDATA] != NULL) {
                ulen = nla_len(nla[NFTA_SET_ELEM_USERDATA]);
-               if (ulen > 0)
-                       nft_set_ext_add_length(&tmpl, NFT_SET_EXT_USERDATA,
-                                              ulen);
+               if (ulen > 0) {
+                       err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_USERDATA,
+                                                    ulen);
+                       if (err < 0)
+                               goto err_parse_data;
+               }
        }
 
        err = -ENOMEM;
@@ -6162,10 +6213,16 @@ static int nft_del_setelem(struct nft_ctx *ctx, struct nft_set *set,
        if (!nla[NFTA_SET_ELEM_KEY] && !(flags & NFT_SET_ELEM_CATCHALL))
                return -EINVAL;
 
+       if (!nft_setelem_valid_key_end(set, nla, flags))
+               return -EINVAL;
+
        nft_set_ext_prepare(&tmpl);
 
-       if (flags != 0)
-               nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS);
+       if (flags != 0) {
+               err = nft_set_ext_add(&tmpl, NFT_SET_EXT_FLAGS);
+               if (err < 0)
+                       return err;
+       }
 
        if (nla[NFTA_SET_ELEM_KEY]) {
                err = nft_setelem_parse_key(ctx, set, &elem.key.val,
@@ -6173,16 +6230,20 @@ static int nft_del_setelem(struct nft_ctx *ctx, struct nft_set *set,
                if (err < 0)
                        return err;
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY, set->klen);
+               if (err < 0)
+                       goto fail_elem;
        }
 
        if (nla[NFTA_SET_ELEM_KEY_END]) {
                err = nft_setelem_parse_key(ctx, set, &elem.key_end.val,
                                            nla[NFTA_SET_ELEM_KEY_END]);
                if (err < 0)
-                       return err;
+                       goto fail_elem;
 
-               nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen);
+               err = nft_set_ext_add_length(&tmpl, NFT_SET_EXT_KEY_END, set->klen);
+               if (err < 0)
+                       goto fail_elem_key_end;
        }
 
        err = -ENOMEM;
@@ -6190,7 +6251,7 @@ static int nft_del_setelem(struct nft_ctx *ctx, struct nft_set *set,
                                      elem.key_end.val.data, NULL, 0, 0,
                                      GFP_KERNEL);
        if (elem.priv == NULL)
-               goto fail_elem;
+               goto fail_elem_key_end;
 
        ext = nft_set_elem_ext(set, elem.priv);
        if (flags)
@@ -6214,6 +6275,8 @@ fail_ops:
        kfree(trans);
 fail_trans:
        kfree(elem.priv);
+fail_elem_key_end:
+       nft_data_release(&elem.key_end.val, NFT_DATA_VALUE);
 fail_elem:
        nft_data_release(&elem.key.val, NFT_DATA_VALUE);
        return err;
@@ -6765,7 +6828,7 @@ static int nf_tables_dump_obj(struct sk_buff *skb, struct netlink_callback *cb)
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (family != NFPROTO_UNSPEC && family != table->family)
@@ -7346,11 +7409,15 @@ static int nft_flowtable_update(struct nft_ctx *ctx, const struct nlmsghdr *nlh,
 
        if (nla[NFTA_FLOWTABLE_FLAGS]) {
                flags = ntohl(nla_get_be32(nla[NFTA_FLOWTABLE_FLAGS]));
-               if (flags & ~NFT_FLOWTABLE_MASK)
-                       return -EOPNOTSUPP;
+               if (flags & ~NFT_FLOWTABLE_MASK) {
+                       err = -EOPNOTSUPP;
+                       goto err_flowtable_update_hook;
+               }
                if ((flowtable->data.flags & NFT_FLOWTABLE_HW_OFFLOAD) ^
-                   (flags & NFT_FLOWTABLE_HW_OFFLOAD))
-                       return -EOPNOTSUPP;
+                   (flags & NFT_FLOWTABLE_HW_OFFLOAD)) {
+                       err = -EOPNOTSUPP;
+                       goto err_flowtable_update_hook;
+               }
        } else {
                flags = flowtable->data.flags;
        }
@@ -7693,7 +7760,7 @@ static int nf_tables_dump_flowtable(struct sk_buff *skb,
 
        rcu_read_lock();
        nft_net = nft_pernet(net);
-       cb->seq = nft_net->base_seq;
+       cb->seq = READ_ONCE(nft_net->base_seq);
 
        list_for_each_entry_rcu(table, &nft_net->tables, list) {
                if (family != NFPROTO_UNSPEC && family != table->family)
@@ -8238,6 +8305,9 @@ static void nft_commit_release(struct nft_trans *trans)
                nf_tables_chain_destroy(&trans->ctx);
                break;
        case NFT_MSG_DELRULE:
+               if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)
+                       nft_flow_rule_destroy(nft_trans_flow_rule(trans));
+
                nf_tables_rule_destroy(&trans->ctx, nft_trans_rule(trans));
                break;
        case NFT_MSG_DELSET:
@@ -8574,6 +8644,7 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb)
        struct nft_trans_elem *te;
        struct nft_chain *chain;
        struct nft_table *table;
+       unsigned int base_seq;
        LIST_HEAD(adl);
        int err;
 
@@ -8623,9 +8694,12 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb)
         * Bump generation counter, invalidate any dump in progress.
         * Cannot fail after this point.
         */
-       while (++nft_net->base_seq == 0)
+       base_seq = READ_ONCE(nft_net->base_seq);
+       while (++base_seq == 0)
                ;
 
+       WRITE_ONCE(nft_net->base_seq, base_seq);
+
        /* step 3. Start new generation, rules_gen_X now in use. */
        net->nft.gencursor = nft_gencursor_next(net);
 
@@ -8677,6 +8751,9 @@ static int nf_tables_commit(struct net *net, struct sk_buff *skb)
                        nf_tables_rule_notify(&trans->ctx,
                                              nft_trans_rule(trans),
                                              NFT_MSG_NEWRULE);
+                       if (trans->ctx.chain->flags & NFT_CHAIN_HW_OFFLOAD)
+                               nft_flow_rule_destroy(nft_trans_flow_rule(trans));
+
                        nft_trans_destroy(trans);
                        break;
                case NFT_MSG_DELRULE:
@@ -9428,6 +9505,9 @@ static int nft_verdict_init(const struct nft_ctx *ctx, struct nft_data *data,
                        return PTR_ERR(chain);
                if (nft_is_base_chain(chain))
                        return -EOPNOTSUPP;
+               if (desc->flags & NFT_DATA_DESC_SETELEM &&
+                   chain->flags & NFT_CHAIN_BINDING)
+                       return -EINVAL;
 
                chain->use++;
                data->verdict.chain = chain;
@@ -9435,7 +9515,7 @@ static int nft_verdict_init(const struct nft_ctx *ctx, struct nft_data *data,
        }
 
        desc->len = sizeof(data->verdict);
-       desc->type = NFT_DATA_VERDICT;
+
        return 0;
 }
 
@@ -9488,20 +9568,25 @@ nla_put_failure:
 }
 
 static int nft_value_init(const struct nft_ctx *ctx,
-                         struct nft_data *data, unsigned int size,
-                         struct nft_data_desc *desc, const struct nlattr *nla)
+                         struct nft_data *data, struct nft_data_desc *desc,
+                         const struct nlattr *nla)
 {
        unsigned int len;
 
        len = nla_len(nla);
        if (len == 0)
                return -EINVAL;
-       if (len > size)
+       if (len > desc->size)
                return -EOVERFLOW;
+       if (desc->len) {
+               if (len != desc->len)
+                       return -EINVAL;
+       } else {
+               desc->len = len;
+       }
 
        nla_memcpy(data->data, nla, len);
-       desc->type = NFT_DATA_VALUE;
-       desc->len  = len;
+
        return 0;
 }
 
@@ -9521,7 +9606,6 @@ static const struct nla_policy nft_data_policy[NFTA_DATA_MAX + 1] = {
  *
  *     @ctx: context of the expression using the data
  *     @data: destination struct nft_data
- *     @size: maximum data length
  *     @desc: data description
  *     @nla: netlink attribute containing data
  *
@@ -9531,24 +9615,35 @@ static const struct nla_policy nft_data_policy[NFTA_DATA_MAX + 1] = {
  *     The caller can indicate that it only wants to accept data of type
  *     NFT_DATA_VALUE by passing NULL for the ctx argument.
  */
-int nft_data_init(const struct nft_ctx *ctx,
-                 struct nft_data *data, unsigned int size,
+int nft_data_init(const struct nft_ctx *ctx, struct nft_data *data,
                  struct nft_data_desc *desc, const struct nlattr *nla)
 {
        struct nlattr *tb[NFTA_DATA_MAX + 1];
        int err;
 
+       if (WARN_ON_ONCE(!desc->size))
+               return -EINVAL;
+
        err = nla_parse_nested_deprecated(tb, NFTA_DATA_MAX, nla,
                                          nft_data_policy, NULL);
        if (err < 0)
                return err;
 
-       if (tb[NFTA_DATA_VALUE])
-               return nft_value_init(ctx, data, size, desc,
-                                     tb[NFTA_DATA_VALUE]);
-       if (tb[NFTA_DATA_VERDICT] && ctx != NULL)
-               return nft_verdict_init(ctx, data, desc, tb[NFTA_DATA_VERDICT]);
-       return -EINVAL;
+       if (tb[NFTA_DATA_VALUE]) {
+               if (desc->type != NFT_DATA_VALUE)
+                       return -EINVAL;
+
+               err = nft_value_init(ctx, data, desc, tb[NFTA_DATA_VALUE]);
+       } else if (tb[NFTA_DATA_VERDICT] && ctx != NULL) {
+               if (desc->type != NFT_DATA_VERDICT)
+                       return -EINVAL;
+
+               err = nft_verdict_init(ctx, data, desc, tb[NFTA_DATA_VERDICT]);
+       } else {
+               err = -EINVAL;
+       }
+
+       return err;
 }
 EXPORT_SYMBOL_GPL(nft_data_init);