X-Git-Url: https://git.proxmox.com/?a=blobdiff_plain;f=scripts%2Fblock-coroutine-wrapper.py;h=a38e5833fb3362776500f92c277ac860d6499253;hb=6bc0bcc89f847839cf3d459a55290dda8801d9d3;hp=85dbeb9ecf9c84abe042157ceace70ebb7a38d00;hpb=bb43694872c344e27d498c0980c50c7effcb448a;p=mirror_qemu.git diff --git a/scripts/block-coroutine-wrapper.py b/scripts/block-coroutine-wrapper.py index 85dbeb9ecf..a38e5833fb 100644 --- a/scripts/block-coroutine-wrapper.py +++ b/scripts/block-coroutine-wrapper.py @@ -2,7 +2,7 @@ """Generate coroutine wrappers for block subsystem. The program parses one or several concatenated c files from stdin, -searches for functions with the 'generated_co_wrapper' specifier +searches for functions with the 'co_wrapper' specifier and generates corresponding wrappers on stdout. Usage: block-coroutine-wrapper.py generated-file.c FILE.[ch]... @@ -42,7 +42,8 @@ def gen_header(): #include "qemu/osdep.h" #include "block/coroutines.h" #include "block/block-gen.h" -#include "block/block_int.h"\ +#include "block/block_int.h" +#include "block/dirty-bitmap.h" """ @@ -62,10 +63,58 @@ class ParamDecl: class FuncDecl: - def __init__(self, return_type: str, name: str, args: str) -> None: + def __init__(self, wrapper_type: str, return_type: str, name: str, + args: str, variant: str) -> None: self.return_type = return_type.strip() self.name = name.strip() + self.struct_name = snake_to_camel(self.name) self.args = [ParamDecl(arg.strip()) for arg in args.split(',')] + self.create_only_co = 'mixed' not in variant + self.graph_rdlock = 'bdrv_rdlock' in variant + self.graph_wrlock = 'bdrv_wrlock' in variant + + self.wrapper_type = wrapper_type + + if wrapper_type == 'co': + if self.graph_wrlock: + raise ValueError(f"co function can't be wrlock: {self.name}") + subsystem, subname = self.name.split('_', 1) + self.target_name = f'{subsystem}_co_{subname}' + else: + assert wrapper_type == 'no_co' + subsystem, co_infix, subname = self.name.split('_', 2) + if co_infix != 'co': + raise ValueError(f"Invalid no_co function name: {self.name}") + if not self.create_only_co: + raise ValueError(f"no_co function can't be mixed: {self.name}") + if self.graph_rdlock and self.graph_wrlock: + raise ValueError("function can't be both rdlock and wrlock: " + f"{self.name}") + self.target_name = f'{subsystem}_{subname}' + + self.ctx = self.gen_ctx() + + self.get_result = 's->ret = ' + self.ret = 'return s.ret;' + self.co_ret = 'return ' + self.return_field = self.return_type + " ret;" + if self.return_type == 'void': + self.get_result = '' + self.ret = '' + self.co_ret = '' + self.return_field = '' + + def gen_ctx(self, prefix: str = '') -> str: + t = self.args[0].type + name = self.args[0].name + if t == 'BlockDriverState *': + return f'bdrv_get_aio_context({prefix}{name})' + elif t == 'BdrvChild *': + return f'bdrv_get_aio_context({prefix}{name}->bs)' + elif t == 'BlockBackend *': + return f'blk_get_aio_context({prefix}{name})' + else: + return 'qemu_get_aio_context()' def gen_list(self, format: str) -> str: return ', '.join(format.format_map(arg.__dict__) for arg in self.args) @@ -74,17 +123,22 @@ class FuncDecl: return '\n'.join(format.format_map(arg.__dict__) for arg in self.args) -# Match wrappers declared with a generated_co_wrapper mark -func_decl_re = re.compile(r'^int\s*generated_co_wrapper\s*' +# Match wrappers declared with a co_wrapper mark +func_decl_re = re.compile(r'^(?P[a-zA-Z][a-zA-Z0-9_]* [\*]?)' + r'(\s*coroutine_fn)?' + r'\s*(?P(no_)?co)_wrapper' + r'(?P(_[a-z][a-z0-9_]*)?)\s*' r'(?P[a-z][a-z0-9_]*)' r'\((?P[^)]*)\);$', re.MULTILINE) def func_decl_iter(text: str) -> Iterator: for m in func_decl_re.finditer(text): - yield FuncDecl(return_type='int', + yield FuncDecl(wrapper_type=m.group('wrapper_type'), + return_type=m.group('return_type'), name=m.group('wrapper_name'), - args=m.group('args')) + args=m.group('args'), + variant=m.group('variant')) def snake_to_camel(func_name: str) -> str: @@ -97,16 +151,76 @@ def snake_to_camel(func_name: str) -> str: return ''.join(words) -def gen_wrapper(func: FuncDecl) -> str: +def create_mixed_wrapper(func: FuncDecl) -> str: + """ + Checks if we are already in coroutine + """ + name = func.target_name + struct_name = func.struct_name + graph_assume_lock = 'assume_graph_lock();' if func.graph_rdlock else '' + + return f"""\ +{func.return_type} {func.name}({ func.gen_list('{decl}') }) +{{ + if (qemu_in_coroutine()) {{ + {graph_assume_lock} + {func.co_ret}{name}({ func.gen_list('{name}') }); + }} else {{ + {struct_name} s = {{ + .poll_state.ctx = {func.ctx}, + .poll_state.in_progress = true, + +{ func.gen_block(' .{name} = {name},') } + }}; + + s.poll_state.co = qemu_coroutine_create({name}_entry, &s); + + bdrv_poll_co(&s.poll_state); + {func.ret} + }} +}}""" + + +def create_co_wrapper(func: FuncDecl) -> str: + """ + Assumes we are not in coroutine, and creates one + """ + name = func.target_name + struct_name = func.struct_name + return f"""\ +{func.return_type} {func.name}({ func.gen_list('{decl}') }) +{{ + {struct_name} s = {{ + .poll_state.ctx = {func.ctx}, + .poll_state.in_progress = true, + +{ func.gen_block(' .{name} = {name},') } + }}; + assert(!qemu_in_coroutine()); + + s.poll_state.co = qemu_coroutine_create({name}_entry, &s); + + bdrv_poll_co(&s.poll_state); + {func.ret} +}}""" + + +def gen_co_wrapper(func: FuncDecl) -> str: assert not '_co_' in func.name - assert func.return_type == 'int' - assert func.args[0].type in ['BlockDriverState *', 'BdrvChild *'] + assert func.wrapper_type == 'co' - subsystem, subname = func.name.split('_', 1) + name = func.target_name + struct_name = func.struct_name - name = f'{subsystem}_co_{subname}' - bs = 'bs' if func.args[0].type == 'BlockDriverState *' else 'child->bs' - struct_name = snake_to_camel(name) + graph_lock='' + graph_unlock='' + if func.graph_rdlock: + graph_lock=' bdrv_graph_co_rdlock();' + graph_unlock=' bdrv_graph_co_rdunlock();' + + creation_function = create_mixed_wrapper + if func.create_only_co: + creation_function = create_co_wrapper return f"""\ /* @@ -115,6 +229,7 @@ def gen_wrapper(func: FuncDecl) -> str: typedef struct {struct_name} {{ BdrvPollCo poll_state; + {func.return_field} { func.gen_block(' {decl};') } }} {struct_name}; @@ -122,28 +237,70 @@ static void coroutine_fn {name}_entry(void *opaque) {{ {struct_name} *s = opaque; - s->poll_state.ret = {name}({ func.gen_list('s->{name}') }); +{graph_lock} + {func.get_result}{name}({ func.gen_list('s->{name}') }); +{graph_unlock} s->poll_state.in_progress = false; aio_wait_kick(); }} -int {func.name}({ func.gen_list('{decl}') }) +{creation_function(func)}""" + + +def gen_no_co_wrapper(func: FuncDecl) -> str: + assert '_co_' in func.name + assert func.wrapper_type == 'no_co' + + name = func.target_name + struct_name = func.struct_name + + graph_lock='' + graph_unlock='' + if func.graph_rdlock: + graph_lock=' bdrv_graph_rdlock_main_loop();' + graph_unlock=' bdrv_graph_rdunlock_main_loop();' + elif func.graph_wrlock: + graph_lock=' bdrv_graph_wrlock(NULL);' + graph_unlock=' bdrv_graph_wrunlock(NULL);' + + return f"""\ +/* + * Wrappers for {name} + */ + +typedef struct {struct_name} {{ + Coroutine *co; + {func.return_field} +{ func.gen_block(' {decl};') } +}} {struct_name}; + +static void {name}_bh(void *opaque) {{ - if (qemu_in_coroutine()) {{ - return {name}({ func.gen_list('{name}') }); - }} else {{ - {struct_name} s = {{ - .poll_state.bs = {bs}, - .poll_state.in_progress = true, + {struct_name} *s = opaque; + AioContext *ctx = {func.gen_ctx('s->')}; -{ func.gen_block(' .{name} = {name},') } - }}; +{graph_lock} + aio_context_acquire(ctx); + {func.get_result}{name}({ func.gen_list('s->{name}') }); + aio_context_release(ctx); +{graph_unlock} - s.poll_state.co = qemu_coroutine_create({name}_entry, &s); + aio_co_wake(s->co); +}} - return bdrv_poll_co(&s.poll_state); - }} +{func.return_type} coroutine_fn {func.name}({ func.gen_list('{decl}') }) +{{ + {struct_name} s = {{ + .co = qemu_coroutine_self(), +{ func.gen_block(' .{name} = {name},') } + }}; + assert(qemu_in_coroutine()); + + aio_bh_schedule_oneshot(qemu_get_aio_context(), {name}_bh, &s); + qemu_coroutine_yield(); + + {func.ret} }}""" @@ -151,7 +308,10 @@ def gen_wrappers(input_code: str) -> str: res = '' for func in func_decl_iter(input_code): res += '\n\n\n' - res += gen_wrapper(func) + if func.wrapper_type == 'co': + res += gen_co_wrapper(func) + else: + res += gen_no_co_wrapper(func) return res