]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/pybind/mgr/rbd_support/task.py
import quincy beta 17.1.0
[ceph.git] / ceph / src / pybind / mgr / rbd_support / task.py
index 87d43eca15a6e97bd103d088df2a793614e6d18b..d283962a365e3f07f1c84fa9da352360d0caf97b 100644 (file)
@@ -10,9 +10,10 @@ from contextlib import contextmanager
 from datetime import datetime, timedelta
 from functools import partial, wraps
 from threading import Condition, Lock, Thread
+from typing import cast, Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar
 
 from .common import (authorize_request, extract_pool_key, get_rbd_pools,
-                     is_authorized)
+                     is_authorized, GLOBAL_POOL_KEY)
 
 
 RBD_TASK_OID = "rbd_task"
@@ -53,52 +54,59 @@ TASK_MAX_RETRY_INTERVAL = timedelta(seconds=300)
 MAX_COMPLETED_TASKS = 50
 
 
+T = TypeVar('T')
+FuncT = TypeVar('FuncT', bound=Callable[..., Any])
+
+
 class Throttle:
-    def __init__(self, throttle_period):
+    def __init__(self: Any, throttle_period: timedelta) -> None:
         self.throttle_period = throttle_period
         self.time_of_last_call = datetime.min
 
-    def __call__(self, fn):
+    def __call__(self: 'Throttle', fn: FuncT) -> FuncT:
         @wraps(fn)
-        def wrapper(*args, **kwargs):
+        def wrapper(*args: Any, **kwargs: Any) -> Any:
             now = datetime.now()
             if self.time_of_last_call + self.throttle_period <= now:
                 self.time_of_last_call = now
                 return fn(*args, **kwargs)
-        return wrapper
+        return cast(FuncT, wrapper)
+
+
+TaskRefsT = Dict[str, str]
 
 
 class Task:
-    def __init__(self, sequence, task_id, message, refs):
+    def __init__(self, sequence: int, task_id: str, message: str, refs: TaskRefsT):
         self.sequence = sequence
         self.task_id = task_id
         self.message = message
         self.refs = refs
-        self.retry_message = None
+        self.retry_message: Optional[str] = None
         self.retry_attempts = 0
-        self.retry_time = None
+        self.retry_time: Optional[datetime] = None
         self.in_progress = False
         self.progress = 0.0
         self.canceled = False
         self.failed = False
         self.progress_posted = False
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self.to_json()
 
     @property
-    def sequence_key(self):
-        return "{0:016X}".format(self.sequence)
+    def sequence_key(self) -> bytes:
+        return "{0:016X}".format(self.sequence).encode()
 
-    def cancel(self):
+    def cancel(self) -> None:
         self.canceled = True
         self.fail("Operation canceled")
 
-    def fail(self, message):
+    def fail(self, message: str) -> None:
         self.failed = True
         self.failure_message = message
 
