]> git.proxmox.com Git - ceph.git/blobdiff - ceph/src/pybind/mgr/rbd_support/schedule.py
import quincy beta 17.1.0
[ceph.git] / ceph / src / pybind / mgr / rbd_support / schedule.py
index 09e9a26a9ea79e45b81d0f50a48a05cd2fc83b30..615c002a043d9651dde483c95fbf23bcc2f87054 100644 (file)
@@ -1,12 +1,15 @@
+import datetime
 import json
 import rados
 import rbd
 import re
 
-from datetime import datetime, timedelta, time
 from dateutil.parser import parse
+from typing import cast, Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKING
 
 from .common import get_rbd_pools
+if TYPE_CHECKING:
+    from .module import Module
 
 SCHEDULE_INTERVAL = "interval"
 SCHEDULE_START_TIME = "start_time"
@@ -14,17 +17,22 @@ SCHEDULE_START_TIME = "start_time"
 
 class LevelSpec:
 
-    def __init__(self, name, id, pool_id, namespace, image_id=None):
+    def __init__(self,
+                 name: str,
+                 id: str,
+                 pool_id: Optional[str],
+                 namespace: Optional[str],
+                 image_id: Optional[str] = None) -> None:
         self.name = name
         self.id = id
         self.pool_id = pool_id
         self.namespace = namespace
         self.image_id = image_id
 
-    def __eq__(self, level_spec):
+    def __eq__(self, level_spec: Any) -> bool:
         return self.id == level_spec.id
 
-    def is_child_of(self, level_spec):
+    def is_child_of(self, level_spec: 'LevelSpec') -> bool:
         if level_spec.is_global():
             return not self.is_global()
         if level_spec.pool_id != self.pool_id:
@@ -37,13 +45,16 @@ class LevelSpec:
             return self.image_id is not None
         return False
 
-    def is_global(self):
+    def is_global(self) -> bool:
         return self.pool_id is None
 
-    def get_pool_id(self):
+    def get_pool_id(self) -> Optional[str]:
         return self.pool_id
 
-    def matches(self, pool_id, namespace, image_id=None):
+    def matches(self,
+                pool_id: str,
+                namespace: str,
+                image_id: Optional[str] = None) -> bool:
         if self.pool_id and self.pool_id != pool_id:
             return False
         if self.namespace and self.namespace != namespace:
@@ -52,7 +63,7 @@ class LevelSpec:
             return False
         return True
 
-    def intersects(self, level_spec):
+    def intersects(self, level_spec: 'LevelSpec') -> bool:
         if self.pool_id is None or level_spec.pool_id is None:
             return True
         if self.pool_id != level_spec.pool_id:
@@ -68,11 +79,14 @@ class LevelSpec:
         return True
 
     @classmethod
-    def make_global(cls):
+    def make_global(cls) -> 'LevelSpec':
         return LevelSpec("", "", None, None, None)
 
     @classmethod
-    def from_pool_spec(cls, pool_id, pool_name, namespace=None):
+    def from_pool_spec(cls,
+                       pool_id: int,
+                       pool_name: str,
+                       namespace: Optional[str] = None) -> 'LevelSpec':
         if namespace is None:
             id = "{}".format(pool_id)
             name = "{}/".format(pool_name)
@@ -82,8 +96,12 @@ class LevelSpec:
         return LevelSpec(name, id, str(pool_id), namespace, None)
 
     @classmethod
