]>
Commit | Line | Data |
---|---|---|
aaaa20b6 VSO |
1 | #! /usr/bin/env python3 |
2 | """Generate coroutine wrappers for block subsystem. | |
3 | ||
4 | The program parses one or several concatenated c files from stdin, | |
76a2f554 | 5 | searches for functions with the 'co_wrapper' specifier |
aaaa20b6 VSO |
6 | and generates corresponding wrappers on stdout. |
7 | ||
8 | Usage: block-coroutine-wrapper.py generated-file.c FILE.[ch]... | |
9 | ||
10 | Copyright (c) 2020 Virtuozzo International GmbH. | |
11 | ||
12 | This program is free software; you can redistribute it and/or modify | |
13 | it under the terms of the GNU General Public License as published by | |
14 | the Free Software Foundation; either version 2 of the License, or | |
15 | (at your option) any later version. | |
16 | ||
17 | This program is distributed in the hope that it will be useful, | |
18 | but WITHOUT ANY WARRANTY; without even the implied warranty of | |
19 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | |
20 | GNU General Public License for more details. | |
21 | ||
22 | You should have received a copy of the GNU General Public License | |
23 | along with this program. If not, see <http://www.gnu.org/licenses/>. | |
24 | """ | |
25 | ||
26 | import sys | |
27 | import re | |
28 | from typing import Iterator | |
29 | ||
30 | ||
31 | def gen_header(): | |
32 | copyright = re.sub('^.*Copyright', 'Copyright', __doc__, flags=re.DOTALL) | |
33 | copyright = re.sub('^(?=.)', ' * ', copyright.strip(), flags=re.MULTILINE) | |
34 | copyright = re.sub('^$', ' *', copyright, flags=re.MULTILINE) | |
35 | return f"""\ | |
36 | /* | |
37 | * File is generated by scripts/block-coroutine-wrapper.py | |
38 | * | |
39 | {copyright} | |
40 | */ | |
41 | ||
42 | #include "qemu/osdep.h" | |
43 | #include "block/coroutines.h" | |
44 | #include "block/block-gen.h" | |
e2c1c34f MA |
45 | #include "block/block_int.h" |
46 | #include "block/dirty-bitmap.h" | |
aaaa20b6 VSO |
47 | """ |
48 | ||
49 | ||
50 | class ParamDecl: | |
51 | param_re = re.compile(r'(?P<decl>' | |
52 | r'(?P<type>.*[ *])' | |
53 | r'(?P<name>[a-z][a-z0-9_]*)' | |
54 | r')') | |
55 | ||
56 | def __init__(self, param_decl: str) -> None: | |
57 | m = self.param_re.match(param_decl.strip()) | |
58 | if m is None: | |
59 | raise ValueError(f'Wrong parameter declaration: "{param_decl}"') | |
60 | self.decl = m.group('decl') | |
61 | self.type = m.group('type') | |
62 | self.name = m.group('name') | |
63 | ||
64 | ||
65 | class FuncDecl: | |
76a2f554 EGE |
66 | def __init__(self, return_type: str, name: str, args: str, |
67 | variant: str) -> None: | |
aaaa20b6 VSO |
68 | self.return_type = return_type.strip() |
69 | self.name = name.strip() | |
76a2f554 | 70 | self.struct_name = snake_to_camel(self.name) |
aaaa20b6 | 71 | self.args = [ParamDecl(arg.strip()) for arg in args.split(',')] |
76a2f554 | 72 | self.create_only_co = 'mixed' not in variant |
e6d3f7a6 | 73 | self.graph_rdlock = 'bdrv_rdlock' in variant |
76a2f554 EGE |
74 | |
75 | subsystem, subname = self.name.split('_', 1) | |
76 | self.co_name = f'{subsystem}_co_{subname}' | |
77 | ||
78 | t = self.args[0].type | |
79 | if t == 'BlockDriverState *': | |
0582fb82 | 80 | ctx = 'bdrv_get_aio_context(bs)' |
76a2f554 | 81 | elif t == 'BdrvChild *': |
0582fb82 EGE |
82 | ctx = 'bdrv_get_aio_context(child->bs)' |
83 | elif t == 'BlockBackend *': | |
84 | ctx = 'blk_get_aio_context(blk)' | |
76a2f554 | 85 | else: |
0582fb82 EGE |
86 | ctx = 'qemu_get_aio_context()' |
87 | self.ctx = ctx | |
aaaa20b6 VSO |
88 | |
89 | def gen_list(self, format: str) -> str: | |
90 | return ', '.join(format.format_map(arg.__dict__) for arg in self.args) | |
91 | ||
92 | def gen_block(self, format: str) -> str: | |
93 | return '\n'.join(format.format_map(arg.__dict__) for arg in self.args) | |
94 | ||
95 | ||
76a2f554 | 96 | # Match wrappers declared with a co_wrapper mark |
6700dfb1 EGE |
97 | func_decl_re = re.compile(r'^(?P<return_type>[a-zA-Z][a-zA-Z0-9_]* [\*]?)' |
98 | r'\s*co_wrapper' | |
76a2f554 | 99 | r'(?P<variant>(_[a-z][a-z0-9_]*)?)\s*' |
aaaa20b6 VSO |
100 | r'(?P<wrapper_name>[a-z][a-z0-9_]*)' |
101 | r'\((?P<args>[^)]*)\);$', re.MULTILINE) | |
102 | ||
103 | ||
104 | def func_decl_iter(text: str) -> Iterator: | |
105 | for m in func_decl_re.finditer(text): | |
6700dfb1 | 106 | yield FuncDecl(return_type=m.group('return_type'), |
aaaa20b6 | 107 | name=m.group('wrapper_name'), |
76a2f554 EGE |
108 | args=m.group('args'), |
109 | variant=m.group('variant')) | |
aaaa20b6 VSO |
110 | |
111 | ||
112 | def snake_to_camel(func_name: str) -> str: | |
113 | """ | |
114 | Convert underscore names like 'some_function_name' to camel-case like | |
115 | 'SomeFunctionName' | |
116 | """ | |
117 | words = func_name.split('_') | |
118 | words = [w[0].upper() + w[1:] for w in words] | |
119 | return ''.join(words) | |
120 | ||
121 | ||
76a2f554 EGE |
122 | def create_mixed_wrapper(func: FuncDecl) -> str: |
123 | """ | |
124 | Checks if we are already in coroutine | |
125 | """ | |
126 | name = func.co_name | |
127 | struct_name = func.struct_name | |
e6d3f7a6 EGE |
128 | graph_assume_lock = 'assume_graph_lock();' if func.graph_rdlock else '' |
129 | ||
76a2f554 | 130 | return f"""\ |
6700dfb1 | 131 | {func.return_type} {func.name}({ func.gen_list('{decl}') }) |
76a2f554 EGE |
132 | {{ |
133 | if (qemu_in_coroutine()) {{ | |
e6d3f7a6 | 134 | {graph_assume_lock} |
76a2f554 EGE |
135 | return {name}({ func.gen_list('{name}') }); |
136 | }} else {{ | |
137 | {struct_name} s = {{ | |
0582fb82 | 138 | .poll_state.ctx = {func.ctx}, |
76a2f554 EGE |
139 | .poll_state.in_progress = true, |
140 | ||
141 | { func.gen_block(' .{name} = {name},') } | |
142 | }}; | |
143 | ||
144 | s.poll_state.co = qemu_coroutine_create({name}_entry, &s); | |
145 | ||
6700dfb1 EGE |
146 | bdrv_poll_co(&s.poll_state); |
147 | return s.ret; | |
76a2f554 EGE |
148 | }} |
149 | }}""" | |
150 | ||
151 | ||
152 | def create_co_wrapper(func: FuncDecl) -> str: | |
153 | """ | |
154 | Assumes we are not in coroutine, and creates one | |
155 | """ | |
156 | name = func.co_name | |
157 | struct_name = func.struct_name | |
158 | return f"""\ | |
6700dfb1 | 159 | {func.return_type} {func.name}({ func.gen_list('{decl}') }) |
76a2f554 EGE |
160 | {{ |
161 | {struct_name} s = {{ | |
0582fb82 | 162 | .poll_state.ctx = {func.ctx}, |
76a2f554 EGE |
163 | .poll_state.in_progress = true, |
164 | ||
165 | { func.gen_block(' .{name} = {name},') } | |
166 | }}; | |
167 | assert(!qemu_in_coroutine()); | |
168 | ||
169 | s.poll_state.co = qemu_coroutine_create({name}_entry, &s); | |
170 | ||
6700dfb1 EGE |
171 | bdrv_poll_co(&s.poll_state); |
172 | return s.ret; | |
76a2f554 EGE |
173 | }}""" |
174 | ||
175 | ||
aaaa20b6 | 176 | def gen_wrapper(func: FuncDecl) -> str: |
bb436948 | 177 | assert not '_co_' in func.name |
aaaa20b6 | 178 | |
76a2f554 EGE |
179 | name = func.co_name |
180 | struct_name = func.struct_name | |
7d55a3bb | 181 | |
e6d3f7a6 EGE |
182 | graph_lock='' |
183 | graph_unlock='' | |
184 | if func.graph_rdlock: | |
185 | graph_lock=' bdrv_graph_co_rdlock();' | |
186 | graph_unlock=' bdrv_graph_co_rdunlock();' | |
187 | ||
76a2f554 EGE |
188 | creation_function = create_mixed_wrapper |
189 | if func.create_only_co: | |
190 | creation_function = create_co_wrapper | |
aaaa20b6 VSO |
191 | |
192 | return f"""\ | |
193 | /* | |
194 | * Wrappers for {name} | |
195 | */ | |
196 | ||
197 | typedef struct {struct_name} {{ | |
198 | BdrvPollCo poll_state; | |
6700dfb1 | 199 | {func.return_type} ret; |
aaaa20b6 VSO |
200 | { func.gen_block(' {decl};') } |
201 | }} {struct_name}; | |
202 | ||
203 | static void coroutine_fn {name}_entry(void *opaque) | |
204 | {{ | |
205 | {struct_name} *s = opaque; | |
206 | ||
e6d3f7a6 | 207 | {graph_lock} |
6700dfb1 | 208 | s->ret = {name}({ func.gen_list('s->{name}') }); |
e6d3f7a6 | 209 | {graph_unlock} |
aaaa20b6 VSO |
210 | s->poll_state.in_progress = false; |
211 | ||
212 | aio_wait_kick(); | |
213 | }} | |
214 | ||
76a2f554 | 215 | {creation_function(func)}""" |
aaaa20b6 VSO |
216 | |
217 | ||
218 | def gen_wrappers(input_code: str) -> str: | |
219 | res = '' | |
220 | for func in func_decl_iter(input_code): | |
221 | res += '\n\n\n' | |
222 | res += gen_wrapper(func) | |
223 | ||
224 | return res | |
225 | ||
226 | ||
227 | if __name__ == '__main__': | |
228 | if len(sys.argv) < 3: | |
229 | exit(f'Usage: {sys.argv[0]} OUT_FILE.c IN_FILE.[ch]...') | |
230 | ||
231 | with open(sys.argv[1], 'w', encoding='utf-8') as f_out: | |
232 | f_out.write(gen_header()) | |
233 | for fname in sys.argv[2:]: | |
234 | with open(fname, encoding='utf-8') as f_in: | |
235 | f_out.write(gen_wrappers(f_in.read())) | |
236 | f_out.write('\n') |