1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2021 Cong Wang <cong.wang@bytedance.com> */
4 #include <linux/skmsg.h>
7 #include <net/af_unix.h>
9 #define unix_sk_has_data(__sk, __psock) \
10 ({ !skb_queue_empty(&__sk->sk_receive_queue) || \
11 !skb_queue_empty(&__psock->ingress_skb) || \
12 !list_empty(&__psock->ingress_msg); \
15 static int unix_msg_wait_data(struct sock
*sk
, struct sk_psock
*psock
,
18 DEFINE_WAIT_FUNC(wait
, woken_wake_function
);
19 struct unix_sock
*u
= unix_sk(sk
);
22 if (sk
->sk_shutdown
& RCV_SHUTDOWN
)
28 add_wait_queue(sk_sleep(sk
), &wait
);
29 sk_set_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
30 if (!unix_sk_has_data(sk
, psock
)) {
31 mutex_unlock(&u
->iolock
);
32 wait_woken(&wait
, TASK_INTERRUPTIBLE
, timeo
);
33 mutex_lock(&u
->iolock
);
34 ret
= unix_sk_has_data(sk
, psock
);
36 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA
, sk
);
37 remove_wait_queue(sk_sleep(sk
), &wait
);
41 static int __unix_recvmsg(struct sock
*sk
, struct msghdr
*msg
,
42 size_t len
, int flags
)
44 if (sk
->sk_type
== SOCK_DGRAM
)
45 return __unix_dgram_recvmsg(sk
, msg
, len
, flags
);
47 return __unix_stream_recvmsg(sk
, msg
, len
, flags
);
50 static int unix_bpf_recvmsg(struct sock
*sk
, struct msghdr
*msg
,
51 size_t len
, int nonblock
, int flags
,
54 struct unix_sock
*u
= unix_sk(sk
);
55 struct sk_psock
*psock
;
58 psock
= sk_psock_get(sk
);
60 return __unix_recvmsg(sk
, msg
, len
, flags
);
62 mutex_lock(&u
->iolock
);
63 if (!skb_queue_empty(&sk
->sk_receive_queue
) &&
64 sk_psock_queue_empty(psock
)) {
65 mutex_unlock(&u
->iolock
);
66 sk_psock_put(sk
, psock
);
67 return __unix_recvmsg(sk
, msg
, len
, flags
);
71 copied
= sk_msg_recvmsg(sk
, psock
, msg
, len
, flags
);
76 timeo
= sock_rcvtimeo(sk
, nonblock
);
77 data
= unix_msg_wait_data(sk
, psock
, timeo
);
79 if (!sk_psock_queue_empty(psock
))
81 mutex_unlock(&u
->iolock
);
82 sk_psock_put(sk
, psock
);
83 return __unix_recvmsg(sk
, msg
, len
, flags
);
87 mutex_unlock(&u
->iolock
);
88 sk_psock_put(sk
, psock
);
92 static struct proto
*unix_dgram_prot_saved __read_mostly
;
93 static DEFINE_SPINLOCK(unix_dgram_prot_lock
);
94 static struct proto unix_dgram_bpf_prot
;
96 static struct proto
*unix_stream_prot_saved __read_mostly
;
97 static DEFINE_SPINLOCK(unix_stream_prot_lock
);
98 static struct proto unix_stream_bpf_prot
;
100 static void unix_dgram_bpf_rebuild_protos(struct proto
*prot
, const struct proto
*base
)
103 prot
->close
= sock_map_close
;
104 prot
->recvmsg
= unix_bpf_recvmsg
;
105 prot
->sock_is_readable
= sk_msg_is_readable
;
108 static void unix_stream_bpf_rebuild_protos(struct proto
*prot
,
109 const struct proto
*base
)
112 prot
->close
= sock_map_close
;
113 prot
->recvmsg
= unix_bpf_recvmsg
;
114 prot
->sock_is_readable
= sk_msg_is_readable
;
115 prot
->unhash
= sock_map_unhash
;
118 static void unix_dgram_bpf_check_needs_rebuild(struct proto
*ops
)
120 if (unlikely(ops
!= smp_load_acquire(&unix_dgram_prot_saved
))) {
121 spin_lock_bh(&unix_dgram_prot_lock
);
122 if (likely(ops
!= unix_dgram_prot_saved
)) {
123 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot
, ops
);
124 smp_store_release(&unix_dgram_prot_saved
, ops
);
126 spin_unlock_bh(&unix_dgram_prot_lock
);
130 static void unix_stream_bpf_check_needs_rebuild(struct proto
*ops
)
132 if (unlikely(ops
!= smp_load_acquire(&unix_stream_prot_saved
))) {
133 spin_lock_bh(&unix_stream_prot_lock
);
134 if (likely(ops
!= unix_stream_prot_saved
)) {
135 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot
, ops
);
136 smp_store_release(&unix_stream_prot_saved
, ops
);
138 spin_unlock_bh(&unix_stream_prot_lock
);
142 int unix_dgram_bpf_update_proto(struct sock
*sk
, struct sk_psock
*psock
, bool restore
)
144 if (sk
->sk_type
!= SOCK_DGRAM
)
148 sk
->sk_write_space
= psock
->saved_write_space
;
149 WRITE_ONCE(sk
->sk_prot
, psock
->sk_proto
);
153 unix_dgram_bpf_check_needs_rebuild(psock
->sk_proto
);
154 WRITE_ONCE(sk
->sk_prot
, &unix_dgram_bpf_prot
);
158 int unix_stream_bpf_update_proto(struct sock
*sk
, struct sk_psock
*psock
, bool restore
)
161 sk
->sk_write_space
= psock
->saved_write_space
;
162 WRITE_ONCE(sk
->sk_prot
, psock
->sk_proto
);
166 unix_stream_bpf_check_needs_rebuild(psock
->sk_proto
);
167 WRITE_ONCE(sk
->sk_prot
, &unix_stream_bpf_prot
);
171 void __init
unix_bpf_build_proto(void)
173 unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot
, &unix_dgram_proto
);
174 unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot
, &unix_stream_proto
);