-    def from_name(cls, handler, name, namespace_validator=None,
-                  image_validator=None, allow_image_level=True):
+    def from_name(cls,
+                  module: 'Module',
+                  name: str,
+                  namespace_validator: Optional[Callable] = None,
+                  image_validator: Optional[Callable] = None,
+                  allow_image_level: bool = True) -> 'LevelSpec':
         # parse names like:
         # '', 'rbd/', 'rbd/ns/', 'rbd//image', 'rbd/image', 'rbd/ns/image'
         match = re.match(r'^(?:([^/]+)/(?:(?:([^/]*)/|)(?:([^/@]+))?)?)?$',
@@ -102,16 +120,15 @@ class LevelSpec:
         if match.group(1):
             pool_name = match.group(1)
             try:
-                pool_id = handler.module.rados.pool_lookup(pool_name)
+                pool_id = module.rados.pool_lookup(pool_name)
                 if pool_id is None:
                     raise ValueError("pool {} does not exist".format(pool_name))
-                if pool_id not in get_rbd_pools(handler.module):
+                if pool_id not in get_rbd_pools(module):
                     raise ValueError("{} is not an RBD pool".format(pool_name))
-                pool_id = str(pool_id)
-                id += pool_id
+                id += str(pool_id)
                 if match.group(2) is not None or match.group(3):
                     id += "/"
-                    with handler.module.rados.open_ioctx(pool_name) as ioctx:
+                    with module.rados.open_ioctx(pool_name) as ioctx:
                         namespace = match.group(2) or ""
                         if namespace:
                             namespaces = rbd.RBD().namespace_list(ioctx)
@@ -150,8 +167,11 @@ class LevelSpec:
         return LevelSpec(name, id, pool_id, namespace, image_id)
 
     @classmethod
-    def from_id(cls, handler, id, namespace_validator=None,
-                image_validator=None):
+    def from_id(cls,
+                handler: Any,
+                id: str,
+                namespace_validator: Optional[Callable] = None,
+                image_validator: Optional[Callable] = None) -> 'LevelSpec':
         # parse ids like:
         # '', '123', '123/', '123/ns', '123//image_id', '123/ns/image_id'
         match = re.match(r'^(?:(\d+)(?:/([^/]*)(?:/([^/@]+))?)?)?$', id)
@@ -209,16 +229,16 @@ class LevelSpec:
 
 class Interval:
 
-    def __init__(self, minutes):
+    def __init__(self, minutes: int) -> None:
         self.minutes = minutes
 
-    def __eq__(self, interval):
+    def __eq__(self, interval: Any) -> bool:
         return self.minutes == interval.minutes
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return hash(self.minutes)
 
-    def to_string(self):
+    def to_string(self) -> str:
         if self.minutes % (60 * 24) == 0:
             interval = int(self.minutes / (60 * 24))
             units = 'd'
@@ -232,7 +252,7 @@ class Interval:
         return "{}{}".format(interval, units)
 
     @classmethod
-    def from_string(cls, interval):
+    def from_string(cls, interval: str) -> 'Interval':
         match = re.match(r'^(\d+)(d|h|m)?$', interval)
         if not match:
             raise ValueError("Invalid interval ({})".format(interval))
@@ -248,23 +268,27 @@ class Interval:
 
 class StartTime:
 
-    def __init__(self, hour, minute, tzinfo):
-        self.time = time(hour, minute, tzinfo=tzinfo)
+    def __init__(self,
+                 hour: int,
+                 minute: int,
+                 tzinfo: Optional[datetime.tzinfo]) -> None:
+        self.time = datetime.time(hour, minute, tzinfo=tzinfo)
         self.minutes = self.time.hour * 60 + self.time.minute
         if self.time.tzinfo:
-            self.minutes += int(self.time.utcoffset().seconds / 60)
+            utcoffset = cast(datetime.timedelta, self.time.utcoffset())
+            self.minutes += int(utcoffset.seconds / 60)
 
-    def __eq__(self, start_time):
+    def __eq__(self, start_time: Any) -> bool:
         return self.minutes == start_time.minutes
 
-    def __hash__(self):
+    def __hash__(self) -> int:
         return hash(self.minutes)
 
-    def to_string(self):
+    def to_string(self) -> str:
         return self.time.isoformat()
 
     @classmethod
-    def from_string(cls, start_time):
+    def from_string(cls, start_time: Optional[str]) -> Optional['StartTime']:
         if not start_time:
             return None
 
@@ -278,42 +302,56 @@ class StartTime:
 
 class Schedule:
 
-    def __init__(self, name):
+    def __init__(self, name: str) -> None:
         self.name = name
-        self.items = set()
+        self.items: Set[Tuple[Interval, Optional[StartTime]]] = set()
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.items)
 
