]>
Commit | Line | Data |
---|---|---|
a4b75251 TL |
1 | import inspect |
2 | import json | |
3 | import logging | |
4 | from functools import wraps | |
5 | from typing import ClassVar, List, Optional, Type | |
6 | from urllib.parse import unquote | |
7 | ||
8 | import cherrypy | |
9 | ||
10 | from ..plugins import PLUGIN_MANAGER | |
11 | from ..services.auth import AuthManager, JwtManager | |
12 | from ..tools import get_request_body_params | |
13 | from ._helpers import _get_function_params | |
14 | from ._version import APIVersion | |
15 | ||
16 | logger = logging.getLogger(__name__) | |
17 | ||
18 | ||
19 | class BaseController: | |
20 | """ | |
21 | Base class for all controllers providing API endpoints. | |
22 | """ | |
23 | ||
24 | _registry: ClassVar[List[Type['BaseController']]] = [] | |
25 | _routed = False | |
26 | ||
27 | def __init_subclass__(cls, skip_registry: bool = False, **kwargs) -> None: | |
28 | super().__init_subclass__(**kwargs) # type: ignore | |
29 | if not skip_registry: | |
30 | BaseController._registry.append(cls) | |
31 | ||
32 | @classmethod | |
33 | def load_controllers(cls): | |
34 | import importlib | |
35 | from pathlib import Path | |
36 | ||
37 | path = Path(__file__).parent | |
38 | logger.debug('Controller import path: %s', path) | |
39 | modules = [ | |
40 | f.stem for f in path.glob('*.py') if | |
41 | not f.name.startswith('_') and f.is_file() and not f.is_symlink()] | |
42 | logger.debug('Controller files found: %r', modules) | |
43 | ||
44 | for module in modules: | |
45 | importlib.import_module(f'{__package__}.{module}') | |
46 | ||
47 | # pylint: disable=protected-access | |
48 | controllers = [ | |
49 | controller for controller in BaseController._registry if | |
50 | controller._routed | |
51 | ] | |
52 | ||
53 | for clist in PLUGIN_MANAGER.hook.get_controllers() or []: | |
54 | controllers.extend(clist) | |
55 | ||
56 | return controllers | |
57 | ||
58 | class Endpoint: | |
59 | """ | |
60 | An instance of this class represents an endpoint. | |
61 | """ | |
62 | ||
63 | def __init__(self, ctrl, func): | |
64 | self.ctrl = ctrl | |
65 | self.inst = None | |
66 | self.func = func | |
67 | ||
68 | if not self.config['proxy']: | |
69 | setattr(self.ctrl, func.__name__, self.function) | |
70 | ||
71 | @property | |
72 | def config(self): | |
73 | func = self.func | |
74 | while not hasattr(func, '_endpoint'): | |
75 | if hasattr(func, "__wrapped__"): | |
76 | func = func.__wrapped__ | |
77 | else: | |
78 | return None | |
79 | return func._endpoint # pylint: disable=protected-access | |
80 | ||
81 | @property | |
82 | def function(self): | |
83 | # pylint: disable=protected-access | |
84 | return self.ctrl._request_wrapper(self.func, self.method, | |
85 | self.config['json_response'], | |
86 | self.config['xml'], | |
87 | self.config['version']) | |
88 | ||
89 | @property | |
90 | def method(self): | |
91 | return self.config['method'] | |
92 | ||
93 | @property | |
94 | def proxy(self): | |
95 | return self.config['proxy'] | |
96 | ||
97 | @property | |
98 | def url(self): | |
99 | ctrl_path = self.ctrl.get_path() | |
100 | if ctrl_path == "/": | |
101 | ctrl_path = "" | |
102 | if self.config['path'] is not None: | |
103 | url = "{}{}".format(ctrl_path, self.config['path']) | |
104 | else: | |
105 | url = "{}/{}".format(ctrl_path, self.func.__name__) | |
106 | ||
107 | ctrl_path_params = self.ctrl.get_path_param_names( | |
108 | self.config['path']) | |
109 | path_params = [p['name'] for p in self.path_params | |
110 | if p['name'] not in ctrl_path_params] | |
111 | path_params = ["{{{}}}".format(p) for p in path_params] | |
112 | if path_params: | |
113 | url += "/{}".format("/".join(path_params)) | |
114 | ||
115 | return url | |
116 | ||
117 | @property | |
118 | def action(self): | |
119 | return self.func.__name__ | |
120 | ||
121 | @property | |
122 | def path_params(self): | |
123 | ctrl_path_params = self.ctrl.get_path_param_names( | |
124 | self.config['path']) | |
125 | func_params = _get_function_params(self.func) | |
126 | ||
127 | if self.method in ['GET', 'DELETE']: | |
128 | assert self.config['path_params'] is None | |
129 | ||
130 | return [p for p in func_params if p['name'] in ctrl_path_params | |
131 | or (p['name'] not in self.config['query_params'] | |
132 | and p['required'])] | |
133 | ||
134 | # elif self.method in ['POST', 'PUT']: | |
135 | return [p for p in func_params if p['name'] in ctrl_path_params | |
136 | or p['name'] in self.config['path_params']] | |
137 | ||
138 | @property | |
139 | def query_params(self): | |
140 | if self.method in ['GET', 'DELETE']: | |
141 | func_params = _get_function_params(self.func) | |
142 | path_params = [p['name'] for p in self.path_params] | |
143 | return [p for p in func_params if p['name'] not in path_params] | |
144 | ||
145 | # elif self.method in ['POST', 'PUT']: | |
146 | func_params = _get_function_params(self.func) | |
147 | return [p for p in func_params | |
148 | if p['name'] in self.config['query_params']] | |
149 | ||
150 | @property | |
151 | def body_params(self): | |
152 | func_params = _get_function_params(self.func) | |
153 | path_params = [p['name'] for p in self.path_params] | |
154 | query_params = [p['name'] for p in self.query_params] | |
155 | return [p for p in func_params | |
156 | if p['name'] not in path_params | |
157 | and p['name'] not in query_params] | |
158 | ||
159 | @property | |
160 | def group(self): | |
161 | return self.ctrl.__name__ | |
162 | ||
163 | @property | |
164 | def is_api(self): | |
165 | # changed from hasattr to getattr: some ui-based api inherit _api_endpoint | |
166 | return getattr(self.ctrl, '_api_endpoint', False) | |
167 | ||
168 | @property | |
169 | def is_secure(self): | |
170 | return self.ctrl._cp_config['tools.authenticate.on'] # pylint: disable=protected-access | |
171 | ||
172 | def __repr__(self): | |
173 | return "Endpoint({}, {}, {})".format(self.url, self.method, | |
174 | self.action) | |
175 | ||
176 | def __init__(self): | |
177 | logger.info('Initializing controller: %s -> %s', | |
178 | self.__class__.__name__, self._cp_path_) # type: ignore | |
179 | super().__init__() | |
180 | ||
181 | def _has_permissions(self, permissions, scope=None): | |
182 | if not self._cp_config['tools.authenticate.on']: # type: ignore | |
183 | raise Exception("Cannot verify permission in non secured " | |
184 | "controllers") | |
185 | ||
186 | if not isinstance(permissions, list): | |
187 | permissions = [permissions] | |
188 | ||
189 | if scope is None: | |
190 | scope = getattr(self, '_security_scope', None) | |
191 | if scope is None: | |
192 | raise Exception("Cannot verify permissions without scope security" | |
193 | " defined") | |
194 | username = JwtManager.LOCAL_USER.username | |
195 | return AuthManager.authorize(username, scope, permissions) | |
196 | ||
197 | @classmethod | |
198 | def get_path_param_names(cls, path_extension=None): | |
199 | if path_extension is None: | |
200 | path_extension = "" | |
201 | full_path = cls._cp_path_[1:] + path_extension # type: ignore | |
202 | path_params = [] | |
203 | for step in full_path.split('/'): | |
204 | param = None | |
205 | if not step: | |
206 | continue | |
207 | if step[0] == ':': | |
208 | param = step[1:] | |
209 | elif step[0] == '{' and step[-1] == '}': | |
210 | param, _, _ = step[1:-1].partition(':') | |
211 | if param: | |
212 | path_params.append(param) | |
213 | return path_params | |
214 | ||
215 | @classmethod | |
216 | def get_path(cls): | |
217 | return cls._cp_path_ # type: ignore | |
218 | ||
219 | @classmethod | |
220 | def endpoints(cls): | |
221 | """ | |
222 | This method iterates over all the methods decorated with ``@endpoint`` | |
223 | and creates an Endpoint object for each one of the methods. | |
224 | ||
225 | :return: A list of endpoint objects | |
226 | :rtype: list[BaseController.Endpoint] | |
227 | """ | |
228 | result = [] | |
229 | for _, func in inspect.getmembers(cls, predicate=callable): | |
230 | if hasattr(func, '_endpoint'): | |
231 | result.append(cls.Endpoint(cls, func)) | |
232 | return result | |
233 | ||
234 | @staticmethod | |
235 | def _request_wrapper(func, method, json_response, xml, # pylint: disable=unused-argument | |
236 | version: Optional[APIVersion]): | |
237 | # pylint: disable=too-many-branches | |
238 | @wraps(func) | |
239 | def inner(*args, **kwargs): | |
240 | client_version = None | |
241 | for key, value in kwargs.items(): | |
242 | if isinstance(value, str): | |
243 | kwargs[key] = unquote(value) | |
244 | ||
245 | # Process method arguments. | |
246 | params = get_request_body_params(cherrypy.request) | |
247 | kwargs.update(params) | |
248 | ||
249 | if version is not None: | |
250 | try: | |
251 | client_version = APIVersion.from_mime_type( | |
252 | cherrypy.request.headers['Accept']) | |
253 | except Exception: | |
254 | raise cherrypy.HTTPError( | |
255 | 415, "Unable to find version in request header") | |
256 | ||
257 | if version.supports(client_version): | |
258 | ret = func(*args, **kwargs) | |
259 | else: | |
260 | raise cherrypy.HTTPError( | |
261 | 415, | |
262 | f"Incorrect version: endpoint is '{version!s}', " | |
263 | f"client requested '{client_version!s}'" | |
264 | ) | |
265 | ||
266 | else: | |
267 | ret = func(*args, **kwargs) | |
268 | if isinstance(ret, bytes): | |
269 | ret = ret.decode('utf-8') | |
270 | if xml: | |
271 | if version: | |
272 | cherrypy.response.headers['Content-Type'] = \ | |
273 | 'application/vnd.ceph.api.v{}+xml'.format(version) | |
274 | else: | |
275 | cherrypy.response.headers['Content-Type'] = 'application/xml' | |
276 | return ret.encode('utf8') | |
277 | if json_response: | |
278 | if version: | |
279 | cherrypy.response.headers['Content-Type'] = \ | |
280 | 'application/vnd.ceph.api.v{}+json'.format(version) | |
281 | else: | |
282 | cherrypy.response.headers['Content-Type'] = 'application/json' | |
283 | ret = json.dumps(ret).encode('utf8') | |
284 | return ret | |
285 | return inner | |
286 | ||
287 | @property | |
288 | def _request(self): | |
289 | return self.Request(cherrypy.request) | |
290 | ||
291 | class Request(object): | |
292 | def __init__(self, cherrypy_req): | |
293 | self._creq = cherrypy_req | |
294 | ||
295 | @property | |
296 | def scheme(self): | |
297 | return self._creq.scheme | |
298 | ||
299 | @property | |
300 | def host(self): | |
301 | base = self._creq.base | |
302 | base = base[len(self.scheme)+3:] | |
303 | return base[:base.find(":")] if ":" in base else base | |
304 | ||
305 | @property | |
306 | def port(self): | |
307 | base = self._creq.base | |
308 | base = base[len(self.scheme)+3:] | |
309 | default_port = 443 if self.scheme == 'https' else 80 | |
310 | return int(base[base.find(":")+1:]) if ":" in base else default_port | |
311 | ||
312 | @property | |
313 | def path_info(self): | |
314 | return self._creq.path_info |