]>
Commit | Line | Data |
---|---|---|
aaaa20b6 VSO |
1 | #! /usr/bin/env python3 |
2 | """Generate coroutine wrappers for block subsystem. | |
3 | ||
4 | The program parses one or several concatenated c files from stdin, | |
76a2f554 | 5 | searches for functions with the 'co_wrapper' specifier |
aaaa20b6 VSO |
6 | and generates corresponding wrappers on stdout. |
7 | ||
8 | Usage: block-coroutine-wrapper.py generated-file.c FILE.[ch]... | |
9 | ||
10 | Copyright (c) 2020 Virtuozzo International GmbH. | |
11 | ||
12 | This program is free software; you can redistribute it and/or modify | |
13 | it under the terms of the GNU General Public License as published by | |
14 | the Free Software Foundation; either version 2 of the License, or | |
15 | (at your option) any later version. | |
16 | ||
17 | This program is distributed in the hope that it will be useful, | |
18 | but WITHOUT ANY WARRANTY; without even the implied warranty of | |
19 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
20 | GNU General Public License for more details. | |
21 | ||
22 | You should have received a copy of the GNU General Public License | |
23 | along with this program. If not, see <http://www.gnu.org/licenses/>. | |
24 | """ | |
25 | ||
26 | import sys | |
27 | import re | |
28 | from typing import Iterator | |
29 | ||
30 | ||
31 | def gen_header(): | |
32 | copyright = re.sub('^.*Copyright', 'Copyright', __doc__, flags=re.DOTALL) | |
33 | copyright = re.sub('^(?=.)', ' * ', copyright.strip(), flags=re.MULTILINE) | |
34 | copyright = re.sub('^$', ' *', copyright, flags=re.MULTILINE) | |
35 | return f"""\ | |
36 | /* | |
37 | * File is generated by scripts/block-coroutine-wrapper.py | |
38 | * | |
39 | {copyright} | |
40 | */ | |
41 | ||
42 | #include "qemu/osdep.h" | |
43 | #include "block/coroutines.h" | |
44 | #include "block/block-gen.h" | |
e2c1c34f MA |
45 | #include "block/block_int.h" |
46 | #include "block/dirty-bitmap.h" | |
aaaa20b6 VSO |
47 | """ |
48 | ||
49 | ||
50 | class ParamDecl: | |
51 | param_re = re.compile(r'(?P<decl>' | |
52 | r'(?P<type>.*[ *])' | |
53 | r'(?P<name>[a-z][a-z0-9_]*)' | |
54 | r')') | |
55 | ||
56 | def __init__(self, param_decl: str) -> None: | |
57 | m = self.param_re.match(param_decl.strip()) | |
58 | if m is None: | |
59 | raise ValueError(f'Wrong parameter declaration: "{param_decl}"') | |
60 | self.decl = m.group('decl') | |
61 | self.type = m.group('type') | |
62 | self.name = m.group('name') | |
63 | ||
64 | ||
65 | class FuncDecl: | |
d6ee2e32 KW |
66 | def __init__(self, wrapper_type: str, return_type: str, name: str, |
67 | args: str, variant: str) -> None: | |
aaaa20b6 VSO |
68 | self.return_type = return_type.strip() |
69 | self.name = name.strip() | |
76a2f554 | 70 | self.struct_name = snake_to_camel(self.name) |
aaaa20b6 | 71 | self.args = [ParamDecl(arg.strip()) for arg in args.split(',')] |
76a2f554 | 72 | self.create_only_co = 'mixed' not in variant |
e6d3f7a6 | 73 | self.graph_rdlock = 'bdrv_rdlock' in variant |
de903298 | 74 | self.graph_wrlock = 'bdrv_wrlock' in variant |
76a2f554 | 75 | |
d6ee2e32 KW |
76 | self.wrapper_type = wrapper_type |
77 | ||
78 | if wrapper_type == 'co': | |
de903298 KW |
79 | if self.graph_wrlock: |
80 | raise ValueError(f"co function can't be wrlock: {self.name}") | |
d6ee2e32 KW |
81 | subsystem, subname = self.name.split('_', 1) |
82 | self.target_name = f'{subsystem}_co_{subname}' | |
83 | else: | |
84 | assert wrapper_type == 'no_co' | |
85 | subsystem, co_infix, subname = self.name.split('_', 2) | |
86 | if co_infix != 'co': | |
87 | raise ValueError(f"Invalid no_co function name: {self.name}") | |
88 | if not self.create_only_co: | |
89 | raise ValueError(f"no_co function can't be mixed: {self.name}") | |
e84c07bc KW |
90 | if self.graph_rdlock and self.graph_wrlock: |
91 | raise ValueError("function can't be both rdlock and wrlock: " | |
92 | f"{self.name}") | |
d6ee2e32 | 93 | self.target_name = f'{subsystem}_{subname}' |
76a2f554 | 94 | |
dea97c1f | 95 | self.ctx = self.gen_ctx() |
aaaa20b6 | 96 | |
5b317b8d EGE |
97 | self.get_result = 's->ret = ' |
98 | self.ret = 'return s.ret;' | |
99 | self.co_ret = 'return ' | |
100 | self.return_field = self.return_type + " ret;" | |
101 | if self.return_type == 'void': | |
102 | self.get_result = '' | |
103 | self.ret = '' | |
104 | self.co_ret = '' | |
105 | self.return_field = '' | |
106 | ||
dea97c1f KW |
107 | def gen_ctx(self, prefix: str = '') -> str: |
108 | t = self.args[0].type | |
d2184349 | 109 | name = self.args[0].name |
dea97c1f | 110 | if t == 'BlockDriverState *': |
d2184349 | 111 | return f'bdrv_get_aio_context({prefix}{name})' |
dea97c1f | 112 | elif t == 'BdrvChild *': |
d2184349 | 113 | return f'bdrv_get_aio_context({prefix}{name}->bs)' |
dea97c1f | 114 | elif t == 'BlockBackend *': |
d2184349 | 115 | return f'blk_get_aio_context({prefix}{name})' |
dea97c1f KW |
116 | else: |
117 | return 'qemu_get_aio_context()' | |
118 | ||
aaaa20b6 VSO |
119 | def gen_list(self, format: str) -> str: |
120 | return ', '.join(format.format_map(arg.__dict__) for arg in self.args) | |
121 | ||
122 | def gen_block(self, format: str) -> str: | |
123 | return '\n'.join(format.format_map(arg.__dict__) for arg in self.args) | |
124 | ||
125 | ||
76a2f554 | 126 | # Match wrappers declared with a co_wrapper mark |
6700dfb1 | 127 | func_decl_re = re.compile(r'^(?P<return_type>[a-zA-Z][a-zA-Z0-9_]* [\*]?)' |
d6ee2e32 KW |
128 | r'(\s*coroutine_fn)?' |
129 | r'\s*(?P<wrapper_type>(no_)?co)_wrapper' | |
76a2f554 | 130 | r'(?P<variant>(_[a-z][a-z0-9_]*)?)\s*' |
aaaa20b6 VSO |
131 | r'(?P<wrapper_name>[a-z][a-z0-9_]*)' |
132 | r'\((?P<args>[^)]*)\);$', re.MULTILINE) | |
133 | ||
134 | ||
135 | def func_decl_iter(text: str) -> Iterator: | |
136 | for m in func_decl_re.finditer(text): | |
d6ee2e32 KW |
137 | yield FuncDecl(wrapper_type=m.group('wrapper_type'), |
138 | return_type=m.group('return_type'), | |
aaaa20b6 | 139 | name=m.group('wrapper_name'), |
76a2f554 EGE |
140 | args=m.group('args'), |
141 | variant=m.group('variant')) | |
aaaa20b6 VSO |
142 | |
143 | ||
144 | def snake_to_camel(func_name: str) -> str: | |
145 | """ | |
146 | Convert underscore names like 'some_function_name' to camel-case like | |
147 | 'SomeFunctionName' | |
148 | """ | |
149 | words = func_name.split('_') | |
150 | words = [w[0].upper() + w[1:] for w in words] | |
151 | return ''.join(words) | |
152 | ||
153 | ||
76a2f554 EGE |
154 | def create_mixed_wrapper(func: FuncDecl) -> str: |
155 | """ | |
156 | Checks if we are already in coroutine | |
157 | """ | |
d6ee2e32 | 158 | name = func.target_name |
76a2f554 | 159 | struct_name = func.struct_name |
e6d3f7a6 EGE |
160 | graph_assume_lock = 'assume_graph_lock();' if func.graph_rdlock else '' |
161 | ||
76a2f554 | 162 | return f"""\ |
6700dfb1 | 163 | {func.return_type} {func.name}({ func.gen_list('{decl}') }) |
76a2f554 EGE |
164 | {{ |
165 | if (qemu_in_coroutine()) {{ | |
e6d3f7a6 | 166 | {graph_assume_lock} |
5b317b8d | 167 | {func.co_ret}{name}({ func.gen_list('{name}') }); |
76a2f554 EGE |
168 | }} else {{ |
169 | {struct_name} s = {{ | |
0582fb82 | 170 | .poll_state.ctx = {func.ctx}, |
76a2f554 EGE |
171 | .poll_state.in_progress = true, |
172 | ||
173 | { func.gen_block(' .{name} = {name},') } | |
174 | }}; | |
175 | ||
176 | s.poll_state.co = qemu_coroutine_create({name}_entry, &s); | |
177 | ||
6700dfb1 | 178 | bdrv_poll_co(&s.poll_state); |
5b317b8d | 179 | {func.ret} |
76a2f554 EGE |
180 | }} |
181 | }}""" | |
182 | ||
183 | ||
184 | def create_co_wrapper(func: FuncDecl) -> str: | |
185 | """ | |
186 | Assumes we are not in coroutine, and creates one | |
187 | """ | |
d6ee2e32 | 188 | name = func.target_name |
76a2f554 EGE |
189 | struct_name = func.struct_name |
190 | return f"""\ | |
6700dfb1 | 191 | {func.return_type} {func.name}({ func.gen_list('{decl}') }) |
76a2f554 EGE |
192 | {{ |
193 | {struct_name} s = {{ | |
0582fb82 | 194 | .poll_state.ctx = {func.ctx}, |
76a2f554 EGE |
195 | .poll_state.in_progress = true, |
196 | ||
197 | { func.gen_block(' .{name} = {name},') } | |
198 | }}; | |
199 | assert(!qemu_in_coroutine()); | |
200 | ||
201 | s.poll_state.co = qemu_coroutine_create({name}_entry, &s); | |
202 | ||
6700dfb1 | 203 | bdrv_poll_co(&s.poll_state); |
5b317b8d | 204 | {func.ret} |
76a2f554 EGE |
205 | }}""" |
206 | ||
207 | ||
d6ee2e32 | 208 | def gen_co_wrapper(func: FuncDecl) -> str: |
bb436948 | 209 | assert not '_co_' in func.name |
d6ee2e32 | 210 | assert func.wrapper_type == 'co' |
aaaa20b6 | 211 | |
d6ee2e32 | 212 | name = func.target_name |
76a2f554 | 213 | struct_name = func.struct_name |
7d55a3bb | 214 | |
e6d3f7a6 EGE |
215 | graph_lock='' |
216 | graph_unlock='' | |
217 | if func.graph_rdlock: | |
218 | graph_lock=' bdrv_graph_co_rdlock();' | |
219 | graph_unlock=' bdrv_graph_co_rdunlock();' | |
220 | ||
76a2f554 EGE |
221 | creation_function = create_mixed_wrapper |
222 | if func.create_only_co: | |
223 | creation_function = create_co_wrapper | |
aaaa20b6 VSO |
224 | |
225 | return f"""\ | |
226 | /* | |
227 | * Wrappers for {name} | |
228 | */ | |
229 | ||
230 | typedef struct {struct_name} {{ | |
231 | BdrvPollCo poll_state; | |
5b317b8d | 232 | {func.return_field} |
aaaa20b6 VSO |
233 | { func.gen_block(' {decl};') } |
234 | }} {struct_name}; | |
235 | ||
236 | static void coroutine_fn {name}_entry(void *opaque) | |
237 | {{ | |
238 | {struct_name} *s = opaque; | |
239 | ||
e6d3f7a6 | 240 | {graph_lock} |
5b317b8d | 241 | {func.get_result}{name}({ func.gen_list('s->{name}') }); |
e6d3f7a6 | 242 | {graph_unlock} |
aaaa20b6 VSO |
243 | s->poll_state.in_progress = false; |
244 | ||
245 | aio_wait_kick(); | |
246 | }} | |
247 | ||
76a2f554 | 248 | {creation_function(func)}""" |
aaaa20b6 VSO |
249 | |
250 | ||
d6ee2e32 KW |
251 | def gen_no_co_wrapper(func: FuncDecl) -> str: |
252 | assert '_co_' in func.name | |
253 | assert func.wrapper_type == 'no_co' | |
254 | ||
255 | name = func.target_name | |
256 | struct_name = func.struct_name | |
257 | ||
de903298 KW |
258 | graph_lock='' |
259 | graph_unlock='' | |
e84c07bc KW |
260 | if func.graph_rdlock: |
261 | graph_lock=' bdrv_graph_rdlock_main_loop();' | |
262 | graph_unlock=' bdrv_graph_rdunlock_main_loop();' | |
263 | elif func.graph_wrlock: | |
de903298 | 264 | graph_lock=' bdrv_graph_wrlock(NULL);' |
6bc0bcc8 | 265 | graph_unlock=' bdrv_graph_wrunlock(NULL);' |
de903298 | 266 | |
d6ee2e32 KW |
267 | return f"""\ |
268 | /* | |
269 | * Wrappers for {name} | |
270 | */ | |
271 | ||
272 | typedef struct {struct_name} {{ | |
273 | Coroutine *co; | |
274 | {func.return_field} | |
275 | { func.gen_block(' {decl};') } | |
276 | }} {struct_name}; | |
277 | ||
278 | static void {name}_bh(void *opaque) | |
279 | {{ | |
280 | {struct_name} *s = opaque; | |
dea97c1f | 281 | AioContext *ctx = {func.gen_ctx('s->')}; |
d6ee2e32 | 282 | |
de903298 | 283 | {graph_lock} |
dea97c1f | 284 | aio_context_acquire(ctx); |
d6ee2e32 | 285 | {func.get_result}{name}({ func.gen_list('s->{name}') }); |
dea97c1f | 286 | aio_context_release(ctx); |
de903298 | 287 | {graph_unlock} |
d6ee2e32 KW |
288 | |
289 | aio_co_wake(s->co); | |
290 | }} | |
291 | ||
292 | {func.return_type} coroutine_fn {func.name}({ func.gen_list('{decl}') }) | |
293 | {{ | |
294 | {struct_name} s = {{ | |
295 | .co = qemu_coroutine_self(), | |
296 | { func.gen_block(' .{name} = {name},') } | |
297 | }}; | |
298 | assert(qemu_in_coroutine()); | |
299 | ||
300 | aio_bh_schedule_oneshot(qemu_get_aio_context(), {name}_bh, &s); | |
301 | qemu_coroutine_yield(); | |
302 | ||
303 | {func.ret} | |
304 | }}""" | |
305 | ||
306 | ||
aaaa20b6 VSO |
307 | def gen_wrappers(input_code: str) -> str: |
308 | res = '' | |
309 | for func in func_decl_iter(input_code): | |
310 | res += '\n\n\n' | |
d6ee2e32 KW |
311 | if func.wrapper_type == 'co': |
312 | res += gen_co_wrapper(func) | |
313 | else: | |
314 | res += gen_no_co_wrapper(func) | |
aaaa20b6 VSO |
315 | |
316 | return res | |
317 | ||
318 | ||
319 | if __name__ == '__main__': | |
320 | if len(sys.argv) < 3: | |
321 | exit(f'Usage: {sys.argv[0]} OUT_FILE.c IN_FILE.[ch]...') | |
322 | ||
323 | with open(sys.argv[1], 'w', encoding='utf-8') as f_out: | |
324 | f_out.write(gen_header()) | |
325 | for fname in sys.argv[2:]: | |
326 | with open(fname, encoding='utf-8') as f_in: | |
327 | f_out.write(gen_wrappers(f_in.read())) | |
328 | f_out.write('\n') |