-    def to_dict(self):
+    def to_dict(self) -> Dict[str, Any]:
         d = {TASK_SEQUENCE: self.sequence,
              TASK_ID: self.task_id,
              TASK_MESSAGE: self.message,
@@ -117,11 +125,11 @@ class Task:
             d[TASK_CANCELED] = True
         return d
 
-    def to_json(self):
+    def to_json(self) -> str:
         return str(json.dumps(self.to_dict()))
 
     @classmethod
-    def from_json(cls, val):
+    def from_json(cls, val: str) -> 'Task':
         try:
             d = json.loads(val)
             action = d.get(TASK_REFS, {}).get(TASK_REF_ACTION)
@@ -135,20 +143,26 @@ class Task:
             raise ValueError("Invalid task format (missing key {})".format(str(e)))
 
 
+# pool_name, namespace, image_name
+ImageSpecT = Tuple[str, str, str]
+# pool_name, namespace
+PoolSpecT = Tuple[str, str]
+MigrationStatusT = Dict[str, str]
+
 class TaskHandler:
     lock = Lock()
     condition = Condition(lock)
     thread = None
 
     in_progress_task = None
-    tasks_by_sequence = dict()
-    tasks_by_id = dict()
+    tasks_by_sequence: Dict[int, Task] = dict()
+    tasks_by_id: Dict[str, Task] = dict()
 
-    completed_tasks = []
+    completed_tasks: List[Task] = []
 
     sequence = 0
 
-    def __init__(self, module):
+    def __init__(self, module: Any) -> None:
         self.module = module
         self.log = module.log
 
@@ -159,16 +173,16 @@ class TaskHandler:
         self.thread.start()
 
     @property
-    def default_pool_name(self):
+    def default_pool_name(self) -> str:
         return self.module.get_ceph_option("rbd_default_pool")
 
-    def extract_pool_spec(self, pool_spec):
+    def extract_pool_spec(self, pool_spec: str) -> PoolSpecT:
         pool_spec = extract_pool_key(pool_spec)
         if pool_spec == GLOBAL_POOL_KEY:
             pool_spec = (self.default_pool_name, '')
-        return pool_spec
+        return cast(PoolSpecT, pool_spec)
 
-    def extract_image_spec(self, image_spec):
+    def extract_image_spec(self, image_spec: str) -> ImageSpecT:
         match = re.match(r'^(?:([^/]+)/(?:([^/]+)/)?)?([^/@]+)$',
                          image_spec or '')
         if not match:
@@ -176,7 +190,7 @@ class TaskHandler:
         return (match.group(1) or self.default_pool_name, match.group(2) or '',
                 match.group(3))
 
-    def run(self):
+    def run(self) -> None:
         try:
             self.log.info("TaskHandler: starting")
             while True:
@@ -195,7 +209,7 @@ class TaskHandler:
                 ex, traceback.format_exc()))
 
     @contextmanager
-    def open_ioctx(self, spec):
+    def open_ioctx(self, spec: PoolSpecT) -> Iterator[rados.Ioctx]:
         try:
             with self.module.rados.open_ioctx(spec[0]) as ioctx:
                 ioctx.set_namespace(spec[1])
@@ -205,7 +219,7 @@ class TaskHandler:
             raise
 
     @classmethod
-    def format_image_spec(cls, image_spec):
+    def format_image_spec(cls, image_spec: ImageSpecT) -> str:
         image = image_spec[2]
         if image_spec[1]:
             image = "{}/{}".format(image_spec[1], image)
@@ -213,7 +227,7 @@ class TaskHandler:
             image = "{}/{}".format(image_spec[0], image)
         return image
 
-    def init_task_queue(self):
+    def init_task_queue(self) -> None:
         for pool_id, pool_name in get_rbd_pools(self.module).items():
             try:
                 with self.module.rados.open_ioctx2(int(pool_id)) as ioctx:
@@ -239,7 +253,7 @@ class TaskHandler:
         self.log.debug("sequence={}, tasks_by_sequence={}, tasks_by_id={}".format(
             self.sequence, str(self.tasks_by_sequence), str(self.tasks_by_id)))
 
-    def load_task_queue(self, ioctx, pool_name):
+    def load_task_queue(self, ioctx: rados.Ioctx, pool_name: str) -> None:
         pool_spec = pool_name
         if ioctx.nspace:
             pool_spec += "/{}".format(ioctx.nspace)
@@ -274,11 +288,11 @@ class TaskHandler:
             # rbd_task DNE
             pass
 
-    def append_task(self, task):
+    def append_task(self, task: Task) -> None:
         self.tasks_by_sequence[task.sequence] = task
         self.tasks_by_id[task.task_id] = task
 
-    def task_refs_match(self, task_refs, refs):
+    def task_refs_match(self, task_refs: TaskRefsT, refs: TaskRefsT) -> bool:
         if TASK_REF_IMAGE_ID not in refs and TASK_REF_IMAGE_ID in task_refs:
             task_refs = task_refs.copy()
             del task_refs[TASK_REF_IMAGE_ID]
@@ -286,7 +300,7 @@ class TaskHandler:
         self.log.debug("task_refs_match: ref1={}, ref2={}".format(task_refs, refs))
         return task_refs == refs
 