-    def add(self, interval, start_time=None):
+    def add(self,
+            interval: Interval,
+            start_time: Optional[StartTime] = None) -> None:
         self.items.add((interval, start_time))
 
-    def remove(self, interval, start_time=None):
+    def remove(self,
+               interval: Interval,
+               start_time: Optional[StartTime] = None) -> None:
         self.items.discard((interval, start_time))
 
-    def next_run(self, now):
+    def next_run(self, now: datetime.datetime) -> str:
         schedule_time = None
-        for item in self.items:
-            period = timedelta(minutes=item[0].minutes)
-            start_time = datetime(1970, 1, 1)
-            if item[1]:
-                start_time += timedelta(minutes=item[1].minutes)
+        for interval, opt_start in self.items:
+            period = datetime.timedelta(minutes=interval.minutes)
+            start_time = datetime.datetime(1970, 1, 1)
+            if opt_start:
+                start = cast(StartTime, opt_start)
+                start_time += datetime.timedelta(minutes=start.minutes)
             time = start_time + \
                 (int((now - start_time) / period) + 1) * period
             if schedule_time is None or time < schedule_time:
                 schedule_time = time
-        return datetime.strftime(schedule_time, "%Y-%m-%d %H:%M:00")
-
-    def to_list(self):
-        return [{SCHEDULE_INTERVAL: i[0].to_string(),
-                 SCHEDULE_START_TIME: i[1] and i[1].to_string() or None}
-                for i in self.items]
+        if schedule_time is None:
+            raise ValueError('no items is added')
+        return datetime.datetime.strftime(schedule_time, "%Y-%m-%d %H:%M:00")
+
+    def to_list(self) -> List[Dict[str, Optional[str]]]:
+        def item_to_dict(interval: Interval,
+                         start_time: Optional[StartTime]) -> Dict[str, Optional[str]]:
+            if start_time:
+                schedule_start_time: Optional[str] = start_time.to_string()
+            else:
+                schedule_start_time = None
+            return {SCHEDULE_INTERVAL: interval.to_string(),
+                    SCHEDULE_START_TIME: schedule_start_time}
+        return [item_to_dict(interval, start_time)
+                for interval, start_time in self.items]
 
-    def to_json(self):
+    def to_json(self) -> str:
         return json.dumps(self.to_list(), indent=4, sort_keys=True)
 
     @classmethod
-    def from_json(cls, name, val):
+    def from_json(cls, name: str, val: str) -> 'Schedule':
         try:
             items = json.loads(val)
             schedule = Schedule(name)
@@ -331,17 +369,20 @@ class Schedule:
         except TypeError as e:
             raise ValueError("Invalid schedule format ({})".format(str(e)))
 
+
 class Schedules:
 
-    def __init__(self, handler):
+    def __init__(self, handler: Any) -> None:
         self.handler = handler
-        self.level_specs = {}
-        self.schedules = {}
+        self.level_specs: Dict[str, LevelSpec] = {}
+        self.schedules: Dict[str, Schedule] = {}
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self.schedules)
 
-    def load(self, namespace_validator=None, image_validator=None):
+    def load(self,
+             namespace_validator: Optional[Callable] = None,
+             image_validator: Optional[Callable] = None) -> None:
 
         schedule_cfg = self.handler.module.get_module_option(
             self.handler.MODULE_OPTION_NAME, '')
@@ -378,10 +419,13 @@ class Schedules:
                     "Failed to load schedules for pool {}: {}".format(
                         pool_name, e))
 
-    def load_from_pool(self, ioctx, namespace_validator, image_validator):
+    def load_from_pool(self,
+                       ioctx: rados.Ioctx,
+                       namespace_validator: Optional[Callable],
+                       image_validator: Optional[Callable]) -> None:
         pool_id = ioctx.get_pool_id()
         pool_name = ioctx.get_pool_name()
-        stale_keys = ()
+        stale_keys = []
         start_after = ''
         try:
             while True:
