1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
4 #include <linux/skmsg.h>
5 #include <linux/filter.h>
7 #include <linux/init.h>
8 #include <linux/wait.h>
10 #include <net/inet_common.h>
13 static int bpf_tcp_ingress(struct sock
*sk
, struct sk_psock
*psock
,
14 struct sk_msg
*msg
, u32 apply_bytes
, int flags
)
16 bool apply
= apply_bytes
;
17 struct scatterlist
*sge
;
22 tmp
= kzalloc(sizeof(*tmp
), __GFP_NOWARN
| GFP_KERNEL
);
27 tmp
->sg
.start
= msg
->sg
.start
;
30 sge
= sk_msg_elem(msg
, i
);
31 size
= (apply
&& apply_bytes
< sge
->length
) ?
32 apply_bytes
: sge
->length
;
33 if (!sk_wmem_schedule(sk
, size
)) {
39 sk_mem_charge(sk
, size
);
40 sk_msg_xfer(tmp
, msg
, i
, size
);
43 get_page(sk_msg_page(tmp
, i
));
44 sk_msg_iter_var_next(i
);
51 } while (i
!= msg
->sg
.end
);
55 sk_psock_queue_msg(psock
, tmp
);
56 sk_psock_data_ready(sk
, psock
);
66 static int tcp_bpf_push(struct sock
*sk
, struct sk_msg
*msg
, u32 apply_bytes
,
67 int flags
, bool uncharge
)
69 bool apply
= apply_bytes
;
70 struct scatterlist
*sge
;
78 sge
= sk_msg_elem(msg
, msg
->sg
.start
);
79 size
= (apply
&& apply_bytes
< sge
->length
) ?
80 apply_bytes
: sge
->length
;
84 tcp_rate_check_app_limited(sk
);
86 has_tx_ulp
= tls_sw_has_ctx_tx(sk
);
88 flags
|= MSG_SENDPAGE_NOPOLICY
;
89 ret
= kernel_sendpage_locked(sk
,
90 page
, off
, size
, flags
);
92 ret
= do_tcp_sendpages(sk
, page
, off
, size
, flags
);
103 sk_mem_uncharge(sk
, ret
);
111 sk_msg_iter_next(msg
, start
);
112 sg_init_table(sge
, 1);
113 if (msg
->sg
.start
== msg
->sg
.end
)
116 if (apply
&& !apply_bytes
)
123 static int tcp_bpf_push_locked(struct sock
*sk
, struct sk_msg
*msg
,
124 u32 apply_bytes
, int flags
, bool uncharge
)
129 ret
= tcp_bpf_push(sk
, msg
, apply_bytes
, flags
, uncharge
);
134 int tcp_bpf_sendmsg_redir(struct sock
*sk
, struct sk_msg
*msg
,
135 u32 bytes
, int flags
)
137 bool ingress
= sk_msg_to_ingress(msg
);
138 struct sk_psock
*psock
= sk_psock_get(sk
);
141 if (unlikely(!psock
))
144 ret
= ingress
? bpf_tcp_ingress(sk
, psock
, msg
, bytes
, flags
) :
145 tcp_bpf_push_locked(sk
, msg
, bytes
, flags
, false);
146 sk_psock_put(sk
, psock
);
149 EXPORT_SYMBOL_GPL(tcp_bpf_sendmsg_redir
);
151 #ifdef CONFIG_BPF_SYSCALL
152 static int tcp_msg_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
155 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
158 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
164 add_wait_queue(sk_sleep(sk
), &wait
);
165 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
166 ret
= sk_wait_event(sk
, &timeo
,
167 !list_empty(&psock
->ingress_msg
) ||
168 !skb_queue_empty(&sk
->sk_receive_queue
), &wait
);
169 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
170 remove_wait_queue(sk_sleep(sk
), &wait
);
174 static int tcp_bpf_recvmsg_parser(struct sock
*sk
,
181 struct sk_psock
*psock
;
184 if (unlikely(flags
& MSG_ERRQUEUE
))
185 return inet_recv_error(sk
, msg
, len
, addr_len
);
187 psock
= sk_psock_get(sk
);
188 if (unlikely(!psock
))
189 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
193 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
198 if (sock_flag(sk
, SOCK_DONE
))
202 copied
= sock_error(sk
);
206 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
209 if (sk
->sk_state
== TCP_CLOSE
) {
214 timeo
= sock_rcvtimeo(sk
, nonblock
);
220 if (signal_pending(current
)) {
221 copied
= sock_intr_errno(timeo
);
225 data
= tcp_msg_wait_data(sk
, psock
, timeo
);
226 if (data
&& !sk_psock_queue_empty(psock
))
227 goto msg_bytes_ready
;
232 sk_psock_put(sk
, psock
);
236 static int tcp_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
, size_t len
,
237 int nonblock
, int flags
, int *addr_len
)
239 struct sk_psock
*psock
;
242 if (unlikely(flags
& MSG_ERRQUEUE
))
243 return inet_recv_error(sk
, msg
, len
, addr_len
);
245 psock
= sk_psock_get(sk
);
246 if (unlikely(!psock
))
247 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
248 if (!skb_queue_empty(&sk
->sk_receive_queue
) &&
249 sk_psock_queue_empty(psock
)) {
250 sk_psock_put(sk
, psock
);
251 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
255 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
260 timeo
= sock_rcvtimeo(sk
, nonblock
);
261 data
= tcp_msg_wait_data(sk
, psock
, timeo
);
263 if (!sk_psock_queue_empty(psock
))
264 goto msg_bytes_ready
;
266 sk_psock_put(sk
, psock
);
267 return tcp_recvmsg(sk
, msg
, len
, nonblock
, flags
, addr_len
);
273 sk_psock_put(sk
, psock
);
277 static int tcp_bpf_send_verdict(struct sock
*sk
, struct sk_psock
*psock
,
278 struct sk_msg
*msg
, int *copied
, int flags
)
280 bool cork
= false, enospc
= sk_msg_full(msg
);
281 struct sock
*sk_redir
;
282 u32 tosend
, delta
= 0;
283 u32 eval
= __SK_NONE
;
287 if (psock
->eval
== __SK_NONE
) {
288 /* Track delta in msg size to add/subtract it on SK_DROP from
289 * returned to user copied size. This ensures user doesn't
290 * get a positive return code with msg_cut_data and SK_DROP
293 delta
= msg
->sg
.size
;
294 psock
->eval
= sk_psock_msg_verdict(sk
, psock
, msg
);
295 delta
-= msg
->sg
.size
;
298 if (msg
->cork_bytes
&&
299 msg
->cork_bytes
> msg
->sg
.size
&& !enospc
) {
300 psock
->cork_bytes
= msg
->cork_bytes
- msg
->sg
.size
;
302 psock
->cork
= kzalloc(sizeof(*psock
->cork
),
303 GFP_ATOMIC
| __GFP_NOWARN
);
307 memcpy(psock
->cork
, msg
, sizeof(*msg
));
311 tosend
= msg
->sg
.size
;
312 if (psock
->apply_bytes
&& psock
->apply_bytes
< tosend
)
313 tosend
= psock
->apply_bytes
;
315 switch (psock
->eval
) {
317 ret
= tcp_bpf_push(sk
, msg
, tosend
, flags
, true);
319 *copied
-= sk_msg_free(sk
, msg
);
322 sk_msg_apply_bytes(psock
, tosend
);
325 sk_redir
= psock
->sk_redir
;
326 sk_msg_apply_bytes(psock
, tosend
);
327 if (!psock
->apply_bytes
) {
328 /* Clean up before releasing the sock lock. */
330 psock
->eval
= __SK_NONE
;
331 psock
->sk_redir
= NULL
;
337 sk_msg_return(sk
, msg
, msg
->sg
.size
);
340 ret
= tcp_bpf_sendmsg_redir(sk_redir
, msg
, tosend
, flags
);
342 if (eval
== __SK_REDIRECT
)
346 if (unlikely(ret
< 0)) {
347 int free
= sk_msg_free_nocharge(sk
, msg
);
353 sk_msg_free(sk
, msg
);
361 sk_msg_free_partial(sk
, msg
, tosend
);
362 sk_msg_apply_bytes(psock
, tosend
);
363 *copied
-= (tosend
+ delta
);
368 if (!psock
->apply_bytes
) {
369 psock
->eval
= __SK_NONE
;
370 if (psock
->sk_redir
) {
371 sock_put(psock
->sk_redir
);
372 psock
->sk_redir
= NULL
;
376 msg
->sg
.data
[msg
->sg
.start
].page_link
&&
377 msg
->sg
.data
[msg
->sg
.start
].length
) {
378 if (eval
== __SK_REDIRECT
)
379 sk_mem_charge(sk
, msg
->sg
.size
);
386 static int tcp_bpf_sendmsg(struct sock
*sk
, struct msghdr
*msg
, size_t size
)
388 struct sk_msg tmp
, *msg_tx
= NULL
;
389 int copied
= 0, err
= 0;
390 struct sk_psock
*psock
;
394 /* Don't let internal do_tcp_sendpages() flags through */
395 flags
= (msg
->msg_flags
& ~MSG_SENDPAGE_DECRYPTED
);
396 flags
|= MSG_NO_SHARED_FRAGS
;
398 psock
= sk_psock_get(sk
);
399 if (unlikely(!psock
))
400 return tcp_sendmsg(sk
, msg
, size
);
403 timeo
= sock_sndtimeo(sk
, msg
->msg_flags
& MSG_DONTWAIT
);
404 while (msg_data_left(msg
)) {
413 copy
= msg_data_left(msg
);
414 if (!sk_stream_memory_free(sk
))
415 goto wait_for_sndbuf
;
417 msg_tx
= psock
->cork
;
423 osize
= msg_tx
->sg
.size
;
424 err
= sk_msg_alloc(sk
, msg_tx
, msg_tx
->sg
.size
+ copy
, msg_tx
->sg
.end
- 1);
427 goto wait_for_memory
;
429 copy
= msg_tx
->sg
.size
- osize
;
432 err
= sk_msg_memcopy_from_iter(sk
, &msg
->msg_iter
, msg_tx
,
435 sk_msg_trim(sk
, msg_tx
, osize
);
440 if (psock
->cork_bytes
) {
441 if (size
> psock
->cork_bytes
)
442 psock
->cork_bytes
= 0;
444 psock
->cork_bytes
-= size
;
445 if (psock
->cork_bytes
&& !enospc
)
447 /* All cork bytes are accounted, rerun the prog. */
448 psock
->eval
= __SK_NONE
;
449 psock
->cork_bytes
= 0;
452 err
= tcp_bpf_send_verdict(sk
, psock
, msg_tx
, &copied
, flags
);
453 if (unlikely(err
< 0))
457 set_bit(SOCK_NOSPACE
, &sk
->sk_socket
->flags
);
459 err
= sk_stream_wait_memory(sk
, &timeo
);
461 if (msg_tx
&& msg_tx
!= psock
->cork
)
462 sk_msg_free(sk
, msg_tx
);
468 err
= sk_stream_error(sk
, msg
->msg_flags
, err
);
470 sk_psock_put(sk
, psock
);
471 return copied
? copied
: err
;
474 static int tcp_bpf_sendpage(struct sock
*sk
, struct page
*page
, int offset
,
475 size_t size
, int flags
)
477 struct sk_msg tmp
, *msg
= NULL
;
478 int err
= 0, copied
= 0;
479 struct sk_psock
*psock
;
482 psock
= sk_psock_get(sk
);
483 if (unlikely(!psock
))
484 return tcp_sendpage(sk
, page
, offset
, size
, flags
);
494 /* Catch case where ring is full and sendpage is stalled. */
495 if (unlikely(sk_msg_full(msg
)))
498 sk_msg_page_add(msg
, page
, size
, offset
);
499 sk_mem_charge(sk
, size
);
501 if (sk_msg_full(msg
))
503 if (psock
->cork_bytes
) {
504 if (size
> psock
->cork_bytes
)
505 psock
->cork_bytes
= 0;
507 psock
->cork_bytes
-= size
;
508 if (psock
->cork_bytes
&& !enospc
)
510 /* All cork bytes are accounted, rerun the prog. */
511 psock
->eval
= __SK_NONE
;
512 psock
->cork_bytes
= 0;
515 err
= tcp_bpf_send_verdict(sk
, psock
, msg
, &copied
, flags
);
518 sk_psock_put(sk
, psock
);
519 return copied
? copied
: err
;
536 static struct proto
*tcpv6_prot_saved __read_mostly
;
537 static DEFINE_SPINLOCK(tcpv6_prot_lock
);
538 static struct proto tcp_bpf_prots
[TCP_BPF_NUM_PROTS
][TCP_BPF_NUM_CFGS
];
540 static void tcp_bpf_rebuild_protos(struct proto prot
[TCP_BPF_NUM_CFGS
],
543 prot
[TCP_BPF_BASE
] = *base
;
544 prot
[TCP_BPF_BASE
].close
= sock_map_close
;
545 prot
[TCP_BPF_BASE
].recvmsg
= tcp_bpf_recvmsg
;
546 prot
[TCP_BPF_BASE
].sock_is_readable
= sk_msg_is_readable
;
548 prot
[TCP_BPF_TX
] = prot
[TCP_BPF_BASE
];
549 prot
[TCP_BPF_TX
].sendmsg
= tcp_bpf_sendmsg
;
550 prot
[TCP_BPF_TX
].sendpage
= tcp_bpf_sendpage
;
552 prot
[TCP_BPF_RX
] = prot
[TCP_BPF_BASE
];
553 prot
[TCP_BPF_RX
].recvmsg
= tcp_bpf_recvmsg_parser
;
555 prot
[TCP_BPF_TXRX
] = prot
[TCP_BPF_TX
];
556 prot
[TCP_BPF_TXRX
].recvmsg
= tcp_bpf_recvmsg_parser
;
559 static void tcp_bpf_check_v6_needs_rebuild(struct proto
*ops
)
561 if (unlikely(ops
!= smp_load_acquire(&tcpv6_prot_saved
))) {
562 spin_lock_bh(&tcpv6_prot_lock
);
563 if (likely(ops
!= tcpv6_prot_saved
)) {
564 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV6
], ops
);
565 smp_store_release(&tcpv6_prot_saved
, ops
);
567 spin_unlock_bh(&tcpv6_prot_lock
);
571 static int __init
tcp_bpf_v4_build_proto(void)
573 tcp_bpf_rebuild_protos(tcp_bpf_prots
[TCP_BPF_IPV4
], &tcp_prot
);
576 late_initcall(tcp_bpf_v4_build_proto
);
578 static int tcp_bpf_assert_proto_ops(struct proto
*ops
)
580 /* In order to avoid retpoline, we make assumptions when we call
581 * into ops if e.g. a psock is not present. Make sure they are
582 * indeed valid assumptions.
584 return ops
->recvmsg
== tcp_recvmsg
&&
585 ops
->sendmsg
== tcp_sendmsg
&&
586 ops
->sendpage
== tcp_sendpage
? 0 : -ENOTSUPP
;
589 int tcp_bpf_update_proto(struct sock
*sk
, struct sk_psock
*psock
, bool restore
)
591 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
592 int config
= psock
->progs
.msg_parser
? TCP_BPF_TX
: TCP_BPF_BASE
;
594 if (psock
->progs
.stream_verdict
|| psock
->progs
.skb_verdict
) {
595 config
= (config
== TCP_BPF_TX
) ? TCP_BPF_TXRX
: TCP_BPF_RX
;
599 if (inet_csk_has_ulp(sk
)) {
600 /* TLS does not have an unhash proto in SW cases,
601 * but we need to ensure we stop using the sock_map
602 * unhash routine because the associated psock is being
603 * removed. So use the original unhash handler.
605 WRITE_ONCE(sk
->sk_prot
->unhash
, psock
->saved_unhash
);
606 tcp_update_ulp(sk
, psock
->sk_proto
, psock
->saved_write_space
);
608 sk
->sk_write_space
= psock
->saved_write_space
;
609 /* Pairs with lockless read in sk_clone_lock() */
610 WRITE_ONCE(sk
->sk_prot
, psock
->sk_proto
);
615 if (inet_csk_has_ulp(sk
))
618 if (sk
->sk_family
== AF_INET6
) {
619 if (tcp_bpf_assert_proto_ops(psock
->sk_proto
))
622 tcp_bpf_check_v6_needs_rebuild(psock
->sk_proto
);
625 /* Pairs with lockless read in sk_clone_lock() */
626 WRITE_ONCE(sk
->sk_prot
, &tcp_bpf_prots
[family
][config
]);
629 EXPORT_SYMBOL_GPL(tcp_bpf_update_proto
);
631 /* If a child got cloned from a listening socket that had tcp_bpf
632 * protocol callbacks installed, we need to restore the callbacks to
633 * the default ones because the child does not inherit the psock state
634 * that tcp_bpf callbacks expect.
636 void tcp_bpf_clone(const struct sock
*sk
, struct sock
*newsk
)
638 int family
= sk
->sk_family
== AF_INET6
? TCP_BPF_IPV6
: TCP_BPF_IPV4
;
639 struct proto
*prot
= newsk
->sk_prot
;
641 if (prot
== &tcp_bpf_prots
[family
][TCP_BPF_BASE
])
642 newsk
->sk_prot
= sk
->sk_prot_creator
;
644 #endif /* CONFIG_BPF_SYSCALL */