-    def find_task(self, refs):
+    def find_task(self, refs: TaskRefsT) -> Optional[Task]:
         self.log.debug("find_task: refs={}".format(refs))
 
         # search for dups and return the original
@@ -299,8 +313,13 @@ class TaskHandler:
         for task in reversed(self.completed_tasks):
             if self.task_refs_match(task.refs, refs):
                 return task
+        else:
+            return None
 
-    def add_task(self, ioctx, message, refs):
+    def add_task(self,
+                 ioctx: rados.Ioctx,
+                 message: str,
+                 refs: TaskRefsT) -> str:
         self.log.debug("add_task: message={}, refs={}".format(message, refs))
 
         # ensure unique uuid across all pools
@@ -316,7 +335,9 @@ class TaskHandler:
         task_json = task.to_json()
         omap_keys = (task.sequence_key, )
         omap_vals = (str.encode(task_json), )
-        self.log.info("adding task: {} {}".format(omap_keys[0], omap_vals[0]))
+        self.log.info("adding task: %s %s",
+                      omap_keys[0].decode(),
+                      omap_vals[0].decode())
 
         with rados.WriteOpCtx() as write_op:
             ioctx.set_omap(write_op, omap_keys, omap_vals)
@@ -326,7 +347,10 @@ class TaskHandler:
         self.condition.notify()
         return task_json
 
-    def remove_task(self, ioctx, task, remove_in_memory=True):
+    def remove_task(self,
+                    ioctx: rados.Ioctx,
+                    task: Task,
+                    remove_in_memory: bool = True) -> None:
         self.log.info("remove_task: task={}".format(str(task)))
         omap_keys = (task.sequence_key, )
         try:
@@ -351,7 +375,7 @@ class TaskHandler:
             except KeyError:
                 pass
 
-    def execute_task(self, sequence):
+    def execute_task(self, sequence: int) -> None:
         task = self.tasks_by_sequence[sequence]
         self.log.info("execute_task: task={}".format(str(task)))
 
@@ -414,7 +438,7 @@ class TaskHandler:
                 TASK_RETRY_INTERVAL * task.retry_attempts,
                 TASK_MAX_RETRY_INTERVAL)
 
-    def progress_callback(self, task, current, total):
+    def progress_callback(self, task: Task, current: int, total: int) -> int:
         progress = float(current) / float(total)
         self.log.debug("progress_callback: task={}, progress={}".format(
             str(task), progress))
@@ -438,7 +462,7 @@ class TaskHandler:
 
         return 0
 