@@ -405,9 +449,9 @@ class Schedules:
                                     image_validator)
                             except ValueError:
                                 self.handler.log.debug(
-                                    "Stail schedule key {} in pool".format(
-                                        k, pool_name))
-                                stale_keys += (k,)
+                                    "Stale schedule key %s in pool %s",
+                                    k, pool_name)
+                                stale_keys.append(k)
                                 continue
 
                             self.level_specs[level_spec.id] = level_spec
@@ -430,7 +474,7 @@ class Schedules:
                 ioctx.remove_omap_keys(write_op, stale_keys)
                 ioctx.operate_write_op(write_op, self.handler.SCHEDULE_OID)
 
-    def save(self, level_spec, schedule):
+    def save(self, level_spec: LevelSpec, schedule: Optional[Schedule]) -> None:
         if level_spec.is_global():
             schedule_cfg = schedule and schedule.to_json() or None
             self.handler.module.set_module_option(
@@ -438,6 +482,7 @@ class Schedules:
             return
 
         pool_id = level_spec.get_pool_id()
+        assert pool_id
         with self.handler.module.rados.open_ioctx2(int(pool_id)) as ioctx:
             with rados.WriteOpCtx() as write_op:
                 if schedule:
@@ -447,8 +492,10 @@ class Schedules:
                     ioctx.remove_omap_keys(write_op, (level_spec.id, ))
                 ioctx.operate_write_op(write_op, self.handler.SCHEDULE_OID)
 
-
-    def add(self, level_spec, interval, start_time):
+    def add(self,
+            level_spec: LevelSpec,
+            interval: str,
+            start_time: Optional[str]) -> None:
         schedule = self.schedules.get(level_spec.id, Schedule(level_spec.name))
         schedule.add(Interval.from_string(interval),
                      StartTime.from_string(start_time))
@@ -456,41 +503,51 @@ class Schedules:
         self.level_specs[level_spec.id] = level_spec
         self.save(level_spec, schedule)
 
-    def remove(self, level_spec, interval, start_time):
+    def remove(self,
+               level_spec: LevelSpec,
+               interval: Optional[str],
+               start_time: Optional[str]) -> None:
         schedule = self.schedules.pop(level_spec.id, None)
         if schedule:
             if interval is None:
                 schedule = None
             else:
-                schedule.remove(Interval.from_string(interval),
-                                StartTime.from_string(start_time))
-                if schedule:
-                    self.schedules[level_spec.id] = schedule
+                try:
+                    schedule.remove(Interval.from_string(interval),
+                                    StartTime.from_string(start_time))
+                finally:
+                    if schedule:
+                        self.schedules[level_spec.id] = schedule
             if not schedule:
                 del self.level_specs[level_spec.id]
         self.save(level_spec, schedule)
 
-    def find(self, pool_id, namespace, image_id=None):
-        levels = [None, pool_id, namespace]
+    def find(self,
+             pool_id: str,
+             namespace: str,
+             image_id: Optional[str] = None) -> Optional['Schedule']:
+        levels = [pool_id, namespace]
         if image_id:
             levels.append(image_id)
-
-        while levels:
-            level_spec_id = "/".join(levels[1:])
-            if level_spec_id in self.schedules:
-                return self.schedules[level_spec_id]
-            del levels[-1]
+        nr_levels = len(levels)
+        while nr_levels >= 0:
+            # an empty spec id implies global schedule
+            level_spec_id = "/".join(levels[:nr_levels])
+            found = self.schedules.get(level_spec_id)
+            if found is not None:
+                return found
+            nr_levels -= 1
         return None
 
-    def intersects(self, level_spec):
+    def intersects(self, level_spec: LevelSpec) -> bool:
         for ls in self.level_specs.values():
             if ls.intersects(level_spec):
                 return True
         return False
 
-    def to_list(self, level_spec):
+    def to_list(self, level_spec: LevelSpec) -> Dict[str, dict]:
         if level_spec.id in self.schedules:
-            parent = level_spec
+            parent: Optional[LevelSpec] = level_spec
         else:
             # try to find existing parent
             parent = None
@@ -515,4 +572,3 @@ class Schedules:
                     'schedule' : schedule.to_list(),
                 }
         return result
-