1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/btf_ids.h>
6 #include <linux/filter.h>
7 #include <linux/errno.h>
8 #include <linux/file.h>
10 #include <linux/workqueue.h>
11 #include <linux/skmsg.h>
12 #include <linux/list.h>
13 #include <linux/jhash.h>
14 #include <linux/sock_diag.h>
20 struct sk_psock_progs progs
;
24 #define SOCK_CREATE_FLAG_MASK \
25 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
27 static int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
28 struct bpf_prog
*old
, u32 which
);
29 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
);
31 static struct bpf_map
*sock_map_alloc(union bpf_attr
*attr
)
33 struct bpf_stab
*stab
;
35 if (!capable(CAP_NET_ADMIN
))
36 return ERR_PTR(-EPERM
);
37 if (attr
->max_entries
== 0 ||
38 attr
->key_size
!= 4 ||
39 (attr
->value_size
!= sizeof(u32
) &&
40 attr
->value_size
!= sizeof(u64
)) ||
41 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
42 return ERR_PTR(-EINVAL
);
44 stab
= kzalloc(sizeof(*stab
), GFP_USER
| __GFP_ACCOUNT
);
46 return ERR_PTR(-ENOMEM
);
48 bpf_map_init_from_attr(&stab
->map
, attr
);
49 raw_spin_lock_init(&stab
->lock
);
51 stab
->sks
= bpf_map_area_alloc((u64
) stab
->map
.max_entries
*
52 sizeof(struct sock
*),
56 return ERR_PTR(-ENOMEM
);
62 int sock_map_get_from_fd(const union bpf_attr
*attr
, struct bpf_prog
*prog
)
64 u32 ufd
= attr
->target_fd
;
69 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
73 map
= __bpf_map_get(f
);
76 ret
= sock_map_prog_update(map
, prog
, NULL
, attr
->attach_type
);
81 int sock_map_prog_detach(const union bpf_attr
*attr
, enum bpf_prog_type ptype
)
83 u32 ufd
= attr
->target_fd
;
84 struct bpf_prog
*prog
;
89 if (attr
->attach_flags
|| attr
->replace_bpf_fd
)
93 map
= __bpf_map_get(f
);
97 prog
= bpf_prog_get(attr
->attach_bpf_fd
);
103 if (prog
->type
!= ptype
) {
108 ret
= sock_map_prog_update(map
, NULL
, prog
, attr
->attach_type
);
116 static void sock_map_sk_acquire(struct sock
*sk
)
117 __acquires(&sk
->sk_lock
.slock
)
124 static void sock_map_sk_release(struct sock
*sk
)
125 __releases(&sk
->sk_lock
.slock
)
132 static void sock_map_add_link(struct sk_psock
*psock
,
133 struct sk_psock_link
*link
,
134 struct bpf_map
*map
, void *link_raw
)
136 link
->link_raw
= link_raw
;
138 spin_lock_bh(&psock
->link_lock
);
139 list_add_tail(&link
->list
, &psock
->link
);
140 spin_unlock_bh(&psock
->link_lock
);
143 static void sock_map_del_link(struct sock
*sk
,
144 struct sk_psock
*psock
, void *link_raw
)
146 bool strp_stop
= false, verdict_stop
= false;
147 struct sk_psock_link
*link
, *tmp
;
149 spin_lock_bh(&psock
->link_lock
);
150 list_for_each_entry_safe(link
, tmp
, &psock
->link
, list
) {
151 if (link
->link_raw
== link_raw
) {
152 struct bpf_map
*map
= link
->map
;
153 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
,
155 if (psock
->saved_data_ready
&& stab
->progs
.stream_parser
)
157 if (psock
->saved_data_ready
&& stab
->progs
.stream_verdict
)
159 if (psock
->saved_data_ready
&& stab
->progs
.skb_verdict
)
161 list_del(&link
->list
);
162 sk_psock_free_link(link
);
165 spin_unlock_bh(&psock
->link_lock
);
166 if (strp_stop
|| verdict_stop
) {
167 write_lock_bh(&sk
->sk_callback_lock
);
169 sk_psock_stop_strp(sk
, psock
);
171 sk_psock_stop_verdict(sk
, psock
);
173 if (psock
->psock_update_sk_prot
)
174 psock
->psock_update_sk_prot(sk
, psock
, false);
175 write_unlock_bh(&sk
->sk_callback_lock
);
179 static void sock_map_unref(struct sock
*sk
, void *link_raw
)
181 struct sk_psock
*psock
= sk_psock(sk
);
184 sock_map_del_link(sk
, psock
, link_raw
);
185 sk_psock_put(sk
, psock
);
189 static int sock_map_init_proto(struct sock
*sk
, struct sk_psock
*psock
)
191 if (!sk
->sk_prot
->psock_update_sk_prot
)
193 psock
->psock_update_sk_prot
= sk
->sk_prot
->psock_update_sk_prot
;
194 return sk
->sk_prot
->psock_update_sk_prot(sk
, psock
, false);
197 static struct sk_psock
*sock_map_psock_get_checked(struct sock
*sk
)
199 struct sk_psock
*psock
;
202 psock
= sk_psock(sk
);
204 if (sk
->sk_prot
->close
!= sock_map_close
) {
205 psock
= ERR_PTR(-EBUSY
);
209 if (!refcount_inc_not_zero(&psock
->refcnt
))
210 psock
= ERR_PTR(-EBUSY
);
217 static int sock_map_link(struct bpf_map
*map
, struct sock
*sk
)
219 struct sk_psock_progs
*progs
= sock_map_progs(map
);
220 struct bpf_prog
*stream_verdict
= NULL
;
221 struct bpf_prog
*stream_parser
= NULL
;
222 struct bpf_prog
*skb_verdict
= NULL
;
223 struct bpf_prog
*msg_parser
= NULL
;
224 struct sk_psock
*psock
;
227 stream_verdict
= READ_ONCE(progs
->stream_verdict
);
228 if (stream_verdict
) {
229 stream_verdict
= bpf_prog_inc_not_zero(stream_verdict
);
230 if (IS_ERR(stream_verdict
))
231 return PTR_ERR(stream_verdict
);
234 stream_parser
= READ_ONCE(progs
->stream_parser
);
236 stream_parser
= bpf_prog_inc_not_zero(stream_parser
);
237 if (IS_ERR(stream_parser
)) {
238 ret
= PTR_ERR(stream_parser
);
239 goto out_put_stream_verdict
;
243 msg_parser
= READ_ONCE(progs
->msg_parser
);
245 msg_parser
= bpf_prog_inc_not_zero(msg_parser
);
246 if (IS_ERR(msg_parser
)) {
247 ret
= PTR_ERR(msg_parser
);
248 goto out_put_stream_parser
;
252 skb_verdict
= READ_ONCE(progs
->skb_verdict
);
254 skb_verdict
= bpf_prog_inc_not_zero(skb_verdict
);
255 if (IS_ERR(skb_verdict
)) {
256 ret
= PTR_ERR(skb_verdict
);
257 goto out_put_msg_parser
;
261 psock
= sock_map_psock_get_checked(sk
);
263 ret
= PTR_ERR(psock
);
268 if ((msg_parser
&& READ_ONCE(psock
->progs
.msg_parser
)) ||
269 (stream_parser
&& READ_ONCE(psock
->progs
.stream_parser
)) ||
270 (skb_verdict
&& READ_ONCE(psock
->progs
.skb_verdict
)) ||
271 (skb_verdict
&& READ_ONCE(psock
->progs
.stream_verdict
)) ||
272 (stream_verdict
&& READ_ONCE(psock
->progs
.skb_verdict
)) ||
273 (stream_verdict
&& READ_ONCE(psock
->progs
.stream_verdict
))) {
274 sk_psock_put(sk
, psock
);
279 psock
= sk_psock_init(sk
, map
->numa_node
);
281 ret
= PTR_ERR(psock
);
287 psock_set_prog(&psock
->progs
.msg_parser
, msg_parser
);
289 psock_set_prog(&psock
->progs
.stream_parser
, stream_parser
);
291 psock_set_prog(&psock
->progs
.stream_verdict
, stream_verdict
);
293 psock_set_prog(&psock
->progs
.skb_verdict
, skb_verdict
);
295 /* msg_* and stream_* programs references tracked in psock after this
296 * point. Reference dec and cleanup will occur through psock destructor
298 ret
= sock_map_init_proto(sk
, psock
);
300 sk_psock_put(sk
, psock
);
304 write_lock_bh(&sk
->sk_callback_lock
);
305 if (stream_parser
&& stream_verdict
&& !psock
->saved_data_ready
) {
306 ret
= sk_psock_init_strp(sk
, psock
);
308 write_unlock_bh(&sk
->sk_callback_lock
);
309 sk_psock_put(sk
, psock
);
312 sk_psock_start_strp(sk
, psock
);
313 } else if (!stream_parser
&& stream_verdict
&& !psock
->saved_data_ready
) {
314 sk_psock_start_verdict(sk
,psock
);
315 } else if (!stream_verdict
&& skb_verdict
&& !psock
->saved_data_ready
) {
316 sk_psock_start_verdict(sk
, psock
);
318 write_unlock_bh(&sk
->sk_callback_lock
);
322 bpf_prog_put(skb_verdict
);
325 bpf_prog_put(msg_parser
);
326 out_put_stream_parser
:
328 bpf_prog_put(stream_parser
);
329 out_put_stream_verdict
:
331 bpf_prog_put(stream_verdict
);
336 static void sock_map_free(struct bpf_map
*map
)
338 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
341 /* After the sync no updates or deletes will be in-flight so it
342 * is safe to walk map and remove entries without risking a race
343 * in EEXIST update case.
346 for (i
= 0; i
< stab
->map
.max_entries
; i
++) {
347 struct sock
**psk
= &stab
->sks
[i
];
350 sk
= xchg(psk
, NULL
);
354 sock_map_unref(sk
, psk
);
360 /* wait for psock readers accessing its map link */
363 bpf_map_area_free(stab
->sks
);
367 static void sock_map_release_progs(struct bpf_map
*map
)
369 psock_progs_drop(&container_of(map
, struct bpf_stab
, map
)->progs
);
372 static struct sock
*__sock_map_lookup_elem(struct bpf_map
*map
, u32 key
)
374 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
376 WARN_ON_ONCE(!rcu_read_lock_held());
378 if (unlikely(key
>= map
->max_entries
))
380 return READ_ONCE(stab
->sks
[key
]);
383 static void *sock_map_lookup(struct bpf_map
*map
, void *key
)
387 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
390 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
395 static void *sock_map_lookup_sys(struct bpf_map
*map
, void *key
)
399 if (map
->value_size
!= sizeof(u64
))
400 return ERR_PTR(-ENOSPC
);
402 sk
= __sock_map_lookup_elem(map
, *(u32
*)key
);
404 return ERR_PTR(-ENOENT
);
406 __sock_gen_cookie(sk
);
407 return &sk
->sk_cookie
;
410 static int __sock_map_delete(struct bpf_stab
*stab
, struct sock
*sk_test
,
416 raw_spin_lock_bh(&stab
->lock
);
418 if (!sk_test
|| sk_test
== sk
)
419 sk
= xchg(psk
, NULL
);
422 sock_map_unref(sk
, psk
);
426 raw_spin_unlock_bh(&stab
->lock
);
430 static void sock_map_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
433 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
435 __sock_map_delete(stab
, sk
, link_raw
);
438 static int sock_map_delete_elem(struct bpf_map
*map
, void *key
)
440 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
444 if (unlikely(i
>= map
->max_entries
))
448 return __sock_map_delete(stab
, NULL
, psk
);
451 static int sock_map_get_next_key(struct bpf_map
*map
, void *key
, void *next
)
453 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
454 u32 i
= key
? *(u32
*)key
: U32_MAX
;
455 u32
*key_next
= next
;
457 if (i
== stab
->map
.max_entries
- 1)
459 if (i
>= stab
->map
.max_entries
)
466 static int sock_map_update_common(struct bpf_map
*map
, u32 idx
,
467 struct sock
*sk
, u64 flags
)
469 struct bpf_stab
*stab
= container_of(map
, struct bpf_stab
, map
);
470 struct sk_psock_link
*link
;
471 struct sk_psock
*psock
;
475 WARN_ON_ONCE(!rcu_read_lock_held());
476 if (unlikely(flags
> BPF_EXIST
))
478 if (unlikely(idx
>= map
->max_entries
))
481 link
= sk_psock_init_link();
485 ret
= sock_map_link(map
, sk
);
489 psock
= sk_psock(sk
);
490 WARN_ON_ONCE(!psock
);
492 raw_spin_lock_bh(&stab
->lock
);
493 osk
= stab
->sks
[idx
];
494 if (osk
&& flags
== BPF_NOEXIST
) {
497 } else if (!osk
&& flags
== BPF_EXIST
) {
502 sock_map_add_link(psock
, link
, map
, &stab
->sks
[idx
]);
505 sock_map_unref(osk
, &stab
->sks
[idx
]);
506 raw_spin_unlock_bh(&stab
->lock
);
509 raw_spin_unlock_bh(&stab
->lock
);
511 sk_psock_put(sk
, psock
);
513 sk_psock_free_link(link
);
517 static bool sock_map_op_okay(const struct bpf_sock_ops_kern
*ops
)
519 return ops
->op
== BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB
||
520 ops
->op
== BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB
||
521 ops
->op
== BPF_SOCK_OPS_TCP_LISTEN_CB
;
524 static bool sk_is_tcp(const struct sock
*sk
)
526 return sk
->sk_type
== SOCK_STREAM
&&
527 sk
->sk_protocol
== IPPROTO_TCP
;
530 static bool sock_map_redirect_allowed(const struct sock
*sk
)
533 return sk
->sk_state
!= TCP_LISTEN
;
535 return sk
->sk_state
== TCP_ESTABLISHED
;
538 static bool sock_map_sk_is_suitable(const struct sock
*sk
)
540 return !!sk
->sk_prot
->psock_update_sk_prot
;
543 static bool sock_map_sk_state_allowed(const struct sock
*sk
)
546 return (1 << sk
->sk_state
) & (TCPF_ESTABLISHED
| TCPF_LISTEN
);
550 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
551 struct sock
*sk
, u64 flags
);
553 int sock_map_update_elem_sys(struct bpf_map
*map
, void *key
, void *value
,
561 if (map
->value_size
== sizeof(u64
))
568 sock
= sockfd_lookup(ufd
, &ret
);
576 if (!sock_map_sk_is_suitable(sk
)) {
581 sock_map_sk_acquire(sk
);
582 if (!sock_map_sk_state_allowed(sk
))
584 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
585 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
587 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
588 sock_map_sk_release(sk
);
594 static int sock_map_update_elem(struct bpf_map
*map
, void *key
,
595 void *value
, u64 flags
)
597 struct sock
*sk
= (struct sock
*)value
;
600 if (unlikely(!sk
|| !sk_fullsock(sk
)))
603 if (!sock_map_sk_is_suitable(sk
))
608 if (!sock_map_sk_state_allowed(sk
))
610 else if (map
->map_type
== BPF_MAP_TYPE_SOCKMAP
)
611 ret
= sock_map_update_common(map
, *(u32
*)key
, sk
, flags
);
613 ret
= sock_hash_update_common(map
, key
, sk
, flags
);
619 BPF_CALL_4(bpf_sock_map_update
, struct bpf_sock_ops_kern
*, sops
,
620 struct bpf_map
*, map
, void *, key
, u64
, flags
)
622 WARN_ON_ONCE(!rcu_read_lock_held());
624 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
625 sock_map_op_okay(sops
)))
626 return sock_map_update_common(map
, *(u32
*)key
, sops
->sk
,
631 const struct bpf_func_proto bpf_sock_map_update_proto
= {
632 .func
= bpf_sock_map_update
,
635 .ret_type
= RET_INTEGER
,
636 .arg1_type
= ARG_PTR_TO_CTX
,
637 .arg2_type
= ARG_CONST_MAP_PTR
,
638 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
639 .arg4_type
= ARG_ANYTHING
,
642 BPF_CALL_4(bpf_sk_redirect_map
, struct sk_buff
*, skb
,
643 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
647 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
650 sk
= __sock_map_lookup_elem(map
, key
);
651 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
654 skb_bpf_set_redir(skb
, sk
, flags
& BPF_F_INGRESS
);
658 const struct bpf_func_proto bpf_sk_redirect_map_proto
= {
659 .func
= bpf_sk_redirect_map
,
661 .ret_type
= RET_INTEGER
,
662 .arg1_type
= ARG_PTR_TO_CTX
,
663 .arg2_type
= ARG_CONST_MAP_PTR
,
664 .arg3_type
= ARG_ANYTHING
,
665 .arg4_type
= ARG_ANYTHING
,
668 BPF_CALL_4(bpf_msg_redirect_map
, struct sk_msg
*, msg
,
669 struct bpf_map
*, map
, u32
, key
, u64
, flags
)
673 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
676 sk
= __sock_map_lookup_elem(map
, key
);
677 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
685 const struct bpf_func_proto bpf_msg_redirect_map_proto
= {
686 .func
= bpf_msg_redirect_map
,
688 .ret_type
= RET_INTEGER
,
689 .arg1_type
= ARG_PTR_TO_CTX
,
690 .arg2_type
= ARG_CONST_MAP_PTR
,
691 .arg3_type
= ARG_ANYTHING
,
692 .arg4_type
= ARG_ANYTHING
,
695 struct sock_map_seq_info
{
701 struct bpf_iter__sockmap
{
702 __bpf_md_ptr(struct bpf_iter_meta
*, meta
);
703 __bpf_md_ptr(struct bpf_map
*, map
);
704 __bpf_md_ptr(void *, key
);
705 __bpf_md_ptr(struct sock
*, sk
);
708 DEFINE_BPF_ITER_FUNC(sockmap
, struct bpf_iter_meta
*meta
,
709 struct bpf_map
*map
, void *key
,
712 static void *sock_map_seq_lookup_elem(struct sock_map_seq_info
*info
)
714 if (unlikely(info
->index
>= info
->map
->max_entries
))
717 info
->sk
= __sock_map_lookup_elem(info
->map
, info
->index
);
719 /* can't return sk directly, since that might be NULL */
723 static void *sock_map_seq_start(struct seq_file
*seq
, loff_t
*pos
)
726 struct sock_map_seq_info
*info
= seq
->private;
731 /* pairs with sock_map_seq_stop */
733 return sock_map_seq_lookup_elem(info
);
736 static void *sock_map_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
739 struct sock_map_seq_info
*info
= seq
->private;
744 return sock_map_seq_lookup_elem(info
);
747 static int sock_map_seq_show(struct seq_file
*seq
, void *v
)
750 struct sock_map_seq_info
*info
= seq
->private;
751 struct bpf_iter__sockmap ctx
= {};
752 struct bpf_iter_meta meta
;
753 struct bpf_prog
*prog
;
756 prog
= bpf_iter_get_info(&meta
, !v
);
763 ctx
.key
= &info
->index
;
767 return bpf_iter_run_prog(prog
, &ctx
);
770 static void sock_map_seq_stop(struct seq_file
*seq
, void *v
)
774 (void)sock_map_seq_show(seq
, NULL
);
776 /* pairs with sock_map_seq_start */
780 static const struct seq_operations sock_map_seq_ops
= {
781 .start
= sock_map_seq_start
,
782 .next
= sock_map_seq_next
,
783 .stop
= sock_map_seq_stop
,
784 .show
= sock_map_seq_show
,
787 static int sock_map_init_seq_private(void *priv_data
,
788 struct bpf_iter_aux_info
*aux
)
790 struct sock_map_seq_info
*info
= priv_data
;
792 info
->map
= aux
->map
;
796 static const struct bpf_iter_seq_info sock_map_iter_seq_info
= {
797 .seq_ops
= &sock_map_seq_ops
,
798 .init_seq_private
= sock_map_init_seq_private
,
799 .seq_priv_size
= sizeof(struct sock_map_seq_info
),
802 static int sock_map_btf_id
;
803 const struct bpf_map_ops sock_map_ops
= {
804 .map_meta_equal
= bpf_map_meta_equal
,
805 .map_alloc
= sock_map_alloc
,
806 .map_free
= sock_map_free
,
807 .map_get_next_key
= sock_map_get_next_key
,
808 .map_lookup_elem_sys_only
= sock_map_lookup_sys
,
809 .map_update_elem
= sock_map_update_elem
,
810 .map_delete_elem
= sock_map_delete_elem
,
811 .map_lookup_elem
= sock_map_lookup
,
812 .map_release_uref
= sock_map_release_progs
,
813 .map_check_btf
= map_check_no_btf
,
814 .map_btf_name
= "bpf_stab",
815 .map_btf_id
= &sock_map_btf_id
,
816 .iter_seq_info
= &sock_map_iter_seq_info
,
819 struct bpf_shtab_elem
{
823 struct hlist_node node
;
827 struct bpf_shtab_bucket
{
828 struct hlist_head head
;
834 struct bpf_shtab_bucket
*buckets
;
837 struct sk_psock_progs progs
;
841 static inline u32
sock_hash_bucket_hash(const void *key
, u32 len
)
843 return jhash(key
, len
, 0);
846 static struct bpf_shtab_bucket
*sock_hash_select_bucket(struct bpf_shtab
*htab
,
849 return &htab
->buckets
[hash
& (htab
->buckets_num
- 1)];
852 static struct bpf_shtab_elem
*
853 sock_hash_lookup_elem_raw(struct hlist_head
*head
, u32 hash
, void *key
,
856 struct bpf_shtab_elem
*elem
;
858 hlist_for_each_entry_rcu(elem
, head
, node
) {
859 if (elem
->hash
== hash
&&
860 !memcmp(&elem
->key
, key
, key_size
))
867 static struct sock
*__sock_hash_lookup_elem(struct bpf_map
*map
, void *key
)
869 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
870 u32 key_size
= map
->key_size
, hash
;
871 struct bpf_shtab_bucket
*bucket
;
872 struct bpf_shtab_elem
*elem
;
874 WARN_ON_ONCE(!rcu_read_lock_held());
876 hash
= sock_hash_bucket_hash(key
, key_size
);
877 bucket
= sock_hash_select_bucket(htab
, hash
);
878 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
880 return elem
? elem
->sk
: NULL
;
883 static void sock_hash_free_elem(struct bpf_shtab
*htab
,
884 struct bpf_shtab_elem
*elem
)
886 atomic_dec(&htab
->count
);
887 kfree_rcu(elem
, rcu
);
890 static void sock_hash_delete_from_link(struct bpf_map
*map
, struct sock
*sk
,
893 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
894 struct bpf_shtab_elem
*elem_probe
, *elem
= link_raw
;
895 struct bpf_shtab_bucket
*bucket
;
897 WARN_ON_ONCE(!rcu_read_lock_held());
898 bucket
= sock_hash_select_bucket(htab
, elem
->hash
);
900 /* elem may be deleted in parallel from the map, but access here
901 * is okay since it's going away only after RCU grace period.
902 * However, we need to check whether it's still present.
904 raw_spin_lock_bh(&bucket
->lock
);
905 elem_probe
= sock_hash_lookup_elem_raw(&bucket
->head
, elem
->hash
,
906 elem
->key
, map
->key_size
);
907 if (elem_probe
&& elem_probe
== elem
) {
908 hlist_del_rcu(&elem
->node
);
909 sock_map_unref(elem
->sk
, elem
);
910 sock_hash_free_elem(htab
, elem
);
912 raw_spin_unlock_bh(&bucket
->lock
);
915 static int sock_hash_delete_elem(struct bpf_map
*map
, void *key
)
917 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
918 u32 hash
, key_size
= map
->key_size
;
919 struct bpf_shtab_bucket
*bucket
;
920 struct bpf_shtab_elem
*elem
;
923 hash
= sock_hash_bucket_hash(key
, key_size
);
924 bucket
= sock_hash_select_bucket(htab
, hash
);
926 raw_spin_lock_bh(&bucket
->lock
);
927 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
929 hlist_del_rcu(&elem
->node
);
930 sock_map_unref(elem
->sk
, elem
);
931 sock_hash_free_elem(htab
, elem
);
934 raw_spin_unlock_bh(&bucket
->lock
);
938 static struct bpf_shtab_elem
*sock_hash_alloc_elem(struct bpf_shtab
*htab
,
939 void *key
, u32 key_size
,
940 u32 hash
, struct sock
*sk
,
941 struct bpf_shtab_elem
*old
)
943 struct bpf_shtab_elem
*new;
945 if (atomic_inc_return(&htab
->count
) > htab
->map
.max_entries
) {
947 atomic_dec(&htab
->count
);
948 return ERR_PTR(-E2BIG
);
952 new = bpf_map_kmalloc_node(&htab
->map
, htab
->elem_size
,
953 GFP_ATOMIC
| __GFP_NOWARN
,
954 htab
->map
.numa_node
);
956 atomic_dec(&htab
->count
);
957 return ERR_PTR(-ENOMEM
);
959 memcpy(new->key
, key
, key_size
);
965 static int sock_hash_update_common(struct bpf_map
*map
, void *key
,
966 struct sock
*sk
, u64 flags
)
968 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
969 u32 key_size
= map
->key_size
, hash
;
970 struct bpf_shtab_elem
*elem
, *elem_new
;
971 struct bpf_shtab_bucket
*bucket
;
972 struct sk_psock_link
*link
;
973 struct sk_psock
*psock
;
976 WARN_ON_ONCE(!rcu_read_lock_held());
977 if (unlikely(flags
> BPF_EXIST
))
980 link
= sk_psock_init_link();
984 ret
= sock_map_link(map
, sk
);
988 psock
= sk_psock(sk
);
989 WARN_ON_ONCE(!psock
);
991 hash
= sock_hash_bucket_hash(key
, key_size
);
992 bucket
= sock_hash_select_bucket(htab
, hash
);
994 raw_spin_lock_bh(&bucket
->lock
);
995 elem
= sock_hash_lookup_elem_raw(&bucket
->head
, hash
, key
, key_size
);
996 if (elem
&& flags
== BPF_NOEXIST
) {
999 } else if (!elem
&& flags
== BPF_EXIST
) {
1004 elem_new
= sock_hash_alloc_elem(htab
, key
, key_size
, hash
, sk
, elem
);
1005 if (IS_ERR(elem_new
)) {
1006 ret
= PTR_ERR(elem_new
);
1010 sock_map_add_link(psock
, link
, map
, elem_new
);
1011 /* Add new element to the head of the list, so that
1012 * concurrent search will find it before old elem.
1014 hlist_add_head_rcu(&elem_new
->node
, &bucket
->head
);
1016 hlist_del_rcu(&elem
->node
);
1017 sock_map_unref(elem
->sk
, elem
);
1018 sock_hash_free_elem(htab
, elem
);
1020 raw_spin_unlock_bh(&bucket
->lock
);
1023 raw_spin_unlock_bh(&bucket
->lock
);
1024 sk_psock_put(sk
, psock
);
1026 sk_psock_free_link(link
);
1030 static int sock_hash_get_next_key(struct bpf_map
*map
, void *key
,
1033 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1034 struct bpf_shtab_elem
*elem
, *elem_next
;
1035 u32 hash
, key_size
= map
->key_size
;
1036 struct hlist_head
*head
;
1040 goto find_first_elem
;
1041 hash
= sock_hash_bucket_hash(key
, key_size
);
1042 head
= &sock_hash_select_bucket(htab
, hash
)->head
;
1043 elem
= sock_hash_lookup_elem_raw(head
, hash
, key
, key_size
);
1045 goto find_first_elem
;
1047 elem_next
= hlist_entry_safe(rcu_dereference(hlist_next_rcu(&elem
->node
)),
1048 struct bpf_shtab_elem
, node
);
1050 memcpy(key_next
, elem_next
->key
, key_size
);
1054 i
= hash
& (htab
->buckets_num
- 1);
1057 for (; i
< htab
->buckets_num
; i
++) {
1058 head
= &sock_hash_select_bucket(htab
, i
)->head
;
1059 elem_next
= hlist_entry_safe(rcu_dereference(hlist_first_rcu(head
)),
1060 struct bpf_shtab_elem
, node
);
1062 memcpy(key_next
, elem_next
->key
, key_size
);
1070 static struct bpf_map
*sock_hash_alloc(union bpf_attr
*attr
)
1072 struct bpf_shtab
*htab
;
1075 if (!capable(CAP_NET_ADMIN
))
1076 return ERR_PTR(-EPERM
);
1077 if (attr
->max_entries
== 0 ||
1078 attr
->key_size
== 0 ||
1079 (attr
->value_size
!= sizeof(u32
) &&
1080 attr
->value_size
!= sizeof(u64
)) ||
1081 attr
->map_flags
& ~SOCK_CREATE_FLAG_MASK
)
1082 return ERR_PTR(-EINVAL
);
1083 if (attr
->key_size
> MAX_BPF_STACK
)
1084 return ERR_PTR(-E2BIG
);
1086 htab
= kzalloc(sizeof(*htab
), GFP_USER
| __GFP_ACCOUNT
);
1088 return ERR_PTR(-ENOMEM
);
1090 bpf_map_init_from_attr(&htab
->map
, attr
);
1092 htab
->buckets_num
= roundup_pow_of_two(htab
->map
.max_entries
);
1093 htab
->elem_size
= sizeof(struct bpf_shtab_elem
) +
1094 round_up(htab
->map
.key_size
, 8);
1095 if (htab
->buckets_num
== 0 ||
1096 htab
->buckets_num
> U32_MAX
/ sizeof(struct bpf_shtab_bucket
)) {
1101 htab
->buckets
= bpf_map_area_alloc(htab
->buckets_num
*
1102 sizeof(struct bpf_shtab_bucket
),
1103 htab
->map
.numa_node
);
1104 if (!htab
->buckets
) {
1109 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1110 INIT_HLIST_HEAD(&htab
->buckets
[i
].head
);
1111 raw_spin_lock_init(&htab
->buckets
[i
].lock
);
1117 return ERR_PTR(err
);
1120 static void sock_hash_free(struct bpf_map
*map
)
1122 struct bpf_shtab
*htab
= container_of(map
, struct bpf_shtab
, map
);
1123 struct bpf_shtab_bucket
*bucket
;
1124 struct hlist_head unlink_list
;
1125 struct bpf_shtab_elem
*elem
;
1126 struct hlist_node
*node
;
1129 /* After the sync no updates or deletes will be in-flight so it
1130 * is safe to walk map and remove entries without risking a race
1131 * in EEXIST update case.
1134 for (i
= 0; i
< htab
->buckets_num
; i
++) {
1135 bucket
= sock_hash_select_bucket(htab
, i
);
1137 /* We are racing with sock_hash_delete_from_link to
1138 * enter the spin-lock critical section. Every socket on
1139 * the list is still linked to sockhash. Since link
1140 * exists, psock exists and holds a ref to socket. That
1141 * lets us to grab a socket ref too.
1143 raw_spin_lock_bh(&bucket
->lock
);
1144 hlist_for_each_entry(elem
, &bucket
->head
, node
)
1145 sock_hold(elem
->sk
);
1146 hlist_move_list(&bucket
->head
, &unlink_list
);
1147 raw_spin_unlock_bh(&bucket
->lock
);
1149 /* Process removed entries out of atomic context to
1150 * block for socket lock before deleting the psock's
1153 hlist_for_each_entry_safe(elem
, node
, &unlink_list
, node
) {
1154 hlist_del(&elem
->node
);
1155 lock_sock(elem
->sk
);
1157 sock_map_unref(elem
->sk
, elem
);
1159 release_sock(elem
->sk
);
1161 sock_hash_free_elem(htab
, elem
);
1165 /* wait for psock readers accessing its map link */
1168 bpf_map_area_free(htab
->buckets
);
1172 static void *sock_hash_lookup_sys(struct bpf_map
*map
, void *key
)
1176 if (map
->value_size
!= sizeof(u64
))
1177 return ERR_PTR(-ENOSPC
);
1179 sk
= __sock_hash_lookup_elem(map
, key
);
1181 return ERR_PTR(-ENOENT
);
1183 __sock_gen_cookie(sk
);
1184 return &sk
->sk_cookie
;
1187 static void *sock_hash_lookup(struct bpf_map
*map
, void *key
)
1191 sk
= __sock_hash_lookup_elem(map
, key
);
1194 if (sk_is_refcounted(sk
) && !refcount_inc_not_zero(&sk
->sk_refcnt
))
1199 static void sock_hash_release_progs(struct bpf_map
*map
)
1201 psock_progs_drop(&container_of(map
, struct bpf_shtab
, map
)->progs
);
1204 BPF_CALL_4(bpf_sock_hash_update
, struct bpf_sock_ops_kern
*, sops
,
1205 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1207 WARN_ON_ONCE(!rcu_read_lock_held());
1209 if (likely(sock_map_sk_is_suitable(sops
->sk
) &&
1210 sock_map_op_okay(sops
)))
1211 return sock_hash_update_common(map
, key
, sops
->sk
, flags
);
1215 const struct bpf_func_proto bpf_sock_hash_update_proto
= {
1216 .func
= bpf_sock_hash_update
,
1219 .ret_type
= RET_INTEGER
,
1220 .arg1_type
= ARG_PTR_TO_CTX
,
1221 .arg2_type
= ARG_CONST_MAP_PTR
,
1222 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1223 .arg4_type
= ARG_ANYTHING
,
1226 BPF_CALL_4(bpf_sk_redirect_hash
, struct sk_buff
*, skb
,
1227 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1231 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1234 sk
= __sock_hash_lookup_elem(map
, key
);
1235 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1238 skb_bpf_set_redir(skb
, sk
, flags
& BPF_F_INGRESS
);
1242 const struct bpf_func_proto bpf_sk_redirect_hash_proto
= {
1243 .func
= bpf_sk_redirect_hash
,
1245 .ret_type
= RET_INTEGER
,
1246 .arg1_type
= ARG_PTR_TO_CTX
,
1247 .arg2_type
= ARG_CONST_MAP_PTR
,
1248 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1249 .arg4_type
= ARG_ANYTHING
,
1252 BPF_CALL_4(bpf_msg_redirect_hash
, struct sk_msg
*, msg
,
1253 struct bpf_map
*, map
, void *, key
, u64
, flags
)
1257 if (unlikely(flags
& ~(BPF_F_INGRESS
)))
1260 sk
= __sock_hash_lookup_elem(map
, key
);
1261 if (unlikely(!sk
|| !sock_map_redirect_allowed(sk
)))
1269 const struct bpf_func_proto bpf_msg_redirect_hash_proto
= {
1270 .func
= bpf_msg_redirect_hash
,
1272 .ret_type
= RET_INTEGER
,
1273 .arg1_type
= ARG_PTR_TO_CTX
,
1274 .arg2_type
= ARG_CONST_MAP_PTR
,
1275 .arg3_type
= ARG_PTR_TO_MAP_KEY
,
1276 .arg4_type
= ARG_ANYTHING
,
1279 struct sock_hash_seq_info
{
1280 struct bpf_map
*map
;
1281 struct bpf_shtab
*htab
;
1285 static void *sock_hash_seq_find_next(struct sock_hash_seq_info
*info
,
1286 struct bpf_shtab_elem
*prev_elem
)
1288 const struct bpf_shtab
*htab
= info
->htab
;
1289 struct bpf_shtab_bucket
*bucket
;
1290 struct bpf_shtab_elem
*elem
;
1291 struct hlist_node
*node
;
1293 /* try to find next elem in the same bucket */
1295 node
= rcu_dereference(hlist_next_rcu(&prev_elem
->node
));
1296 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1300 /* no more elements, continue in the next bucket */
1304 for (; info
->bucket_id
< htab
->buckets_num
; info
->bucket_id
++) {
1305 bucket
= &htab
->buckets
[info
->bucket_id
];
1306 node
= rcu_dereference(hlist_first_rcu(&bucket
->head
));
1307 elem
= hlist_entry_safe(node
, struct bpf_shtab_elem
, node
);
1315 static void *sock_hash_seq_start(struct seq_file
*seq
, loff_t
*pos
)
1318 struct sock_hash_seq_info
*info
= seq
->private;
1323 /* pairs with sock_hash_seq_stop */
1325 return sock_hash_seq_find_next(info
, NULL
);
1328 static void *sock_hash_seq_next(struct seq_file
*seq
, void *v
, loff_t
*pos
)
1331 struct sock_hash_seq_info
*info
= seq
->private;
1334 return sock_hash_seq_find_next(info
, v
);
1337 static int sock_hash_seq_show(struct seq_file
*seq
, void *v
)
1340 struct sock_hash_seq_info
*info
= seq
->private;
1341 struct bpf_iter__sockmap ctx
= {};
1342 struct bpf_shtab_elem
*elem
= v
;
1343 struct bpf_iter_meta meta
;
1344 struct bpf_prog
*prog
;
1347 prog
= bpf_iter_get_info(&meta
, !elem
);
1352 ctx
.map
= info
->map
;
1354 ctx
.key
= elem
->key
;
1358 return bpf_iter_run_prog(prog
, &ctx
);
1361 static void sock_hash_seq_stop(struct seq_file
*seq
, void *v
)
1365 (void)sock_hash_seq_show(seq
, NULL
);
1367 /* pairs with sock_hash_seq_start */
1371 static const struct seq_operations sock_hash_seq_ops
= {
1372 .start
= sock_hash_seq_start
,
1373 .next
= sock_hash_seq_next
,
1374 .stop
= sock_hash_seq_stop
,
1375 .show
= sock_hash_seq_show
,
1378 static int sock_hash_init_seq_private(void *priv_data
,
1379 struct bpf_iter_aux_info
*aux
)
1381 struct sock_hash_seq_info
*info
= priv_data
;
1383 info
->map
= aux
->map
;
1384 info
->htab
= container_of(aux
->map
, struct bpf_shtab
, map
);
1388 static const struct bpf_iter_seq_info sock_hash_iter_seq_info
= {
1389 .seq_ops
= &sock_hash_seq_ops
,
1390 .init_seq_private
= sock_hash_init_seq_private
,
1391 .seq_priv_size
= sizeof(struct sock_hash_seq_info
),
1394 static int sock_hash_map_btf_id
;
1395 const struct bpf_map_ops sock_hash_ops
= {
1396 .map_meta_equal
= bpf_map_meta_equal
,
1397 .map_alloc
= sock_hash_alloc
,
1398 .map_free
= sock_hash_free
,
1399 .map_get_next_key
= sock_hash_get_next_key
,
1400 .map_update_elem
= sock_map_update_elem
,
1401 .map_delete_elem
= sock_hash_delete_elem
,
1402 .map_lookup_elem
= sock_hash_lookup
,
1403 .map_lookup_elem_sys_only
= sock_hash_lookup_sys
,
1404 .map_release_uref
= sock_hash_release_progs
,
1405 .map_check_btf
= map_check_no_btf
,
1406 .map_btf_name
= "bpf_shtab",
1407 .map_btf_id
= &sock_hash_map_btf_id
,
1408 .iter_seq_info
= &sock_hash_iter_seq_info
,
1411 static struct sk_psock_progs
*sock_map_progs(struct bpf_map
*map
)
1413 switch (map
->map_type
) {
1414 case BPF_MAP_TYPE_SOCKMAP
:
1415 return &container_of(map
, struct bpf_stab
, map
)->progs
;
1416 case BPF_MAP_TYPE_SOCKHASH
:
1417 return &container_of(map
, struct bpf_shtab
, map
)->progs
;
1425 static int sock_map_prog_update(struct bpf_map
*map
, struct bpf_prog
*prog
,
1426 struct bpf_prog
*old
, u32 which
)
1428 struct sk_psock_progs
*progs
= sock_map_progs(map
);
1429 struct bpf_prog
**pprog
;
1435 case BPF_SK_MSG_VERDICT
:
1436 pprog
= &progs
->msg_parser
;
1438 #if IS_ENABLED(CONFIG_BPF_STREAM_PARSER)
1439 case BPF_SK_SKB_STREAM_PARSER
:
1440 pprog
= &progs
->stream_parser
;
1443 case BPF_SK_SKB_STREAM_VERDICT
:
1444 if (progs
->skb_verdict
)
1446 pprog
= &progs
->stream_verdict
;
1448 case BPF_SK_SKB_VERDICT
:
1449 if (progs
->stream_verdict
)
1451 pprog
= &progs
->skb_verdict
;
1458 return psock_replace_prog(pprog
, prog
, old
);
1460 psock_set_prog(pprog
, prog
);
1464 static void sock_map_unlink(struct sock
*sk
, struct sk_psock_link
*link
)
1466 switch (link
->map
->map_type
) {
1467 case BPF_MAP_TYPE_SOCKMAP
:
1468 return sock_map_delete_from_link(link
->map
, sk
,
1470 case BPF_MAP_TYPE_SOCKHASH
:
1471 return sock_hash_delete_from_link(link
->map
, sk
,
1478 static void sock_map_remove_links(struct sock
*sk
, struct sk_psock
*psock
)
1480 struct sk_psock_link
*link
;
1482 while ((link
= sk_psock_link_pop(psock
))) {
1483 sock_map_unlink(sk
, link
);
1484 sk_psock_free_link(link
);
1488 void sock_map_unhash(struct sock
*sk
)
1490 void (*saved_unhash
)(struct sock
*sk
);
1491 struct sk_psock
*psock
;
1494 psock
= sk_psock(sk
);
1495 if (unlikely(!psock
)) {
1497 if (sk
->sk_prot
->unhash
)
1498 sk
->sk_prot
->unhash(sk
);
1502 saved_unhash
= psock
->saved_unhash
;
1503 sock_map_remove_links(sk
, psock
);
1507 EXPORT_SYMBOL_GPL(sock_map_unhash
);
1509 void sock_map_close(struct sock
*sk
, long timeout
)
1511 void (*saved_close
)(struct sock
*sk
, long timeout
);
1512 struct sk_psock
*psock
;
1516 psock
= sk_psock_get(sk
);
1517 if (unlikely(!psock
)) {
1520 return sk
->sk_prot
->close(sk
, timeout
);
1523 saved_close
= psock
->saved_close
;
1524 sock_map_remove_links(sk
, psock
);
1526 sk_psock_stop(psock
, true);
1527 sk_psock_put(sk
, psock
);
1529 saved_close(sk
, timeout
);
1531 EXPORT_SYMBOL_GPL(sock_map_close
);
1533 static int sock_map_iter_attach_target(struct bpf_prog
*prog
,
1534 union bpf_iter_link_info
*linfo
,
1535 struct bpf_iter_aux_info
*aux
)
1537 struct bpf_map
*map
;
1540 if (!linfo
->map
.map_fd
)
1543 map
= bpf_map_get_with_uref(linfo
->map
.map_fd
);
1545 return PTR_ERR(map
);
1547 if (map
->map_type
!= BPF_MAP_TYPE_SOCKMAP
&&
1548 map
->map_type
!= BPF_MAP_TYPE_SOCKHASH
)
1551 if (prog
->aux
->max_rdonly_access
> map
->key_size
) {
1560 bpf_map_put_with_uref(map
);
1564 static void sock_map_iter_detach_target(struct bpf_iter_aux_info
*aux
)
1566 bpf_map_put_with_uref(aux
->map
);
1569 static struct bpf_iter_reg sock_map_iter_reg
= {
1570 .target
= "sockmap",
1571 .attach_target
= sock_map_iter_attach_target
,
1572 .detach_target
= sock_map_iter_detach_target
,
1573 .show_fdinfo
= bpf_iter_map_show_fdinfo
,
1574 .fill_link_info
= bpf_iter_map_fill_link_info
,
1575 .ctx_arg_info_size
= 2,
1577 { offsetof(struct bpf_iter__sockmap
, key
),
1578 PTR_TO_RDONLY_BUF_OR_NULL
},
1579 { offsetof(struct bpf_iter__sockmap
, sk
),
1580 PTR_TO_BTF_ID_OR_NULL
},
1584 static int __init
bpf_sockmap_iter_init(void)
1586 sock_map_iter_reg
.ctx_arg_info
[1].btf_id
=
1587 btf_sock_ids
[BTF_SOCK_TYPE_SOCK
];
1588 return bpf_iter_reg_target(&sock_map_iter_reg
);
1590 late_initcall(bpf_sockmap_iter_init
);