-    def execute_flatten(self, ioctx, task):
+    def execute_flatten(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_flatten: task={}".format(str(task)))
 
         try:
@@ -451,7 +475,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_remove(self, ioctx, task):
+    def execute_remove(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_remove: task={}".format(str(task)))
 
         try:
@@ -461,7 +485,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_trash_remove(self, ioctx, task):
+    def execute_trash_remove(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_trash_remove: task={}".format(str(task)))
 
         try:
@@ -471,7 +495,7 @@ class TaskHandler:
             task.fail("Image does not exist")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_execute(self, ioctx, task):
+    def execute_migration_execute(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_execute: task={}".format(str(task)))
 
         try:
@@ -484,7 +508,7 @@ class TaskHandler:
             task.fail("Image is not migrating")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_commit(self, ioctx, task):
+    def execute_migration_commit(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_commit: task={}".format(str(task)))
 
         try:
@@ -497,7 +521,7 @@ class TaskHandler:
             task.fail("Image is not migrating or migration not executed")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def execute_migration_abort(self, ioctx, task):
+    def execute_migration_abort(self, ioctx: rados.Ioctx, task: Task) -> None:
         self.log.info("execute_migration_abort: task={}".format(str(task)))
 
         try:
@@ -510,7 +534,7 @@ class TaskHandler:
             task.fail("Image is not migrating")
             self.log.info("{}: task={}".format(task.failure_message, str(task)))
 
-    def complete_progress(self, task):
+    def complete_progress(self, task: Task) -> None:
         if not task.progress_posted:
             # ensure progress event exists before we complete/fail it
             self.post_progress(task, 0)
@@ -526,7 +550,7 @@ class TaskHandler:
             # progress module is disabled
             pass
 
-    def _update_progress(self, task, progress):
+    def _update_progress(self, task: Task, progress: float) -> None:
         self.log.debug("update_progress: task={}, progress={}".format(str(task), progress))
         try:
             refs = {"origin": "rbd_support"}
@@ -538,19 +562,19 @@ class TaskHandler:
             # progress module is disabled
             pass
 
-    def post_progress(self, task, progress):
+    def post_progress(self, task: Task, progress: float) -> None:
         self._update_progress(task, progress)
         task.progress_posted = True
 
-    def update_progress(self, task, progress):
+    def update_progress(self, task: Task, progress: float) -> None:
         if task.progress_posted:
             self._update_progress(task, progress)
 
     @Throttle(timedelta(seconds=1))
-    def throttled_update_progress(self, task, progress):
+    def throttled_update_progress(self, task: Task, progress: float) -> None:
         self.update_progress(task, progress)
 
-    def queue_flatten(self, image_spec):
+    def queue_flatten(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -561,7 +585,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             try:
                 with rbd.Image(ioctx, image_spec[2]) as image:
                     refs[TASK_REF_IMAGE_ID] = image.id()
@@ -590,7 +614,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ""
 
-    def queue_remove(self, image_spec):
+    def queue_remove(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -601,7 +625,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             try:
                 with rbd.Image(ioctx, image_spec[2]) as image:
                     refs[TASK_REF_IMAGE_ID] = image.id()
@@ -626,7 +650,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def queue_trash_remove(self, image_id_spec):
+    def queue_trash_remove(self, image_id_spec: str) -> Tuple[int, str, str]:
         image_id_spec = self.extract_image_spec(image_id_spec)
 
         authorize_request(self.module, image_id_spec[0], image_id_spec[1])
@@ -641,7 +665,7 @@ class TaskHandler:
             return 0, task.to_json(), ''
 
         # verify that image exists in trash
-        with self.open_ioctx(image_id_spec) as ioctx:
+        with self.open_ioctx(image_id_spec[:2]) as ioctx:
             rbd.RBD().trash_get(ioctx, image_id_spec[2])
 
             return 0, self.add_task(ioctx,
@@ -649,25 +673,29 @@ class TaskHandler:
                                         self.format_image_spec(image_id_spec)),
                                     refs), ''
 
-    def get_migration_status(self, ioctx, image_spec):
+    def get_migration_status(self,
+                             ioctx: rados.Ioctx,
+                             image_spec: ImageSpecT) -> Optional[MigrationStatusT]:
         try:
             return rbd.RBD().migration_status(ioctx, image_spec[2])
         except (rbd.InvalidArgument, rbd.ImageNotFound):
             return None
 
-    def validate_image_migrating(self, image_spec, migration_status):
+    def validate_image_migrating(self,
+                                 image_spec: ImageSpecT,
+                                 migration_status: Optional[MigrationStatusT]) -> None:
         if not migration_status:
             raise rbd.InvalidArgument("Image {} is not migrating".format(
                 self.format_image_spec(image_spec)), errno=errno.EINVAL)
 
-    def resolve_pool_name(self, pool_id):
+    def resolve_pool_name(self, pool_id: str) -> str:
         osd_map = self.module.get('osd_map')
         for pool in osd_map['pools']:
             if pool['pool'] == pool_id:
                 return pool['pool_name']
         return '<unknown>'
 
-    def queue_migration_execute(self, image_spec):
+    def queue_migration_execute(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -678,7 +706,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -688,6 +716,7 @@ class TaskHandler:
                 return 0, task.to_json(), ''
 
             self.validate_image_migrating(image_spec, status)
+            assert status
             if status['state'] not in [rbd.RBD_IMAGE_MIGRATION_STATE_PREPARED,
                                        rbd.RBD_IMAGE_MIGRATION_STATE_EXECUTING]:
                 raise rbd.InvalidArgument("Image {} is not in ready state".format(
@@ -705,7 +734,7 @@ class TaskHandler:
                                                                 status['dest_image_name']))),
                                     refs), ''
 
-    def queue_migration_commit(self, image_spec):
+    def queue_migration_commit(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -716,7 +745,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -726,6 +755,7 @@ class TaskHandler:
                 return 0, task.to_json(), ''
 
             self.validate_image_migrating(image_spec, status)
+            assert status
             if status['state'] != rbd.RBD_IMAGE_MIGRATION_STATE_EXECUTED:
                 raise rbd.InvalidArgument("Image {} has not completed migration".format(
                     self.format_image_spec(image_spec)), errno=errno.EINVAL)
@@ -735,7 +765,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def queue_migration_abort(self, image_spec):
+    def queue_migration_abort(self, image_spec: str) -> Tuple[int, str, str]:
         image_spec = self.extract_image_spec(image_spec)
 
         authorize_request(self.module, image_spec[0], image_spec[1])
@@ -746,7 +776,7 @@ class TaskHandler:
                 TASK_REF_POOL_NAMESPACE: image_spec[1],
                 TASK_REF_IMAGE_NAME: image_spec[2]}
 
-        with self.open_ioctx(image_spec) as ioctx:
+        with self.open_ioctx(image_spec[:2]) as ioctx:
             status = self.get_migration_status(ioctx, image_spec)
             if status:
                 refs[TASK_REF_IMAGE_ID] = status['dest_image_id']
@@ -761,7 +791,7 @@ class TaskHandler:
                                         self.format_image_spec(image_spec)),
                                     refs), ''
 
-    def task_cancel(self, task_id):
+    def task_cancel(self, task_id: str) -> Tuple[int, str, str]:
         self.log.info("task_cancel: {}".format(task_id))
 
         task = self.tasks_by_id.get(task_id)
@@ -787,7 +817,7 @@ class TaskHandler:
 
         return 0, "", ""
 
-    def task_list(self, task_id):
+    def task_list(self, task_id: Optional[str]) -> Tuple[int, str, str]:
         self.log.info("task_list: {}".format(task_id))
 
         if task_id:
@@ -797,35 +827,14 @@ class TaskHandler:
                                              task.refs[TASK_REF_POOL_NAMESPACE]):
                 return -errno.ENOENT, '', "No such task {}".format(task_id)
 
-            result = task.to_dict()
+            return 0, json.dumps(task.to_dict(), indent=4, sort_keys=True), ""
         else:
-            result = []
+            tasks = []
             for sequence in sorted(self.tasks_by_sequence.keys()):
                 task = self.tasks_by_sequence[sequence]
                 if is_authorized(self.module,
                                  task.refs[TASK_REF_POOL_NAME],
                                  task.refs[TASK_REF_POOL_NAMESPACE]):
-                    result.append(task.to_dict())
+                    tasks.append(task.to_dict())
 
-        return 0, json.dumps(result, indent=4, sort_keys=True), ""
-
-    def handle_command(self, inbuf, prefix, cmd):
-        with self.lock:
-            if prefix == 'add flatten':
-                return self.queue_flatten(cmd['image_spec'])
-            elif prefix == 'add remove':
-                return self.queue_remove(cmd['image_spec'])
-            elif prefix == 'add trash remove':
-                return self.queue_trash_remove(cmd['image_id_spec'])
-            elif prefix == 'add migration execute':
-                return self.queue_migration_execute(cmd['image_spec'])
-            elif prefix == 'add migration commit':
-                return self.queue_migration_commit(cmd['image_spec'])
-            elif prefix == 'add migration abort':
-                return self.queue_migration_abort(cmd['image_spec'])
-            elif prefix == 'cancel':
-                return self.task_cancel(cmd['task_id'])
-            elif prefix == 'list':
-                return self.task_list(cmd.get('task_id'))
-
-        raise NotImplementedError(cmd['prefix'])
+            return 0, json.dumps(tasks, indent=4, sort_keys=True), ""