]> git.proxmox.com Git - mirror_ubuntu-eoan-kernel.git/blob - kernel/bpf/sockmap.c
Merge tag 'mac80211-next-for-davem-2018-03-29' of git://git.kernel.org/pub/scm/linux...
[mirror_ubuntu-eoan-kernel.git] / kernel / bpf / sockmap.c
1 /* Copyright (c) 2017 Covalent IO, Inc. http://covalent.io
2 *
3 * This program is free software; you can redistribute it and/or
4 * modify it under the terms of version 2 of the GNU General Public
5 * License as published by the Free Software Foundation.
6 *
7 * This program is distributed in the hope that it will be useful, but
8 * WITHOUT ANY WARRANTY; without even the implied warranty of
9 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
10 * General Public License for more details.
11 */
12
13 /* A BPF sock_map is used to store sock objects. This is primarly used
14 * for doing socket redirect with BPF helper routines.
15 *
16 * A sock map may have BPF programs attached to it, currently a program
17 * used to parse packets and a program to provide a verdict and redirect
18 * decision on the packet are supported. Any programs attached to a sock
19 * map are inherited by sock objects when they are added to the map. If
20 * no BPF programs are attached the sock object may only be used for sock
21 * redirect.
22 *
23 * A sock object may be in multiple maps, but can only inherit a single
24 * parse or verdict program. If adding a sock object to a map would result
25 * in having multiple parsing programs the update will return an EBUSY error.
26 *
27 * For reference this program is similar to devmap used in XDP context
28 * reviewing these together may be useful. For an example please review
29 * ./samples/bpf/sockmap/.
30 */
31 #include <linux/bpf.h>
32 #include <net/sock.h>
33 #include <linux/filter.h>
34 #include <linux/errno.h>
35 #include <linux/file.h>
36 #include <linux/kernel.h>
37 #include <linux/net.h>
38 #include <linux/skbuff.h>
39 #include <linux/workqueue.h>
40 #include <linux/list.h>
41 #include <linux/mm.h>
42 #include <net/strparser.h>
43 #include <net/tcp.h>
44
45 #define SOCK_CREATE_FLAG_MASK \
46 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
47
48 struct bpf_stab {
49 struct bpf_map map;
50 struct sock **sock_map;
51 struct bpf_prog *bpf_tx_msg;
52 struct bpf_prog *bpf_parse;
53 struct bpf_prog *bpf_verdict;
54 };
55
56 enum smap_psock_state {
57 SMAP_TX_RUNNING,
58 };
59
60 struct smap_psock_map_entry {
61 struct list_head list;
62 struct sock **entry;
63 };
64
65 struct smap_psock {
66 struct rcu_head rcu;
67 refcount_t refcnt;
68
69 /* datapath variables */
70 struct sk_buff_head rxqueue;
71 bool strp_enabled;
72
73 /* datapath error path cache across tx work invocations */
74 int save_rem;
75 int save_off;
76 struct sk_buff *save_skb;
77
78 /* datapath variables for tx_msg ULP */
79 struct sock *sk_redir;
80 int apply_bytes;
81 int cork_bytes;
82 int sg_size;
83 int eval;
84 struct sk_msg_buff *cork;
85
86 struct strparser strp;
87 struct bpf_prog *bpf_tx_msg;
88 struct bpf_prog *bpf_parse;
89 struct bpf_prog *bpf_verdict;
90 struct list_head maps;
91
92 /* Back reference used when sock callback trigger sockmap operations */
93 struct sock *sock;
94 unsigned long state;
95
96 struct work_struct tx_work;
97 struct work_struct gc_work;
98
99 struct proto *sk_proto;
100 void (*save_close)(struct sock *sk, long timeout);
101 void (*save_data_ready)(struct sock *sk);
102 void (*save_write_space)(struct sock *sk);
103 };
104
105 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
106 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size);
107 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
108 int offset, size_t size, int flags);
109
110 static inline struct smap_psock *smap_psock_sk(const struct sock *sk)
111 {
112 return rcu_dereference_sk_user_data(sk);
113 }
114
115 static struct proto tcp_bpf_proto;
116 static int bpf_tcp_init(struct sock *sk)
117 {
118 struct smap_psock *psock;
119
120 rcu_read_lock();
121 psock = smap_psock_sk(sk);
122 if (unlikely(!psock)) {
123 rcu_read_unlock();
124 return -EINVAL;
125 }
126
127 if (unlikely(psock->sk_proto)) {
128 rcu_read_unlock();
129 return -EBUSY;
130 }
131
132 psock->save_close = sk->sk_prot->close;
133 psock->sk_proto = sk->sk_prot;
134
135 if (psock->bpf_tx_msg) {
136 tcp_bpf_proto.sendmsg = bpf_tcp_sendmsg;
137 tcp_bpf_proto.sendpage = bpf_tcp_sendpage;
138 }
139
140 sk->sk_prot = &tcp_bpf_proto;
141 rcu_read_unlock();
142 return 0;
143 }
144
145 static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
146 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
147
148 static void bpf_tcp_release(struct sock *sk)
149 {
150 struct smap_psock *psock;
151
152 rcu_read_lock();
153 psock = smap_psock_sk(sk);
154 if (unlikely(!psock))
155 goto out;
156
157 if (psock->cork) {
158 free_start_sg(psock->sock, psock->cork);
159 kfree(psock->cork);
160 psock->cork = NULL;
161 }
162
163 sk->sk_prot = psock->sk_proto;
164 psock->sk_proto = NULL;
165 out:
166 rcu_read_unlock();
167 }
168
169 static void bpf_tcp_close(struct sock *sk, long timeout)
170 {
171 void (*close_fun)(struct sock *sk, long timeout);
172 struct smap_psock_map_entry *e, *tmp;
173 struct smap_psock *psock;
174 struct sock *osk;
175
176 rcu_read_lock();
177 psock = smap_psock_sk(sk);
178 if (unlikely(!psock)) {
179 rcu_read_unlock();
180 return sk->sk_prot->close(sk, timeout);
181 }
182
183 /* The psock may be destroyed anytime after exiting the RCU critial
184 * section so by the time we use close_fun the psock may no longer
185 * be valid. However, bpf_tcp_close is called with the sock lock
186 * held so the close hook and sk are still valid.
187 */
188 close_fun = psock->save_close;
189
190 write_lock_bh(&sk->sk_callback_lock);
191 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
192 osk = cmpxchg(e->entry, sk, NULL);
193 if (osk == sk) {
194 list_del(&e->list);
195 smap_release_sock(psock, sk);
196 }
197 }
198 write_unlock_bh(&sk->sk_callback_lock);
199 rcu_read_unlock();
200 close_fun(sk, timeout);
201 }
202
203 enum __sk_action {
204 __SK_DROP = 0,
205 __SK_PASS,
206 __SK_REDIRECT,
207 __SK_NONE,
208 };
209
210 static struct tcp_ulp_ops bpf_tcp_ulp_ops __read_mostly = {
211 .name = "bpf_tcp",
212 .uid = TCP_ULP_BPF,
213 .user_visible = false,
214 .owner = NULL,
215 .init = bpf_tcp_init,
216 .release = bpf_tcp_release,
217 };
218
219 static int memcopy_from_iter(struct sock *sk,
220 struct sk_msg_buff *md,
221 struct iov_iter *from, int bytes)
222 {
223 struct scatterlist *sg = md->sg_data;
224 int i = md->sg_curr, rc = -ENOSPC;
225
226 do {
227 int copy;
228 char *to;
229
230 if (md->sg_copybreak >= sg[i].length) {
231 md->sg_copybreak = 0;
232
233 if (++i == MAX_SKB_FRAGS)
234 i = 0;
235
236 if (i == md->sg_end)
237 break;
238 }
239
240 copy = sg[i].length - md->sg_copybreak;
241 to = sg_virt(&sg[i]) + md->sg_copybreak;
242 md->sg_copybreak += copy;
243
244 if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
245 rc = copy_from_iter_nocache(to, copy, from);
246 else
247 rc = copy_from_iter(to, copy, from);
248
249 if (rc != copy) {
250 rc = -EFAULT;
251 goto out;
252 }
253
254 bytes -= copy;
255 if (!bytes)
256 break;
257
258 md->sg_copybreak = 0;
259 if (++i == MAX_SKB_FRAGS)
260 i = 0;
261 } while (i != md->sg_end);
262 out:
263 md->sg_curr = i;
264 return rc;
265 }
266
267 static int bpf_tcp_push(struct sock *sk, int apply_bytes,
268 struct sk_msg_buff *md,
269 int flags, bool uncharge)
270 {
271 bool apply = apply_bytes;
272 struct scatterlist *sg;
273 int offset, ret = 0;
274 struct page *p;
275 size_t size;
276
277 while (1) {
278 sg = md->sg_data + md->sg_start;
279 size = (apply && apply_bytes < sg->length) ?
280 apply_bytes : sg->length;
281 offset = sg->offset;
282
283 tcp_rate_check_app_limited(sk);
284 p = sg_page(sg);
285 retry:
286 ret = do_tcp_sendpages(sk, p, offset, size, flags);
287 if (ret != size) {
288 if (ret > 0) {
289 if (apply)
290 apply_bytes -= ret;
291 size -= ret;
292 offset += ret;
293 if (uncharge)
294 sk_mem_uncharge(sk, ret);
295 goto retry;
296 }
297
298 sg->length = size;
299 sg->offset = offset;
300 return ret;
301 }
302
303 if (apply)
304 apply_bytes -= ret;
305 sg->offset += ret;
306 sg->length -= ret;
307 if (uncharge)
308 sk_mem_uncharge(sk, ret);
309
310 if (!sg->length) {
311 put_page(p);
312 md->sg_start++;
313 if (md->sg_start == MAX_SKB_FRAGS)
314 md->sg_start = 0;
315 memset(sg, 0, sizeof(*sg));
316
317 if (md->sg_start == md->sg_end)
318 break;
319 }
320
321 if (apply && !apply_bytes)
322 break;
323 }
324 return 0;
325 }
326
327 static inline void bpf_compute_data_pointers_sg(struct sk_msg_buff *md)
328 {
329 struct scatterlist *sg = md->sg_data + md->sg_start;
330
331 if (md->sg_copy[md->sg_start]) {
332 md->data = md->data_end = 0;
333 } else {
334 md->data = sg_virt(sg);
335 md->data_end = md->data + sg->length;
336 }
337 }
338
339 static void return_mem_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
340 {
341 struct scatterlist *sg = md->sg_data;
342 int i = md->sg_start;
343
344 do {
345 int uncharge = (bytes < sg[i].length) ? bytes : sg[i].length;
346
347 sk_mem_uncharge(sk, uncharge);
348 bytes -= uncharge;
349 if (!bytes)
350 break;
351 i++;
352 if (i == MAX_SKB_FRAGS)
353 i = 0;
354 } while (i != md->sg_end);
355 }
356
357 static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
358 {
359 struct scatterlist *sg = md->sg_data;
360 int i = md->sg_start, free;
361
362 while (bytes && sg[i].length) {
363 free = sg[i].length;
364 if (bytes < free) {
365 sg[i].length -= bytes;
366 sg[i].offset += bytes;
367 sk_mem_uncharge(sk, bytes);
368 break;
369 }
370
371 sk_mem_uncharge(sk, sg[i].length);
372 put_page(sg_page(&sg[i]));
373 bytes -= sg[i].length;
374 sg[i].length = 0;
375 sg[i].page_link = 0;
376 sg[i].offset = 0;
377 i++;
378
379 if (i == MAX_SKB_FRAGS)
380 i = 0;
381 }
382 }
383
384 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
385 {
386 struct scatterlist *sg = md->sg_data;
387 int i = start, free = 0;
388
389 while (sg[i].length) {
390 free += sg[i].length;
391 sk_mem_uncharge(sk, sg[i].length);
392 put_page(sg_page(&sg[i]));
393 sg[i].length = 0;
394 sg[i].page_link = 0;
395 sg[i].offset = 0;
396 i++;
397
398 if (i == MAX_SKB_FRAGS)
399 i = 0;
400 }
401
402 return free;
403 }
404
405 static int free_start_sg(struct sock *sk, struct sk_msg_buff *md)
406 {
407 int free = free_sg(sk, md->sg_start, md);
408
409 md->sg_start = md->sg_end;
410 return free;
411 }
412
413 static int free_curr_sg(struct sock *sk, struct sk_msg_buff *md)
414 {
415 return free_sg(sk, md->sg_curr, md);
416 }
417
418 static int bpf_map_msg_verdict(int _rc, struct sk_msg_buff *md)
419 {
420 return ((_rc == SK_PASS) ?
421 (md->map ? __SK_REDIRECT : __SK_PASS) :
422 __SK_DROP);
423 }
424
425 static unsigned int smap_do_tx_msg(struct sock *sk,
426 struct smap_psock *psock,
427 struct sk_msg_buff *md)
428 {
429 struct bpf_prog *prog;
430 unsigned int rc, _rc;
431
432 preempt_disable();
433 rcu_read_lock();
434
435 /* If the policy was removed mid-send then default to 'accept' */
436 prog = READ_ONCE(psock->bpf_tx_msg);
437 if (unlikely(!prog)) {
438 _rc = SK_PASS;
439 goto verdict;
440 }
441
442 bpf_compute_data_pointers_sg(md);
443 rc = (*prog->bpf_func)(md, prog->insnsi);
444 psock->apply_bytes = md->apply_bytes;
445
446 /* Moving return codes from UAPI namespace into internal namespace */
447 _rc = bpf_map_msg_verdict(rc, md);
448
449 /* The psock has a refcount on the sock but not on the map and because
450 * we need to drop rcu read lock here its possible the map could be
451 * removed between here and when we need it to execute the sock
452 * redirect. So do the map lookup now for future use.
453 */
454 if (_rc == __SK_REDIRECT) {
455 if (psock->sk_redir)
456 sock_put(psock->sk_redir);
457 psock->sk_redir = do_msg_redirect_map(md);
458 if (!psock->sk_redir) {
459 _rc = __SK_DROP;
460 goto verdict;
461 }
462 sock_hold(psock->sk_redir);
463 }
464 verdict:
465 rcu_read_unlock();
466 preempt_enable();
467
468 return _rc;
469 }
470
471 static int bpf_tcp_sendmsg_do_redirect(struct sock *sk, int send,
472 struct sk_msg_buff *md,
473 int flags)
474 {
475 struct smap_psock *psock;
476 struct scatterlist *sg;
477 int i, err, free = 0;
478
479 sg = md->sg_data;
480
481 rcu_read_lock();
482 psock = smap_psock_sk(sk);
483 if (unlikely(!psock))
484 goto out_rcu;
485
486 if (!refcount_inc_not_zero(&psock->refcnt))
487 goto out_rcu;
488
489 rcu_read_unlock();
490 lock_sock(sk);
491 err = bpf_tcp_push(sk, send, md, flags, false);
492 release_sock(sk);
493 smap_release_sock(psock, sk);
494 if (unlikely(err))
495 goto out;
496 return 0;
497 out_rcu:
498 rcu_read_unlock();
499 out:
500 i = md->sg_start;
501 while (sg[i].length) {
502 free += sg[i].length;
503 put_page(sg_page(&sg[i]));
504 sg[i].length = 0;
505 i++;
506 if (i == MAX_SKB_FRAGS)
507 i = 0;
508 }
509 return free;
510 }
511
512 static inline void bpf_md_init(struct smap_psock *psock)
513 {
514 if (!psock->apply_bytes) {
515 psock->eval = __SK_NONE;
516 if (psock->sk_redir) {
517 sock_put(psock->sk_redir);
518 psock->sk_redir = NULL;
519 }
520 }
521 }
522
523 static void apply_bytes_dec(struct smap_psock *psock, int i)
524 {
525 if (psock->apply_bytes) {
526 if (psock->apply_bytes < i)
527 psock->apply_bytes = 0;
528 else
529 psock->apply_bytes -= i;
530 }
531 }
532
533 static int bpf_exec_tx_verdict(struct smap_psock *psock,
534 struct sk_msg_buff *m,
535 struct sock *sk,
536 int *copied, int flags)
537 {
538 bool cork = false, enospc = (m->sg_start == m->sg_end);
539 struct sock *redir;
540 int err = 0;
541 int send;
542
543 more_data:
544 if (psock->eval == __SK_NONE)
545 psock->eval = smap_do_tx_msg(sk, psock, m);
546
547 if (m->cork_bytes &&
548 m->cork_bytes > psock->sg_size && !enospc) {
549 psock->cork_bytes = m->cork_bytes - psock->sg_size;
550 if (!psock->cork) {
551 psock->cork = kcalloc(1,
552 sizeof(struct sk_msg_buff),
553 GFP_ATOMIC | __GFP_NOWARN);
554
555 if (!psock->cork) {
556 err = -ENOMEM;
557 goto out_err;
558 }
559 }
560 memcpy(psock->cork, m, sizeof(*m));
561 goto out_err;
562 }
563
564 send = psock->sg_size;
565 if (psock->apply_bytes && psock->apply_bytes < send)
566 send = psock->apply_bytes;
567
568 switch (psock->eval) {
569 case __SK_PASS:
570 err = bpf_tcp_push(sk, send, m, flags, true);
571 if (unlikely(err)) {
572 *copied -= free_start_sg(sk, m);
573 break;
574 }
575
576 apply_bytes_dec(psock, send);
577 psock->sg_size -= send;
578 break;
579 case __SK_REDIRECT:
580 redir = psock->sk_redir;
581 apply_bytes_dec(psock, send);
582
583 if (psock->cork) {
584 cork = true;
585 psock->cork = NULL;
586 }
587
588 return_mem_sg(sk, send, m);
589 release_sock(sk);
590
591 err = bpf_tcp_sendmsg_do_redirect(redir, send, m, flags);
592 lock_sock(sk);
593
594 if (cork) {
595 free_start_sg(sk, m);
596 kfree(m);
597 m = NULL;
598 }
599 if (unlikely(err))
600 *copied -= err;
601 else
602 psock->sg_size -= send;
603 break;
604 case __SK_DROP:
605 default:
606 free_bytes_sg(sk, send, m);
607 apply_bytes_dec(psock, send);
608 *copied -= send;
609 psock->sg_size -= send;
610 err = -EACCES;
611 break;
612 }
613
614 if (likely(!err)) {
615 bpf_md_init(psock);
616 if (m &&
617 m->sg_data[m->sg_start].page_link &&
618 m->sg_data[m->sg_start].length)
619 goto more_data;
620 }
621
622 out_err:
623 return err;
624 }
625
626 static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
627 {
628 int flags = msg->msg_flags | MSG_NO_SHARED_FRAGS;
629 struct sk_msg_buff md = {0};
630 unsigned int sg_copy = 0;
631 struct smap_psock *psock;
632 int copied = 0, err = 0;
633 struct scatterlist *sg;
634 long timeo;
635
636 /* Its possible a sock event or user removed the psock _but_ the ops
637 * have not been reprogrammed yet so we get here. In this case fallback
638 * to tcp_sendmsg. Note this only works because we _only_ ever allow
639 * a single ULP there is no hierarchy here.
640 */
641 rcu_read_lock();
642 psock = smap_psock_sk(sk);
643 if (unlikely(!psock)) {
644 rcu_read_unlock();
645 return tcp_sendmsg(sk, msg, size);
646 }
647
648 /* Increment the psock refcnt to ensure its not released while sending a
649 * message. Required because sk lookup and bpf programs are used in
650 * separate rcu critical sections. Its OK if we lose the map entry
651 * but we can't lose the sock reference.
652 */
653 if (!refcount_inc_not_zero(&psock->refcnt)) {
654 rcu_read_unlock();
655 return tcp_sendmsg(sk, msg, size);
656 }
657
658 sg = md.sg_data;
659 sg_init_table(sg, MAX_SKB_FRAGS);
660 rcu_read_unlock();
661
662 lock_sock(sk);
663 timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
664
665 while (msg_data_left(msg)) {
666 struct sk_msg_buff *m;
667 bool enospc = false;
668 int copy;
669
670 if (sk->sk_err) {
671 err = sk->sk_err;
672 goto out_err;
673 }
674
675 copy = msg_data_left(msg);
676 if (!sk_stream_memory_free(sk))
677 goto wait_for_sndbuf;
678
679 m = psock->cork_bytes ? psock->cork : &md;
680 m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
681 err = sk_alloc_sg(sk, copy, m->sg_data,
682 m->sg_start, &m->sg_end, &sg_copy,
683 m->sg_end - 1);
684 if (err) {
685 if (err != -ENOSPC)
686 goto wait_for_memory;
687 enospc = true;
688 copy = sg_copy;
689 }
690
691 err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
692 if (err < 0) {
693 free_curr_sg(sk, m);
694 goto out_err;
695 }
696
697 psock->sg_size += copy;
698 copied += copy;
699 sg_copy = 0;
700
701 /* When bytes are being corked skip running BPF program and
702 * applying verdict unless there is no more buffer space. In
703 * the ENOSPC case simply run BPF prorgram with currently
704 * accumulated data. We don't have much choice at this point
705 * we could try extending the page frags or chaining complex
706 * frags but even in these cases _eventually_ we will hit an
707 * OOM scenario. More complex recovery schemes may be
708 * implemented in the future, but BPF programs must handle
709 * the case where apply_cork requests are not honored. The
710 * canonical method to verify this is to check data length.
711 */
712 if (psock->cork_bytes) {
713 if (copy > psock->cork_bytes)
714 psock->cork_bytes = 0;
715 else
716 psock->cork_bytes -= copy;
717
718 if (psock->cork_bytes && !enospc)
719 goto out_cork;
720
721 /* All cork bytes accounted for re-run filter */
722 psock->eval = __SK_NONE;
723 psock->cork_bytes = 0;
724 }
725
726 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
727 if (unlikely(err < 0))
728 goto out_err;
729 continue;
730 wait_for_sndbuf:
731 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
732 wait_for_memory:
733 err = sk_stream_wait_memory(sk, &timeo);
734 if (err)
735 goto out_err;
736 }
737 out_err:
738 if (err < 0)
739 err = sk_stream_error(sk, msg->msg_flags, err);
740 out_cork:
741 release_sock(sk);
742 smap_release_sock(psock, sk);
743 return copied ? copied : err;
744 }
745
746 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
747 int offset, size_t size, int flags)
748 {
749 struct sk_msg_buff md = {0}, *m = NULL;
750 int err = 0, copied = 0;
751 struct smap_psock *psock;
752 struct scatterlist *sg;
753 bool enospc = false;
754
755 rcu_read_lock();
756 psock = smap_psock_sk(sk);
757 if (unlikely(!psock))
758 goto accept;
759
760 if (!refcount_inc_not_zero(&psock->refcnt))
761 goto accept;
762 rcu_read_unlock();
763
764 lock_sock(sk);
765
766 if (psock->cork_bytes)
767 m = psock->cork;
768 else
769 m = &md;
770
771 /* Catch case where ring is full and sendpage is stalled. */
772 if (unlikely(m->sg_end == m->sg_start &&
773 m->sg_data[m->sg_end].length))
774 goto out_err;
775
776 psock->sg_size += size;
777 sg = &m->sg_data[m->sg_end];
778 sg_set_page(sg, page, size, offset);
779 get_page(page);
780 m->sg_copy[m->sg_end] = true;
781 sk_mem_charge(sk, size);
782 m->sg_end++;
783 copied = size;
784
785 if (m->sg_end == MAX_SKB_FRAGS)
786 m->sg_end = 0;
787
788 if (m->sg_end == m->sg_start)
789 enospc = true;
790
791 if (psock->cork_bytes) {
792 if (size > psock->cork_bytes)
793 psock->cork_bytes = 0;
794 else
795 psock->cork_bytes -= size;
796
797 if (psock->cork_bytes && !enospc)
798 goto out_err;
799
800 /* All cork bytes accounted for re-run filter */
801 psock->eval = __SK_NONE;
802 psock->cork_bytes = 0;
803 }
804
805 err = bpf_exec_tx_verdict(psock, m, sk, &copied, flags);
806 out_err:
807 release_sock(sk);
808 smap_release_sock(psock, sk);
809 return copied ? copied : err;
810 accept:
811 rcu_read_unlock();
812 return tcp_sendpage(sk, page, offset, size, flags);
813 }
814
815 static void bpf_tcp_msg_add(struct smap_psock *psock,
816 struct sock *sk,
817 struct bpf_prog *tx_msg)
818 {
819 struct bpf_prog *orig_tx_msg;
820
821 orig_tx_msg = xchg(&psock->bpf_tx_msg, tx_msg);
822 if (orig_tx_msg)
823 bpf_prog_put(orig_tx_msg);
824 }
825
826 static int bpf_tcp_ulp_register(void)
827 {
828 tcp_bpf_proto = tcp_prot;
829 tcp_bpf_proto.close = bpf_tcp_close;
830 /* Once BPF TX ULP is registered it is never unregistered. It
831 * will be in the ULP list for the lifetime of the system. Doing
832 * duplicate registers is not a problem.
833 */
834 return tcp_register_ulp(&bpf_tcp_ulp_ops);
835 }
836
837 static int smap_verdict_func(struct smap_psock *psock, struct sk_buff *skb)
838 {
839 struct bpf_prog *prog = READ_ONCE(psock->bpf_verdict);
840 int rc;
841
842 if (unlikely(!prog))
843 return __SK_DROP;
844
845 skb_orphan(skb);
846 /* We need to ensure that BPF metadata for maps is also cleared
847 * when we orphan the skb so that we don't have the possibility
848 * to reference a stale map.
849 */
850 TCP_SKB_CB(skb)->bpf.map = NULL;
851 skb->sk = psock->sock;
852 bpf_compute_data_pointers(skb);
853 preempt_disable();
854 rc = (*prog->bpf_func)(skb, prog->insnsi);
855 preempt_enable();
856 skb->sk = NULL;
857
858 /* Moving return codes from UAPI namespace into internal namespace */
859 return rc == SK_PASS ?
860 (TCP_SKB_CB(skb)->bpf.map ? __SK_REDIRECT : __SK_PASS) :
861 __SK_DROP;
862 }
863
864 static void smap_do_verdict(struct smap_psock *psock, struct sk_buff *skb)
865 {
866 struct sock *sk;
867 int rc;
868
869 rc = smap_verdict_func(psock, skb);
870 switch (rc) {
871 case __SK_REDIRECT:
872 sk = do_sk_redirect_map(skb);
873 if (likely(sk)) {
874 struct smap_psock *peer = smap_psock_sk(sk);
875
876 if (likely(peer &&
877 test_bit(SMAP_TX_RUNNING, &peer->state) &&
878 !sock_flag(sk, SOCK_DEAD) &&
879 sock_writeable(sk))) {
880 skb_set_owner_w(skb, sk);
881 skb_queue_tail(&peer->rxqueue, skb);
882 schedule_work(&peer->tx_work);
883 break;
884 }
885 }
886 /* Fall through and free skb otherwise */
887 case __SK_DROP:
888 default:
889 kfree_skb(skb);
890 }
891 }
892
893 static void smap_report_sk_error(struct smap_psock *psock, int err)
894 {
895 struct sock *sk = psock->sock;
896
897 sk->sk_err = err;
898 sk->sk_error_report(sk);
899 }
900
901 static void smap_read_sock_strparser(struct strparser *strp,
902 struct sk_buff *skb)
903 {
904 struct smap_psock *psock;
905
906 rcu_read_lock();
907 psock = container_of(strp, struct smap_psock, strp);
908 smap_do_verdict(psock, skb);
909 rcu_read_unlock();
910 }
911
912 /* Called with lock held on socket */
913 static void smap_data_ready(struct sock *sk)
914 {
915 struct smap_psock *psock;
916
917 rcu_read_lock();
918 psock = smap_psock_sk(sk);
919 if (likely(psock)) {
920 write_lock_bh(&sk->sk_callback_lock);
921 strp_data_ready(&psock->strp);
922 write_unlock_bh(&sk->sk_callback_lock);
923 }
924 rcu_read_unlock();
925 }
926
927 static void smap_tx_work(struct work_struct *w)
928 {
929 struct smap_psock *psock;
930 struct sk_buff *skb;
931 int rem, off, n;
932
933 psock = container_of(w, struct smap_psock, tx_work);
934
935 /* lock sock to avoid losing sk_socket at some point during loop */
936 lock_sock(psock->sock);
937 if (psock->save_skb) {
938 skb = psock->save_skb;
939 rem = psock->save_rem;
940 off = psock->save_off;
941 psock->save_skb = NULL;
942 goto start;
943 }
944
945 while ((skb = skb_dequeue(&psock->rxqueue))) {
946 rem = skb->len;
947 off = 0;
948 start:
949 do {
950 if (likely(psock->sock->sk_socket))
951 n = skb_send_sock_locked(psock->sock,
952 skb, off, rem);
953 else
954 n = -EINVAL;
955 if (n <= 0) {
956 if (n == -EAGAIN) {
957 /* Retry when space is available */
958 psock->save_skb = skb;
959 psock->save_rem = rem;
960 psock->save_off = off;
961 goto out;
962 }
963 /* Hard errors break pipe and stop xmit */
964 smap_report_sk_error(psock, n ? -n : EPIPE);
965 clear_bit(SMAP_TX_RUNNING, &psock->state);
966 kfree_skb(skb);
967 goto out;
968 }
969 rem -= n;
970 off += n;
971 } while (rem);
972 kfree_skb(skb);
973 }
974 out:
975 release_sock(psock->sock);
976 }
977
978 static void smap_write_space(struct sock *sk)
979 {
980 struct smap_psock *psock;
981
982 rcu_read_lock();
983 psock = smap_psock_sk(sk);
984 if (likely(psock && test_bit(SMAP_TX_RUNNING, &psock->state)))
985 schedule_work(&psock->tx_work);
986 rcu_read_unlock();
987 }
988
989 static void smap_stop_sock(struct smap_psock *psock, struct sock *sk)
990 {
991 if (!psock->strp_enabled)
992 return;
993 sk->sk_data_ready = psock->save_data_ready;
994 sk->sk_write_space = psock->save_write_space;
995 psock->save_data_ready = NULL;
996 psock->save_write_space = NULL;
997 strp_stop(&psock->strp);
998 psock->strp_enabled = false;
999 }
1000
1001 static void smap_destroy_psock(struct rcu_head *rcu)
1002 {
1003 struct smap_psock *psock = container_of(rcu,
1004 struct smap_psock, rcu);
1005
1006 /* Now that a grace period has passed there is no longer
1007 * any reference to this sock in the sockmap so we can
1008 * destroy the psock, strparser, and bpf programs. But,
1009 * because we use workqueue sync operations we can not
1010 * do it in rcu context
1011 */
1012 schedule_work(&psock->gc_work);
1013 }
1014
1015 static void smap_release_sock(struct smap_psock *psock, struct sock *sock)
1016 {
1017 if (refcount_dec_and_test(&psock->refcnt)) {
1018 tcp_cleanup_ulp(sock);
1019 smap_stop_sock(psock, sock);
1020 clear_bit(SMAP_TX_RUNNING, &psock->state);
1021 rcu_assign_sk_user_data(sock, NULL);
1022 call_rcu_sched(&psock->rcu, smap_destroy_psock);
1023 }
1024 }
1025
1026 static int smap_parse_func_strparser(struct strparser *strp,
1027 struct sk_buff *skb)
1028 {
1029 struct smap_psock *psock;
1030 struct bpf_prog *prog;
1031 int rc;
1032
1033 rcu_read_lock();
1034 psock = container_of(strp, struct smap_psock, strp);
1035 prog = READ_ONCE(psock->bpf_parse);
1036
1037 if (unlikely(!prog)) {
1038 rcu_read_unlock();
1039 return skb->len;
1040 }
1041
1042 /* Attach socket for bpf program to use if needed we can do this
1043 * because strparser clones the skb before handing it to a upper
1044 * layer, meaning skb_orphan has been called. We NULL sk on the
1045 * way out to ensure we don't trigger a BUG_ON in skb/sk operations
1046 * later and because we are not charging the memory of this skb to
1047 * any socket yet.
1048 */
1049 skb->sk = psock->sock;
1050 bpf_compute_data_pointers(skb);
1051 rc = (*prog->bpf_func)(skb, prog->insnsi);
1052 skb->sk = NULL;
1053 rcu_read_unlock();
1054 return rc;
1055 }
1056
1057 static int smap_read_sock_done(struct strparser *strp, int err)
1058 {
1059 return err;
1060 }
1061
1062 static int smap_init_sock(struct smap_psock *psock,
1063 struct sock *sk)
1064 {
1065 static const struct strp_callbacks cb = {
1066 .rcv_msg = smap_read_sock_strparser,
1067 .parse_msg = smap_parse_func_strparser,
1068 .read_sock_done = smap_read_sock_done,
1069 };
1070
1071 return strp_init(&psock->strp, sk, &cb);
1072 }
1073
1074 static void smap_init_progs(struct smap_psock *psock,
1075 struct bpf_stab *stab,
1076 struct bpf_prog *verdict,
1077 struct bpf_prog *parse)
1078 {
1079 struct bpf_prog *orig_parse, *orig_verdict;
1080
1081 orig_parse = xchg(&psock->bpf_parse, parse);
1082 orig_verdict = xchg(&psock->bpf_verdict, verdict);
1083
1084 if (orig_verdict)
1085 bpf_prog_put(orig_verdict);
1086 if (orig_parse)
1087 bpf_prog_put(orig_parse);
1088 }
1089
1090 static void smap_start_sock(struct smap_psock *psock, struct sock *sk)
1091 {
1092 if (sk->sk_data_ready == smap_data_ready)
1093 return;
1094 psock->save_data_ready = sk->sk_data_ready;
1095 psock->save_write_space = sk->sk_write_space;
1096 sk->sk_data_ready = smap_data_ready;
1097 sk->sk_write_space = smap_write_space;
1098 psock->strp_enabled = true;
1099 }
1100
1101 static void sock_map_remove_complete(struct bpf_stab *stab)
1102 {
1103 bpf_map_area_free(stab->sock_map);
1104 kfree(stab);
1105 }
1106
1107 static void smap_gc_work(struct work_struct *w)
1108 {
1109 struct smap_psock_map_entry *e, *tmp;
1110 struct smap_psock *psock;
1111
1112 psock = container_of(w, struct smap_psock, gc_work);
1113
1114 /* no callback lock needed because we already detached sockmap ops */
1115 if (psock->strp_enabled)
1116 strp_done(&psock->strp);
1117
1118 cancel_work_sync(&psock->tx_work);
1119 __skb_queue_purge(&psock->rxqueue);
1120
1121 /* At this point all strparser and xmit work must be complete */
1122 if (psock->bpf_parse)
1123 bpf_prog_put(psock->bpf_parse);
1124 if (psock->bpf_verdict)
1125 bpf_prog_put(psock->bpf_verdict);
1126 if (psock->bpf_tx_msg)
1127 bpf_prog_put(psock->bpf_tx_msg);
1128
1129 if (psock->cork) {
1130 free_start_sg(psock->sock, psock->cork);
1131 kfree(psock->cork);
1132 }
1133
1134 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1135 list_del(&e->list);
1136 kfree(e);
1137 }
1138
1139 if (psock->sk_redir)
1140 sock_put(psock->sk_redir);
1141
1142 sock_put(psock->sock);
1143 kfree(psock);
1144 }
1145
1146 static struct smap_psock *smap_init_psock(struct sock *sock,
1147 struct bpf_stab *stab)
1148 {
1149 struct smap_psock *psock;
1150
1151 psock = kzalloc_node(sizeof(struct smap_psock),
1152 GFP_ATOMIC | __GFP_NOWARN,
1153 stab->map.numa_node);
1154 if (!psock)
1155 return ERR_PTR(-ENOMEM);
1156
1157 psock->eval = __SK_NONE;
1158 psock->sock = sock;
1159 skb_queue_head_init(&psock->rxqueue);
1160 INIT_WORK(&psock->tx_work, smap_tx_work);
1161 INIT_WORK(&psock->gc_work, smap_gc_work);
1162 INIT_LIST_HEAD(&psock->maps);
1163 refcount_set(&psock->refcnt, 1);
1164
1165 rcu_assign_sk_user_data(sock, psock);
1166 sock_hold(sock);
1167 return psock;
1168 }
1169
1170 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
1171 {
1172 struct bpf_stab *stab;
1173 u64 cost;
1174 int err;
1175
1176 if (!capable(CAP_NET_ADMIN))
1177 return ERR_PTR(-EPERM);
1178
1179 /* check sanity of attributes */
1180 if (attr->max_entries == 0 || attr->key_size != 4 ||
1181 attr->value_size != 4 || attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
1182 return ERR_PTR(-EINVAL);
1183
1184 if (attr->value_size > KMALLOC_MAX_SIZE)
1185 return ERR_PTR(-E2BIG);
1186
1187 err = bpf_tcp_ulp_register();
1188 if (err && err != -EEXIST)
1189 return ERR_PTR(err);
1190
1191 stab = kzalloc(sizeof(*stab), GFP_USER);
1192 if (!stab)
1193 return ERR_PTR(-ENOMEM);
1194
1195 bpf_map_init_from_attr(&stab->map, attr);
1196
1197 /* make sure page count doesn't overflow */
1198 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
1199 err = -EINVAL;
1200 if (cost >= U32_MAX - PAGE_SIZE)
1201 goto free_stab;
1202
1203 stab->map.pages = round_up(cost, PAGE_SIZE) >> PAGE_SHIFT;
1204
1205 /* if map size is larger than memlock limit, reject it early */
1206 err = bpf_map_precharge_memlock(stab->map.pages);
1207 if (err)
1208 goto free_stab;
1209
1210 err = -ENOMEM;
1211 stab->sock_map = bpf_map_area_alloc(stab->map.max_entries *
1212 sizeof(struct sock *),
1213 stab->map.numa_node);
1214 if (!stab->sock_map)
1215 goto free_stab;
1216
1217 return &stab->map;
1218 free_stab:
1219 kfree(stab);
1220 return ERR_PTR(err);
1221 }
1222
1223 static void smap_list_remove(struct smap_psock *psock, struct sock **entry)
1224 {
1225 struct smap_psock_map_entry *e, *tmp;
1226
1227 list_for_each_entry_safe(e, tmp, &psock->maps, list) {
1228 if (e->entry == entry) {
1229 list_del(&e->list);
1230 break;
1231 }
1232 }
1233 }
1234
1235 static void sock_map_free(struct bpf_map *map)
1236 {
1237 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1238 int i;
1239
1240 synchronize_rcu();
1241
1242 /* At this point no update, lookup or delete operations can happen.
1243 * However, be aware we can still get a socket state event updates,
1244 * and data ready callabacks that reference the psock from sk_user_data
1245 * Also psock worker threads are still in-flight. So smap_release_sock
1246 * will only free the psock after cancel_sync on the worker threads
1247 * and a grace period expire to ensure psock is really safe to remove.
1248 */
1249 rcu_read_lock();
1250 for (i = 0; i < stab->map.max_entries; i++) {
1251 struct smap_psock *psock;
1252 struct sock *sock;
1253
1254 sock = xchg(&stab->sock_map[i], NULL);
1255 if (!sock)
1256 continue;
1257
1258 write_lock_bh(&sock->sk_callback_lock);
1259 psock = smap_psock_sk(sock);
1260 /* This check handles a racing sock event that can get the
1261 * sk_callback_lock before this case but after xchg happens
1262 * causing the refcnt to hit zero and sock user data (psock)
1263 * to be null and queued for garbage collection.
1264 */
1265 if (likely(psock)) {
1266 smap_list_remove(psock, &stab->sock_map[i]);
1267 smap_release_sock(psock, sock);
1268 }
1269 write_unlock_bh(&sock->sk_callback_lock);
1270 }
1271 rcu_read_unlock();
1272
1273 sock_map_remove_complete(stab);
1274 }
1275
1276 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next_key)
1277 {
1278 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1279 u32 i = key ? *(u32 *)key : U32_MAX;
1280 u32 *next = (u32 *)next_key;
1281
1282 if (i >= stab->map.max_entries) {
1283 *next = 0;
1284 return 0;
1285 }
1286
1287 if (i == stab->map.max_entries - 1)
1288 return -ENOENT;
1289
1290 *next = i + 1;
1291 return 0;
1292 }
1293
1294 struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
1295 {
1296 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1297
1298 if (key >= map->max_entries)
1299 return NULL;
1300
1301 return READ_ONCE(stab->sock_map[key]);
1302 }
1303
1304 static int sock_map_delete_elem(struct bpf_map *map, void *key)
1305 {
1306 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1307 struct smap_psock *psock;
1308 int k = *(u32 *)key;
1309 struct sock *sock;
1310
1311 if (k >= map->max_entries)
1312 return -EINVAL;
1313
1314 sock = xchg(&stab->sock_map[k], NULL);
1315 if (!sock)
1316 return -EINVAL;
1317
1318 write_lock_bh(&sock->sk_callback_lock);
1319 psock = smap_psock_sk(sock);
1320 if (!psock)
1321 goto out;
1322
1323 if (psock->bpf_parse)
1324 smap_stop_sock(psock, sock);
1325 smap_list_remove(psock, &stab->sock_map[k]);
1326 smap_release_sock(psock, sock);
1327 out:
1328 write_unlock_bh(&sock->sk_callback_lock);
1329 return 0;
1330 }
1331
1332 /* Locking notes: Concurrent updates, deletes, and lookups are allowed and are
1333 * done inside rcu critical sections. This ensures on updates that the psock
1334 * will not be released via smap_release_sock() until concurrent updates/deletes
1335 * complete. All operations operate on sock_map using cmpxchg and xchg
1336 * operations to ensure we do not get stale references. Any reads into the
1337 * map must be done with READ_ONCE() because of this.
1338 *
1339 * A psock is destroyed via call_rcu and after any worker threads are cancelled
1340 * and syncd so we are certain all references from the update/lookup/delete
1341 * operations as well as references in the data path are no longer in use.
1342 *
1343 * Psocks may exist in multiple maps, but only a single set of parse/verdict
1344 * programs may be inherited from the maps it belongs to. A reference count
1345 * is kept with the total number of references to the psock from all maps. The
1346 * psock will not be released until this reaches zero. The psock and sock
1347 * user data data use the sk_callback_lock to protect critical data structures
1348 * from concurrent access. This allows us to avoid two updates from modifying
1349 * the user data in sock and the lock is required anyways for modifying
1350 * callbacks, we simply increase its scope slightly.
1351 *
1352 * Rules to follow,
1353 * - psock must always be read inside RCU critical section
1354 * - sk_user_data must only be modified inside sk_callback_lock and read
1355 * inside RCU critical section.
1356 * - psock->maps list must only be read & modified inside sk_callback_lock
1357 * - sock_map must use READ_ONCE and (cmp)xchg operations
1358 * - BPF verdict/parse programs must use READ_ONCE and xchg operations
1359 */
1360 static int sock_map_ctx_update_elem(struct bpf_sock_ops_kern *skops,
1361 struct bpf_map *map,
1362 void *key, u64 flags)
1363 {
1364 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1365 struct smap_psock_map_entry *e = NULL;
1366 struct bpf_prog *verdict, *parse, *tx_msg;
1367 struct sock *osock, *sock;
1368 struct smap_psock *psock;
1369 u32 i = *(u32 *)key;
1370 bool new = false;
1371 int err;
1372
1373 if (unlikely(flags > BPF_EXIST))
1374 return -EINVAL;
1375
1376 if (unlikely(i >= stab->map.max_entries))
1377 return -E2BIG;
1378
1379 sock = READ_ONCE(stab->sock_map[i]);
1380 if (flags == BPF_EXIST && !sock)
1381 return -ENOENT;
1382 else if (flags == BPF_NOEXIST && sock)
1383 return -EEXIST;
1384
1385 sock = skops->sk;
1386
1387 /* 1. If sock map has BPF programs those will be inherited by the
1388 * sock being added. If the sock is already attached to BPF programs
1389 * this results in an error.
1390 */
1391 verdict = READ_ONCE(stab->bpf_verdict);
1392 parse = READ_ONCE(stab->bpf_parse);
1393 tx_msg = READ_ONCE(stab->bpf_tx_msg);
1394
1395 if (parse && verdict) {
1396 /* bpf prog refcnt may be zero if a concurrent attach operation
1397 * removes the program after the above READ_ONCE() but before
1398 * we increment the refcnt. If this is the case abort with an
1399 * error.
1400 */
1401 verdict = bpf_prog_inc_not_zero(stab->bpf_verdict);
1402 if (IS_ERR(verdict))
1403 return PTR_ERR(verdict);
1404
1405 parse = bpf_prog_inc_not_zero(stab->bpf_parse);
1406 if (IS_ERR(parse)) {
1407 bpf_prog_put(verdict);
1408 return PTR_ERR(parse);
1409 }
1410 }
1411
1412 if (tx_msg) {
1413 tx_msg = bpf_prog_inc_not_zero(stab->bpf_tx_msg);
1414 if (IS_ERR(tx_msg)) {
1415 if (verdict)
1416 bpf_prog_put(verdict);
1417 if (parse)
1418 bpf_prog_put(parse);
1419 return PTR_ERR(tx_msg);
1420 }
1421 }
1422
1423 write_lock_bh(&sock->sk_callback_lock);
1424 psock = smap_psock_sk(sock);
1425
1426 /* 2. Do not allow inheriting programs if psock exists and has
1427 * already inherited programs. This would create confusion on
1428 * which parser/verdict program is running. If no psock exists
1429 * create one. Inside sk_callback_lock to ensure concurrent create
1430 * doesn't update user data.
1431 */
1432 if (psock) {
1433 if (READ_ONCE(psock->bpf_parse) && parse) {
1434 err = -EBUSY;
1435 goto out_progs;
1436 }
1437 if (READ_ONCE(psock->bpf_tx_msg) && tx_msg) {
1438 err = -EBUSY;
1439 goto out_progs;
1440 }
1441 if (!refcount_inc_not_zero(&psock->refcnt)) {
1442 err = -EAGAIN;
1443 goto out_progs;
1444 }
1445 } else {
1446 psock = smap_init_psock(sock, stab);
1447 if (IS_ERR(psock)) {
1448 err = PTR_ERR(psock);
1449 goto out_progs;
1450 }
1451
1452 set_bit(SMAP_TX_RUNNING, &psock->state);
1453 new = true;
1454 }
1455
1456 e = kzalloc(sizeof(*e), GFP_ATOMIC | __GFP_NOWARN);
1457 if (!e) {
1458 err = -ENOMEM;
1459 goto out_progs;
1460 }
1461 e->entry = &stab->sock_map[i];
1462
1463 /* 3. At this point we have a reference to a valid psock that is
1464 * running. Attach any BPF programs needed.
1465 */
1466 if (tx_msg)
1467 bpf_tcp_msg_add(psock, sock, tx_msg);
1468 if (new) {
1469 err = tcp_set_ulp_id(sock, TCP_ULP_BPF);
1470 if (err)
1471 goto out_free;
1472 }
1473
1474 if (parse && verdict && !psock->strp_enabled) {
1475 err = smap_init_sock(psock, sock);
1476 if (err)
1477 goto out_free;
1478 smap_init_progs(psock, stab, verdict, parse);
1479 smap_start_sock(psock, sock);
1480 }
1481
1482 /* 4. Place psock in sockmap for use and stop any programs on
1483 * the old sock assuming its not the same sock we are replacing
1484 * it with. Because we can only have a single set of programs if
1485 * old_sock has a strp we can stop it.
1486 */
1487 list_add_tail(&e->list, &psock->maps);
1488 write_unlock_bh(&sock->sk_callback_lock);
1489
1490 osock = xchg(&stab->sock_map[i], sock);
1491 if (osock) {
1492 struct smap_psock *opsock = smap_psock_sk(osock);
1493
1494 write_lock_bh(&osock->sk_callback_lock);
1495 smap_list_remove(opsock, &stab->sock_map[i]);
1496 smap_release_sock(opsock, osock);
1497 write_unlock_bh(&osock->sk_callback_lock);
1498 }
1499 return 0;
1500 out_free:
1501 smap_release_sock(psock, sock);
1502 out_progs:
1503 if (verdict)
1504 bpf_prog_put(verdict);
1505 if (parse)
1506 bpf_prog_put(parse);
1507 if (tx_msg)
1508 bpf_prog_put(tx_msg);
1509 write_unlock_bh(&sock->sk_callback_lock);
1510 kfree(e);
1511 return err;
1512 }
1513
1514 int sock_map_prog(struct bpf_map *map, struct bpf_prog *prog, u32 type)
1515 {
1516 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1517 struct bpf_prog *orig;
1518
1519 if (unlikely(map->map_type != BPF_MAP_TYPE_SOCKMAP))
1520 return -EINVAL;
1521
1522 switch (type) {
1523 case BPF_SK_MSG_VERDICT:
1524 orig = xchg(&stab->bpf_tx_msg, prog);
1525 break;
1526 case BPF_SK_SKB_STREAM_PARSER:
1527 orig = xchg(&stab->bpf_parse, prog);
1528 break;
1529 case BPF_SK_SKB_STREAM_VERDICT:
1530 orig = xchg(&stab->bpf_verdict, prog);
1531 break;
1532 default:
1533 return -EOPNOTSUPP;
1534 }
1535
1536 if (orig)
1537 bpf_prog_put(orig);
1538
1539 return 0;
1540 }
1541
1542 static void *sock_map_lookup(struct bpf_map *map, void *key)
1543 {
1544 return NULL;
1545 }
1546
1547 static int sock_map_update_elem(struct bpf_map *map,
1548 void *key, void *value, u64 flags)
1549 {
1550 struct bpf_sock_ops_kern skops;
1551 u32 fd = *(u32 *)value;
1552 struct socket *socket;
1553 int err;
1554
1555 socket = sockfd_lookup(fd, &err);
1556 if (!socket)
1557 return err;
1558
1559 skops.sk = socket->sk;
1560 if (!skops.sk) {
1561 fput(socket->file);
1562 return -EINVAL;
1563 }
1564
1565 if (skops.sk->sk_type != SOCK_STREAM ||
1566 skops.sk->sk_protocol != IPPROTO_TCP) {
1567 fput(socket->file);
1568 return -EOPNOTSUPP;
1569 }
1570
1571 err = sock_map_ctx_update_elem(&skops, map, key, flags);
1572 fput(socket->file);
1573 return err;
1574 }
1575
1576 static void sock_map_release(struct bpf_map *map, struct file *map_file)
1577 {
1578 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
1579 struct bpf_prog *orig;
1580
1581 orig = xchg(&stab->bpf_parse, NULL);
1582 if (orig)
1583 bpf_prog_put(orig);
1584 orig = xchg(&stab->bpf_verdict, NULL);
1585 if (orig)
1586 bpf_prog_put(orig);
1587
1588 orig = xchg(&stab->bpf_tx_msg, NULL);
1589 if (orig)
1590 bpf_prog_put(orig);
1591 }
1592
1593 const struct bpf_map_ops sock_map_ops = {
1594 .map_alloc = sock_map_alloc,
1595 .map_free = sock_map_free,
1596 .map_lookup_elem = sock_map_lookup,
1597 .map_get_next_key = sock_map_get_next_key,
1598 .map_update_elem = sock_map_update_elem,
1599 .map_delete_elem = sock_map_delete_elem,
1600 .map_release = sock_map_release,
1601 };
1602
1603 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, bpf_sock,
1604 struct bpf_map *, map, void *, key, u64, flags)
1605 {
1606 WARN_ON_ONCE(!rcu_read_lock_held());
1607 return sock_map_ctx_update_elem(bpf_sock, map, key, flags);
1608 }
1609
1610 const struct bpf_func_proto bpf_sock_map_update_proto = {
1611 .func = bpf_sock_map_update,
1612 .gpl_only = false,
1613 .pkt_access = true,
1614 .ret_type = RET_INTEGER,
1615 .arg1_type = ARG_PTR_TO_CTX,
1616 .arg2_type = ARG_CONST_MAP_PTR,
1617 .arg3_type = ARG_PTR_TO_MAP_KEY,
1618 .arg4_type = ARG_ANYTHING,
1619 };