]>
Commit | Line | Data |
---|---|---|
604326b4 DB |
1 | // SPDX-License-Identifier: GPL-2.0 |
2 | /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */ | |
3 | ||
4 | #include <linux/skmsg.h> | |
5 | #include <linux/filter.h> | |
6 | #include <linux/bpf.h> | |
7 | #include <linux/init.h> | |
8 | #include <linux/wait.h> | |
9 | ||
10 | #include <net/inet_common.h> | |
0608c69c | 11 | #include <net/tls.h> |
604326b4 | 12 | |
604326b4 DB |
13 | static int bpf_tcp_ingress(struct sock *sk, struct sk_psock *psock, |
14 | struct sk_msg *msg, u32 apply_bytes, int flags) | |
15 | { | |
16 | bool apply = apply_bytes; | |
17 | struct scatterlist *sge; | |
18 | u32 size, copied = 0; | |
19 | struct sk_msg *tmp; | |
20 | int i, ret = 0; | |
21 | ||
22 | tmp = kzalloc(sizeof(*tmp), __GFP_NOWARN | GFP_KERNEL); | |
23 | if (unlikely(!tmp)) | |
24 | return -ENOMEM; | |
25 | ||
26 | lock_sock(sk); | |
27 | tmp->sg.start = msg->sg.start; | |
28 | i = msg->sg.start; | |
29 | do { | |
30 | sge = sk_msg_elem(msg, i); | |
31 | size = (apply && apply_bytes < sge->length) ? | |
32 | apply_bytes : sge->length; | |
33 | if (!sk_wmem_schedule(sk, size)) { | |
34 | if (!copied) | |
35 | ret = -ENOMEM; | |
36 | break; | |
37 | } | |
38 | ||
39 | sk_mem_charge(sk, size); | |
40 | sk_msg_xfer(tmp, msg, i, size); | |
41 | copied += size; | |
42 | if (sge->length) | |
43 | get_page(sk_msg_page(tmp, i)); | |
44 | sk_msg_iter_var_next(i); | |
45 | tmp->sg.end = i; | |
46 | if (apply) { | |
47 | apply_bytes -= size; | |
48 | if (!apply_bytes) | |
49 | break; | |
50 | } | |
51 | } while (i != msg->sg.end); | |
52 | ||
53 | if (!ret) { | |
54 | msg->sg.start = i; | |
604326b4 | 55 | sk_psock_queue_msg(psock, tmp); |
552de910 | 56 | sk_psock_data_ready(sk, psock); |
604326b4 DB |
57 | } else { |
58 | sk_msg_free(sk, tmp); | |
59 | kfree(tmp); | |
60 | } | |
61 | ||
62 | release_sock(sk); | |
63 | return ret; | |
64 | } | |
65 | ||
66 | static int tcp_bpf_push(struct sock *sk, struct sk_msg *msg, u32 apply_bytes, | |
67 | int flags, bool uncharge) | |
68 | { | |
69 | bool apply = apply_bytes; | |
70 | struct scatterlist *sge; | |
71 | struct page *page; | |
72 | int size, ret = 0; | |
73 | u32 off; | |
74 | ||
75 | while (1) { | |
0608c69c JF |
76 | bool has_tx_ulp; |
77 | ||
604326b4 DB |
78 | sge = sk_msg_elem(msg, msg->sg.start); |
79 | size = (apply && apply_bytes < sge->length) ? | |
80 | apply_bytes : sge->length; | |
81 | off = sge->offset; | |
82 | page = sg_page(sge); | |
83 | ||
84 | tcp_rate_check_app_limited(sk); | |
85 | retry: | |
0608c69c JF |
86 | has_tx_ulp = tls_sw_has_ctx_tx(sk); |
87 | if (has_tx_ulp) { | |
88 | flags |= MSG_SENDPAGE_NOPOLICY; | |
89 | ret = kernel_sendpage_locked(sk, | |
90 | page, off, size, flags); | |
91 | } else { | |
92 | ret = do_tcp_sendpages(sk, page, off, size, flags); | |
93 | } | |
94 | ||
604326b4 DB |
95 | if (ret <= 0) |
96 | return ret; | |
97 | if (apply) | |
98 | apply_bytes -= ret; | |
99 | msg->sg.size -= ret; | |
100 | sge->offset += ret; | |
101 | sge->length -= ret; | |
102 | if (uncharge) | |
103 | sk_mem_uncharge(sk, ret); | |
104 | if (ret != size) { | |
105 | size -= ret; | |
106 | off += ret; | |
107 | goto retry; | |
108 | } | |
109 | if (!sge->length) { | |
110 | put_page(page); | |
111 | sk_msg_iter_next(msg, start); | |
112 | sg_init_table(sge, 1); | |
113 | if (msg->sg.start == msg->sg.end) | |
114 | break; | |
115 | } | |
116 | if (apply && !apply_bytes) | |
117 | break; | |
118 | } | |
119 | ||
120 | return 0; | |
121 | } | |
122 | ||
123 | static int tcp_bpf_push_locked(struct sock *sk, struct sk_msg *msg, | |
124 | u32 apply_bytes, int flags, bool uncharge) | |
125 | { | |
126 | int ret; | |
127 | ||
128 | lock_sock(sk); | |
129 | ret = tcp_bpf_push(sk, msg, apply_bytes, flags, uncharge); | |
130 | release_sock(sk); | |
131 | return ret; | |
132 | } | |
133 | ||
134 | int tcp_bpf_sendmsg_redir(struct sock *sk, struct sk_msg *msg, | |
135 | u32 bytes, int flags) | |
136 | { | |
137 | bool ingress = sk_msg_to_ingress(msg); | |
138 | struct sk_psock *psock = sk_psock_get(sk); | |
139 | int ret; | |
140 | ||
4cb38f26 WY |
141 | if (unlikely(!psock)) |
142 | return -EPIPE; | |
143 | ||
604326b4 DB |
144 | ret = ingress ? bpf_tcp_ingress(sk, psock, msg, bytes, flags) : |
145 | tcp_bpf_push_locked(sk, msg, bytes, flags, false); | |
146 | sk_psock_put(sk, psock); | |
147 | return ret; | |
148 | } | |
149 | EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir); | |
150 | ||
88759609 | 151 | #ifdef CONFIG_BPF_SYSCALL |
b6df0078 JK |
152 | static int tcp_msg_wait_data(struct sock *sk, struct sk_psock *psock, |
153 | long timeo) | |
9f2470fb CW |
154 | { |
155 | DEFINE_WAIT_FUNC(wait, woken_wake_function); | |
156 | int ret = 0; | |
157 | ||
158 | if (sk->sk_shutdown & RCV_SHUTDOWN) | |
159 | return 1; | |
160 | ||
161 | if (!timeo) | |
162 | return ret; | |
163 | ||
164 | add_wait_queue(sk_sleep(sk), &wait); | |
165 | sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); | |
166 | ret = sk_wait_event(sk, &timeo, | |
167 | !list_empty(&psock->ingress_msg) || | |
168 | !skb_queue_empty(&sk->sk_receive_queue), &wait); | |
169 | sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); | |
170 | remove_wait_queue(sk_sleep(sk), &wait); | |
171 | return ret; | |
172 | } | |
173 | ||
73e8fdc1 JF |
174 | static int tcp_bpf_recvmsg_parser(struct sock *sk, |
175 | struct msghdr *msg, | |
176 | size_t len, | |
177 | int nonblock, | |
178 | int flags, | |
179 | int *addr_len) | |
180 | { | |
181 | struct sk_psock *psock; | |
182 | int copied; | |
183 | ||
184 | if (unlikely(flags & MSG_ERRQUEUE)) | |
185 | return inet_recv_error(sk, msg, len, addr_len); | |
186 | ||
187 | psock = sk_psock_get(sk); | |
188 | if (unlikely(!psock)) | |
189 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | |
190 | ||
191 | lock_sock(sk); | |
192 | msg_bytes_ready: | |
193 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); | |
194 | if (!copied) { | |
195 | long timeo; | |
196 | int data; | |
197 | ||
2bc9b7e8 JF |
198 | if (sock_flag(sk, SOCK_DONE)) |
199 | goto out; | |
200 | ||
201 | if (sk->sk_err) { | |
202 | copied = sock_error(sk); | |
203 | goto out; | |
204 | } | |
205 | ||
206 | if (sk->sk_shutdown & RCV_SHUTDOWN) | |
207 | goto out; | |
208 | ||
209 | if (sk->sk_state == TCP_CLOSE) { | |
210 | copied = -ENOTCONN; | |
211 | goto out; | |
212 | } | |
213 | ||
73e8fdc1 | 214 | timeo = sock_rcvtimeo(sk, nonblock); |
2bc9b7e8 JF |
215 | if (!timeo) { |
216 | copied = -EAGAIN; | |
217 | goto out; | |
218 | } | |
219 | ||
220 | if (signal_pending(current)) { | |
221 | copied = sock_intr_errno(timeo); | |
222 | goto out; | |
223 | } | |
224 | ||
73e8fdc1 JF |
225 | data = tcp_msg_wait_data(sk, psock, timeo); |
226 | if (data && !sk_psock_queue_empty(psock)) | |
227 | goto msg_bytes_ready; | |
228 | copied = -EAGAIN; | |
229 | } | |
2bc9b7e8 | 230 | out: |
73e8fdc1 JF |
231 | release_sock(sk); |
232 | sk_psock_put(sk, psock); | |
233 | return copied; | |
234 | } | |
235 | ||
c0fd336e Y |
236 | static int tcp_bpf_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, |
237 | int nonblock, int flags, int *addr_len) | |
238 | { | |
239 | struct sk_psock *psock; | |
240 | int copied, ret; | |
241 | ||
18f02ad1 XY |
242 | if (unlikely(flags & MSG_ERRQUEUE)) |
243 | return inet_recv_error(sk, msg, len, addr_len); | |
244 | ||
c0fd336e Y |
245 | psock = sk_psock_get(sk); |
246 | if (unlikely(!psock)) | |
247 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | |
c0fd336e | 248 | if (!skb_queue_empty(&sk->sk_receive_queue) && |
18f02ad1 XY |
249 | sk_psock_queue_empty(psock)) { |
250 | sk_psock_put(sk, psock); | |
c0fd336e | 251 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); |
18f02ad1 | 252 | } |
c0fd336e Y |
253 | lock_sock(sk); |
254 | msg_bytes_ready: | |
2bc793e3 | 255 | copied = sk_msg_recvmsg(sk, psock, msg, len, flags); |
c0fd336e | 256 | if (!copied) { |
c0fd336e | 257 | long timeo; |
c49661aa | 258 | int data; |
c0fd336e Y |
259 | |
260 | timeo = sock_rcvtimeo(sk, nonblock); | |
b6df0078 | 261 | data = tcp_msg_wait_data(sk, psock, timeo); |
c0fd336e Y |
262 | if (data) { |
263 | if (!sk_psock_queue_empty(psock)) | |
264 | goto msg_bytes_ready; | |
265 | release_sock(sk); | |
266 | sk_psock_put(sk, psock); | |
267 | return tcp_recvmsg(sk, msg, len, nonblock, flags, addr_len); | |
268 | } | |
c0fd336e Y |
269 | copied = -EAGAIN; |
270 | } | |
271 | ret = copied; | |
c0fd336e Y |
272 | release_sock(sk); |
273 | sk_psock_put(sk, psock); | |
274 | return ret; | |
275 | } | |
276 | ||
604326b4 DB |
277 | static int tcp_bpf_send_verdict(struct sock *sk, struct sk_psock *psock, |
278 | struct sk_msg *msg, int *copied, int flags) | |
279 | { | |
031097d9 | 280 | bool cork = false, enospc = sk_msg_full(msg); |
604326b4 | 281 | struct sock *sk_redir; |
7246d8ed | 282 | u32 tosend, delta = 0; |
cd9733f5 | 283 | u32 eval = __SK_NONE; |
604326b4 DB |
284 | int ret; |
285 | ||
286 | more_data: | |
7246d8ed JF |
287 | if (psock->eval == __SK_NONE) { |
288 | /* Track delta in msg size to add/subtract it on SK_DROP from | |
289 | * returned to user copied size. This ensures user doesn't | |
290 | * get a positive return code with msg_cut_data and SK_DROP | |
291 | * verdict. | |
292 | */ | |
293 | delta = msg->sg.size; | |
604326b4 | 294 | psock->eval = sk_psock_msg_verdict(sk, psock, msg); |
7361d448 | 295 | delta -= msg->sg.size; |
7246d8ed | 296 | } |
604326b4 DB |
297 | |
298 | if (msg->cork_bytes && | |
299 | msg->cork_bytes > msg->sg.size && !enospc) { | |
300 | psock->cork_bytes = msg->cork_bytes - msg->sg.size; | |
301 | if (!psock->cork) { | |
302 | psock->cork = kzalloc(sizeof(*psock->cork), | |
303 | GFP_ATOMIC | __GFP_NOWARN); | |
304 | if (!psock->cork) | |
305 | return -ENOMEM; | |
306 | } | |
307 | memcpy(psock->cork, msg, sizeof(*msg)); | |
308 | return 0; | |
309 | } | |
310 | ||
311 | tosend = msg->sg.size; | |
312 | if (psock->apply_bytes && psock->apply_bytes < tosend) | |
313 | tosend = psock->apply_bytes; | |
314 | ||
315 | switch (psock->eval) { | |
316 | case __SK_PASS: | |
317 | ret = tcp_bpf_push(sk, msg, tosend, flags, true); | |
318 | if (unlikely(ret)) { | |
319 | *copied -= sk_msg_free(sk, msg); | |
320 | break; | |
321 | } | |
322 | sk_msg_apply_bytes(psock, tosend); | |
323 | break; | |
324 | case __SK_REDIRECT: | |
325 | sk_redir = psock->sk_redir; | |
326 | sk_msg_apply_bytes(psock, tosend); | |
cd9733f5 LJ |
327 | if (!psock->apply_bytes) { |
328 | /* Clean up before releasing the sock lock. */ | |
329 | eval = psock->eval; | |
330 | psock->eval = __SK_NONE; | |
331 | psock->sk_redir = NULL; | |
332 | } | |
604326b4 DB |
333 | if (psock->cork) { |
334 | cork = true; | |
335 | psock->cork = NULL; | |
336 | } | |
b8375314 | 337 | sk_msg_return(sk, msg, msg->sg.size); |
604326b4 | 338 | release_sock(sk); |
cd9733f5 | 339 | |
604326b4 | 340 | ret = tcp_bpf_sendmsg_redir(sk_redir, msg, tosend, flags); |
cd9733f5 LJ |
341 | |
342 | if (eval == __SK_REDIRECT) | |
343 | sock_put(sk_redir); | |
344 | ||
604326b4 DB |
345 | lock_sock(sk); |
346 | if (unlikely(ret < 0)) { | |
347 | int free = sk_msg_free_nocharge(sk, msg); | |
348 | ||
349 | if (!cork) | |
350 | *copied -= free; | |
351 | } | |
352 | if (cork) { | |
353 | sk_msg_free(sk, msg); | |
354 | kfree(msg); | |
355 | msg = NULL; | |
356 | ret = 0; | |
357 | } | |
358 | break; | |
359 | case __SK_DROP: | |
360 | default: | |
361 | sk_msg_free_partial(sk, msg, tosend); | |
362 | sk_msg_apply_bytes(psock, tosend); | |
7246d8ed | 363 | *copied -= (tosend + delta); |
604326b4 DB |
364 | return -EACCES; |
365 | } | |
366 | ||
367 | if (likely(!ret)) { | |
368 | if (!psock->apply_bytes) { | |
369 | psock->eval = __SK_NONE; | |
370 | if (psock->sk_redir) { | |
371 | sock_put(psock->sk_redir); | |
372 | psock->sk_redir = NULL; | |
373 | } | |
374 | } | |
375 | if (msg && | |
376 | msg->sg.data[msg->sg.start].page_link && | |
b8375314 WY |
377 | msg->sg.data[msg->sg.start].length) { |
378 | if (eval == __SK_REDIRECT) | |
379 | sk_mem_charge(sk, msg->sg.size); | |
604326b4 | 380 | goto more_data; |
b8375314 | 381 | } |
604326b4 DB |
382 | } |
383 | return ret; | |
384 | } | |
385 | ||
386 | static int tcp_bpf_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) | |
387 | { | |
388 | struct sk_msg tmp, *msg_tx = NULL; | |
604326b4 DB |
389 | int copied = 0, err = 0; |
390 | struct sk_psock *psock; | |
391 | long timeo; | |
41477662 JK |
392 | int flags; |
393 | ||
394 | /* Don't let internal do_tcp_sendpages() flags through */ | |
395 | flags = (msg->msg_flags & ~MSG_SENDPAGE_DECRYPTED); | |
396 | flags |= MSG_NO_SHARED_FRAGS; | |
604326b4 DB |
397 | |
398 | psock = sk_psock_get(sk); | |
399 | if (unlikely(!psock)) | |
400 | return tcp_sendmsg(sk, msg, size); | |
401 | ||
402 | lock_sock(sk); | |
403 | timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); | |
404 | while (msg_data_left(msg)) { | |
405 | bool enospc = false; | |
406 | u32 copy, osize; | |
407 | ||
408 | if (sk->sk_err) { | |
409 | err = -sk->sk_err; | |
410 | goto out_err; | |
411 | } | |
412 | ||
413 | copy = msg_data_left(msg); | |
414 | if (!sk_stream_memory_free(sk)) | |
415 | goto wait_for_sndbuf; | |
416 | if (psock->cork) { | |
417 | msg_tx = psock->cork; | |
418 | } else { | |
419 | msg_tx = &tmp; | |
420 | sk_msg_init(msg_tx); | |
421 | } | |
422 | ||
423 | osize = msg_tx->sg.size; | |
424 | err = sk_msg_alloc(sk, msg_tx, msg_tx->sg.size + copy, msg_tx->sg.end - 1); | |
425 | if (err) { | |
426 | if (err != -ENOSPC) | |
427 | goto wait_for_memory; | |
428 | enospc = true; | |
429 | copy = msg_tx->sg.size - osize; | |
430 | } | |
431 | ||
432 | err = sk_msg_memcopy_from_iter(sk, &msg->msg_iter, msg_tx, | |
433 | copy); | |
434 | if (err < 0) { | |
435 | sk_msg_trim(sk, msg_tx, osize); | |
436 | goto out_err; | |
437 | } | |
438 | ||
439 | copied += copy; | |
440 | if (psock->cork_bytes) { | |
441 | if (size > psock->cork_bytes) | |
442 | psock->cork_bytes = 0; | |
443 | else | |
444 | psock->cork_bytes -= size; | |
445 | if (psock->cork_bytes && !enospc) | |
446 | goto out_err; | |
447 | /* All cork bytes are accounted, rerun the prog. */ | |
448 | psock->eval = __SK_NONE; | |
449 | psock->cork_bytes = 0; | |
450 | } | |
451 | ||
452 | err = tcp_bpf_send_verdict(sk, psock, msg_tx, &copied, flags); | |
453 | if (unlikely(err < 0)) | |
454 | goto out_err; | |
455 | continue; | |
456 | wait_for_sndbuf: | |
457 | set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); | |
458 | wait_for_memory: | |
459 | err = sk_stream_wait_memory(sk, &timeo); | |
460 | if (err) { | |
461 | if (msg_tx && msg_tx != psock->cork) | |
462 | sk_msg_free(sk, msg_tx); | |
463 | goto out_err; | |
464 | } | |
465 | } | |
466 | out_err: | |
467 | if (err < 0) | |
468 | err = sk_stream_error(sk, msg->msg_flags, err); | |
469 | release_sock(sk); | |
470 | sk_psock_put(sk, psock); | |
471 | return copied ? copied : err; | |
472 | } | |
473 | ||
474 | static int tcp_bpf_sendpage(struct sock *sk, struct page *page, int offset, | |
475 | size_t size, int flags) | |
476 | { | |
477 | struct sk_msg tmp, *msg = NULL; | |
478 | int err = 0, copied = 0; | |
479 | struct sk_psock *psock; | |
480 | bool enospc = false; | |
481 | ||
482 | psock = sk_psock_get(sk); | |
483 | if (unlikely(!psock)) | |
484 | return tcp_sendpage(sk, page, offset, size, flags); | |
485 | ||
486 | lock_sock(sk); | |
487 | if (psock->cork) { | |
488 | msg = psock->cork; | |
489 | } else { | |
490 | msg = &tmp; | |
491 | sk_msg_init(msg); | |
492 | } | |
493 | ||
494 | /* Catch case where ring is full and sendpage is stalled. */ | |
495 | if (unlikely(sk_msg_full(msg))) | |
496 | goto out_err; | |
497 | ||
498 | sk_msg_page_add(msg, page, size, offset); | |
499 | sk_mem_charge(sk, size); | |
500 | copied = size; | |
501 | if (sk_msg_full(msg)) | |
502 | enospc = true; | |
503 | if (psock->cork_bytes) { | |
504 | if (size > psock->cork_bytes) | |
505 | psock->cork_bytes = 0; | |
506 | else | |
507 | psock->cork_bytes -= size; | |
508 | if (psock->cork_bytes && !enospc) | |
509 | goto out_err; | |
510 | /* All cork bytes are accounted, rerun the prog. */ | |
511 | psock->eval = __SK_NONE; | |
512 | psock->cork_bytes = 0; | |
513 | } | |
514 | ||
515 | err = tcp_bpf_send_verdict(sk, psock, msg, &copied, flags); | |
516 | out_err: | |
517 | release_sock(sk); | |
518 | sk_psock_put(sk, psock); | |
519 | return copied ? copied : err; | |
520 | } | |
521 | ||
604326b4 DB |
522 | enum { |
523 | TCP_BPF_IPV4, | |
524 | TCP_BPF_IPV6, | |
525 | TCP_BPF_NUM_PROTS, | |
526 | }; | |
527 | ||
528 | enum { | |
529 | TCP_BPF_BASE, | |
530 | TCP_BPF_TX, | |
73e8fdc1 JF |
531 | TCP_BPF_RX, |
532 | TCP_BPF_TXRX, | |
604326b4 DB |
533 | TCP_BPF_NUM_CFGS, |
534 | }; | |
535 | ||
536 | static struct proto *tcpv6_prot_saved __read_mostly; | |
537 | static DEFINE_SPINLOCK(tcpv6_prot_lock); | |
538 | static struct proto tcp_bpf_prots[TCP_BPF_NUM_PROTS][TCP_BPF_NUM_CFGS]; | |
539 | ||
540 | static void tcp_bpf_rebuild_protos(struct proto prot[TCP_BPF_NUM_CFGS], | |
541 | struct proto *base) | |
542 | { | |
543 | prot[TCP_BPF_BASE] = *base; | |
f747632b | 544 | prot[TCP_BPF_BASE].close = sock_map_close; |
604326b4 | 545 | prot[TCP_BPF_BASE].recvmsg = tcp_bpf_recvmsg; |
fb4e0a5e | 546 | prot[TCP_BPF_BASE].sock_is_readable = sk_msg_is_readable; |
604326b4 DB |
547 | |
548 | prot[TCP_BPF_TX] = prot[TCP_BPF_BASE]; | |
549 | prot[TCP_BPF_TX].sendmsg = tcp_bpf_sendmsg; | |
550 | prot[TCP_BPF_TX].sendpage = tcp_bpf_sendpage; | |
73e8fdc1 JF |
551 | |
552 | prot[TCP_BPF_RX] = prot[TCP_BPF_BASE]; | |
553 | prot[TCP_BPF_RX].recvmsg = tcp_bpf_recvmsg_parser; | |
554 | ||
555 | prot[TCP_BPF_TXRX] = prot[TCP_BPF_TX]; | |
556 | prot[TCP_BPF_TXRX].recvmsg = tcp_bpf_recvmsg_parser; | |
604326b4 DB |
557 | } |
558 | ||
7b219da4 | 559 | static void tcp_bpf_check_v6_needs_rebuild(struct proto *ops) |
604326b4 | 560 | { |
7b219da4 | 561 | if (unlikely(ops != smp_load_acquire(&tcpv6_prot_saved))) { |
604326b4 DB |
562 | spin_lock_bh(&tcpv6_prot_lock); |
563 | if (likely(ops != tcpv6_prot_saved)) { | |
564 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV6], ops); | |
565 | smp_store_release(&tcpv6_prot_saved, ops); | |
566 | } | |
567 | spin_unlock_bh(&tcpv6_prot_lock); | |
568 | } | |
569 | } | |
570 | ||
571 | static int __init tcp_bpf_v4_build_proto(void) | |
572 | { | |
573 | tcp_bpf_rebuild_protos(tcp_bpf_prots[TCP_BPF_IPV4], &tcp_prot); | |
574 | return 0; | |
575 | } | |
228a4a7b | 576 | late_initcall(tcp_bpf_v4_build_proto); |
604326b4 | 577 | |
604326b4 DB |
578 | static int tcp_bpf_assert_proto_ops(struct proto *ops) |
579 | { | |
580 | /* In order to avoid retpoline, we make assumptions when we call | |
581 | * into ops if e.g. a psock is not present. Make sure they are | |
582 | * indeed valid assumptions. | |
583 | */ | |
584 | return ops->recvmsg == tcp_recvmsg && | |
585 | ops->sendmsg == tcp_sendmsg && | |
586 | ops->sendpage == tcp_sendpage ? 0 : -ENOTSUPP; | |
587 | } | |
588 | ||
51e0158a | 589 | int tcp_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) |
604326b4 | 590 | { |
d19da360 LB |
591 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; |
592 | int config = psock->progs.msg_parser ? TCP_BPF_TX : TCP_BPF_BASE; | |
604326b4 | 593 | |
73e8fdc1 JF |
594 | if (psock->progs.stream_verdict || psock->progs.skb_verdict) { |
595 | config = (config == TCP_BPF_TX) ? TCP_BPF_TXRX : TCP_BPF_RX; | |
596 | } | |
597 | ||
8a59f9d1 CW |
598 | if (restore) { |
599 | if (inet_csk_has_ulp(sk)) { | |
8859a44e JK |
600 | /* TLS does not have an unhash proto in SW cases, |
601 | * but we need to ensure we stop using the sock_map | |
602 | * unhash routine because the associated psock is being | |
603 | * removed. So use the original unhash handler. | |
604 | */ | |
605 | WRITE_ONCE(sk->sk_prot->unhash, psock->saved_unhash); | |
8a59f9d1 CW |
606 | tcp_update_ulp(sk, psock->sk_proto, psock->saved_write_space); |
607 | } else { | |
608 | sk->sk_write_space = psock->saved_write_space; | |
609 | /* Pairs with lockless read in sk_clone_lock() */ | |
610 | WRITE_ONCE(sk->sk_prot, psock->sk_proto); | |
611 | } | |
612 | return 0; | |
613 | } | |
614 | ||
615 | if (inet_csk_has_ulp(sk)) | |
616 | return -EINVAL; | |
617 | ||
7b219da4 LB |
618 | if (sk->sk_family == AF_INET6) { |
619 | if (tcp_bpf_assert_proto_ops(psock->sk_proto)) | |
8a59f9d1 | 620 | return -EINVAL; |
d19da360 | 621 | |
7b219da4 | 622 | tcp_bpf_check_v6_needs_rebuild(psock->sk_proto); |
d19da360 LB |
623 | } |
624 | ||
8a59f9d1 CW |
625 | /* Pairs with lockless read in sk_clone_lock() */ |
626 | WRITE_ONCE(sk->sk_prot, &tcp_bpf_prots[family][config]); | |
627 | return 0; | |
604326b4 | 628 | } |
8a59f9d1 | 629 | EXPORT_SYMBOL_GPL(tcp_bpf_update_proto); |
604326b4 | 630 | |
e8025155 JS |
631 | /* If a child got cloned from a listening socket that had tcp_bpf |
632 | * protocol callbacks installed, we need to restore the callbacks to | |
633 | * the default ones because the child does not inherit the psock state | |
634 | * that tcp_bpf callbacks expect. | |
635 | */ | |
636 | void tcp_bpf_clone(const struct sock *sk, struct sock *newsk) | |
637 | { | |
638 | int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4; | |
639 | struct proto *prot = newsk->sk_prot; | |
640 | ||
641 | if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE]) | |
642 | newsk->sk_prot = sk->sk_prot_creator; | |
643 | } | |
88759609 | 644 | #endif /* CONFIG_BPF_SYSCALL */ |