]> git.proxmox.com Git - mirror_ubuntu-jammy-kernel.git/blame - net/ipv4/bpf_tcp_ca.c
bpf: tcp: Support tcp_congestion_ops in bpf
[mirror_ubuntu-jammy-kernel.git] / net / ipv4 / bpf_tcp_ca.c
CommitLineData
0baf26b0
MKL
1// SPDX-License-Identifier: GPL-2.0
2/* Copyright (c) 2019 Facebook */
3
4#include <linux/types.h>
5#include <linux/bpf_verifier.h>
6#include <linux/bpf.h>
7#include <linux/btf.h>
8#include <linux/filter.h>
9#include <net/tcp.h>
10
11static u32 optional_ops[] = {
12 offsetof(struct tcp_congestion_ops, init),
13 offsetof(struct tcp_congestion_ops, release),
14 offsetof(struct tcp_congestion_ops, set_state),
15 offsetof(struct tcp_congestion_ops, cwnd_event),
16 offsetof(struct tcp_congestion_ops, in_ack_event),
17 offsetof(struct tcp_congestion_ops, pkts_acked),
18 offsetof(struct tcp_congestion_ops, min_tso_segs),
19 offsetof(struct tcp_congestion_ops, sndbuf_expand),
20 offsetof(struct tcp_congestion_ops, cong_control),
21};
22
23static u32 unsupported_ops[] = {
24 offsetof(struct tcp_congestion_ops, get_info),
25};
26
27static const struct btf_type *tcp_sock_type;
28static u32 tcp_sock_id, sock_id;
29
30static int bpf_tcp_ca_init(struct btf *btf)
31{
32 s32 type_id;
33
34 type_id = btf_find_by_name_kind(btf, "sock", BTF_KIND_STRUCT);
35 if (type_id < 0)
36 return -EINVAL;
37 sock_id = type_id;
38
39 type_id = btf_find_by_name_kind(btf, "tcp_sock", BTF_KIND_STRUCT);
40 if (type_id < 0)
41 return -EINVAL;
42 tcp_sock_id = type_id;
43 tcp_sock_type = btf_type_by_id(btf, tcp_sock_id);
44
45 return 0;
46}
47
48static bool is_optional(u32 member_offset)
49{
50 unsigned int i;
51
52 for (i = 0; i < ARRAY_SIZE(optional_ops); i++) {
53 if (member_offset == optional_ops[i])
54 return true;
55 }
56
57 return false;
58}
59
60static bool is_unsupported(u32 member_offset)
61{
62 unsigned int i;
63
64 for (i = 0; i < ARRAY_SIZE(unsupported_ops); i++) {
65 if (member_offset == unsupported_ops[i])
66 return true;
67 }
68
69 return false;
70}
71
72extern struct btf *btf_vmlinux;
73
74static bool bpf_tcp_ca_is_valid_access(int off, int size,
75 enum bpf_access_type type,
76 const struct bpf_prog *prog,
77 struct bpf_insn_access_aux *info)
78{
79 if (off < 0 || off >= sizeof(__u64) * MAX_BPF_FUNC_ARGS)
80 return false;
81 if (type != BPF_READ)
82 return false;
83 if (off % size != 0)
84 return false;
85
86 if (!btf_ctx_access(off, size, type, prog, info))
87 return false;
88
89 if (info->reg_type == PTR_TO_BTF_ID && info->btf_id == sock_id)
90 /* promote it to tcp_sock */
91 info->btf_id = tcp_sock_id;
92
93 return true;
94}
95
96static int bpf_tcp_ca_btf_struct_access(struct bpf_verifier_log *log,
97 const struct btf_type *t, int off,
98 int size, enum bpf_access_type atype,
99 u32 *next_btf_id)
100{
101 size_t end;
102
103 if (atype == BPF_READ)
104 return btf_struct_access(log, t, off, size, atype, next_btf_id);
105
106 if (t != tcp_sock_type) {
107 bpf_log(log, "only read is supported\n");
108 return -EACCES;
109 }
110
111 switch (off) {
112 case bpf_ctx_range(struct inet_connection_sock, icsk_ca_priv):
113 end = offsetofend(struct inet_connection_sock, icsk_ca_priv);
114 break;
115 case offsetof(struct inet_connection_sock, icsk_ack.pending):
116 end = offsetofend(struct inet_connection_sock,
117 icsk_ack.pending);
118 break;
119 case offsetof(struct tcp_sock, snd_cwnd):
120 end = offsetofend(struct tcp_sock, snd_cwnd);
121 break;
122 case offsetof(struct tcp_sock, snd_cwnd_cnt):
123 end = offsetofend(struct tcp_sock, snd_cwnd_cnt);
124 break;
125 case offsetof(struct tcp_sock, snd_ssthresh):
126 end = offsetofend(struct tcp_sock, snd_ssthresh);
127 break;
128 case offsetof(struct tcp_sock, ecn_flags):
129 end = offsetofend(struct tcp_sock, ecn_flags);
130 break;
131 default:
132 bpf_log(log, "no write support to tcp_sock at off %d\n", off);
133 return -EACCES;
134 }
135
136 if (off + size > end) {
137 bpf_log(log,
138 "write access at off %d with size %d beyond the member of tcp_sock ended at %zu\n",
139 off, size, end);
140 return -EACCES;
141 }
142
143 return NOT_INIT;
144}
145
146static const struct bpf_func_proto *
147bpf_tcp_ca_get_func_proto(enum bpf_func_id func_id,
148 const struct bpf_prog *prog)
149{
150 return bpf_base_func_proto(func_id);
151}
152
153static const struct bpf_verifier_ops bpf_tcp_ca_verifier_ops = {
154 .get_func_proto = bpf_tcp_ca_get_func_proto,
155 .is_valid_access = bpf_tcp_ca_is_valid_access,
156 .btf_struct_access = bpf_tcp_ca_btf_struct_access,
157};
158
159static int bpf_tcp_ca_init_member(const struct btf_type *t,
160 const struct btf_member *member,
161 void *kdata, const void *udata)
162{
163 const struct tcp_congestion_ops *utcp_ca;
164 struct tcp_congestion_ops *tcp_ca;
165 size_t tcp_ca_name_len;
166 int prog_fd;
167 u32 moff;
168
169 utcp_ca = (const struct tcp_congestion_ops *)udata;
170 tcp_ca = (struct tcp_congestion_ops *)kdata;
171
172 moff = btf_member_bit_offset(t, member) / 8;
173 switch (moff) {
174 case offsetof(struct tcp_congestion_ops, flags):
175 if (utcp_ca->flags & ~TCP_CONG_MASK)
176 return -EINVAL;
177 tcp_ca->flags = utcp_ca->flags;
178 return 1;
179 case offsetof(struct tcp_congestion_ops, name):
180 tcp_ca_name_len = strnlen(utcp_ca->name, sizeof(utcp_ca->name));
181 if (!tcp_ca_name_len ||
182 tcp_ca_name_len == sizeof(utcp_ca->name))
183 return -EINVAL;
184 if (tcp_ca_find(utcp_ca->name))
185 return -EEXIST;
186 memcpy(tcp_ca->name, utcp_ca->name, sizeof(tcp_ca->name));
187 return 1;
188 }
189
190 if (!btf_type_resolve_func_ptr(btf_vmlinux, member->type, NULL))
191 return 0;
192
193 /* Ensure bpf_prog is provided for compulsory func ptr */
194 prog_fd = (int)(*(unsigned long *)(udata + moff));
195 if (!prog_fd && !is_optional(moff) && !is_unsupported(moff))
196 return -EINVAL;
197
198 return 0;
199}
200
201static int bpf_tcp_ca_check_member(const struct btf_type *t,
202 const struct btf_member *member)
203{
204 if (is_unsupported(btf_member_bit_offset(t, member) / 8))
205 return -ENOTSUPP;
206 return 0;
207}
208
209static int bpf_tcp_ca_reg(void *kdata)
210{
211 return tcp_register_congestion_control(kdata);
212}
213
214static void bpf_tcp_ca_unreg(void *kdata)
215{
216 tcp_unregister_congestion_control(kdata);
217}
218
219/* Avoid sparse warning. It is only used in bpf_struct_ops.c. */
220extern struct bpf_struct_ops bpf_tcp_congestion_ops;
221
222struct bpf_struct_ops bpf_tcp_congestion_ops = {
223 .verifier_ops = &bpf_tcp_ca_verifier_ops,
224 .reg = bpf_tcp_ca_reg,
225 .unreg = bpf_tcp_ca_unreg,
226 .check_member = bpf_tcp_ca_check_member,
227 .init_member = bpf_tcp_ca_init_member,
228 .init = bpf_tcp_ca_init,
229 .name = "tcp_congestion_ops",
230};