]>
Commit | Line | Data |
---|---|---|
11fdf7f2 TL |
1 | # -*- coding: utf-8 -*- |
2 | # pylint: disable=protected-access,too-many-branches | |
3 | from __future__ import absolute_import | |
4 | ||
5 | import collections | |
6 | import importlib | |
7 | import inspect | |
8 | import json | |
9f95a23c | 9 | import logging |
11fdf7f2 TL |
10 | import os |
11 | import pkgutil | |
eafe8130 | 12 | import re |
11fdf7f2 TL |
13 | import sys |
14 | ||
9f95a23c TL |
15 | import six |
16 | from six.moves.urllib.parse import unquote | |
11fdf7f2 TL |
17 | |
18 | # pylint: disable=wrong-import-position | |
19 | import cherrypy | |
20 | ||
11fdf7f2 TL |
21 | from ..security import Scope, Permission |
22 | from ..tools import wraps, getargspec, TaskManager, get_request_body_params | |
23 | from ..exceptions import ScopeNotValid, PermissionNotValid | |
24 | from ..services.auth import AuthManager, JwtManager | |
25 | from ..plugins import PLUGIN_MANAGER | |
26 | ||
9f95a23c TL |
27 | try: |
28 | from typing import Any, List, Optional | |
29 | except ImportError: | |
30 | pass # For typing only | |
11fdf7f2 | 31 | |
9f95a23c TL |
32 | |
33 | def EndpointDoc(description="", group="", parameters=None, responses=None): # noqa: N802 | |
11fdf7f2 TL |
34 | if not isinstance(description, str): |
35 | raise Exception("%s has been called with a description that is not a string: %s" | |
36 | % (EndpointDoc.__name__, description)) | |
9f95a23c | 37 | if not isinstance(group, str): |
11fdf7f2 TL |
38 | raise Exception("%s has been called with a groupname that is not a string: %s" |
39 | % (EndpointDoc.__name__, group)) | |
9f95a23c | 40 | if parameters and not isinstance(parameters, dict): |
11fdf7f2 TL |
41 | raise Exception("%s has been called with parameters that is not a dict: %s" |
42 | % (EndpointDoc.__name__, parameters)) | |
9f95a23c | 43 | if responses and not isinstance(responses, dict): |
11fdf7f2 TL |
44 | raise Exception("%s has been called with responses that is not a dict: %s" |
45 | % (EndpointDoc.__name__, responses)) | |
46 | ||
47 | if not parameters: | |
48 | parameters = {} | |
49 | ||
50 | def _split_param(name, p_type, description, optional=False, default_value=None, nested=False): | |
51 | param = { | |
52 | 'name': name, | |
53 | 'description': description, | |
54 | 'required': not optional, | |
55 | 'nested': nested, | |
56 | } | |
57 | if default_value: | |
58 | param['default'] = default_value | |
59 | if isinstance(p_type, type): | |
60 | param['type'] = p_type | |
61 | else: | |
62 | nested_params = _split_parameters(p_type, nested=True) | |
63 | if nested_params: | |
64 | param['type'] = type(p_type) | |
65 | param['nested_params'] = nested_params | |
66 | else: | |
67 | param['type'] = p_type | |
68 | return param | |
69 | ||
70 | # Optional must be set to True in order to set default value and parameters format must be: | |
71 | # 'name: (type or nested parameters, description, [optional], [default value])' | |
72 | def _split_dict(data, nested): | |
73 | splitted = [] | |
74 | for name, props in data.items(): | |
75 | if isinstance(name, str) and isinstance(props, tuple): | |
76 | if len(props) == 2: | |
77 | param = _split_param(name, props[0], props[1], nested=nested) | |
78 | elif len(props) == 3: | |
79 | param = _split_param(name, props[0], props[1], optional=props[2], nested=nested) | |
80 | if len(props) == 4: | |
81 | param = _split_param(name, props[0], props[1], props[2], props[3], nested) | |
82 | splitted.append(param) | |
83 | else: | |
84 | raise Exception( | |
85 | """Parameter %s in %s has not correct format. Valid formats are: | |
86 | <name>: (<type>, <description>, [optional], [default value]) | |
87 | <name>: (<[type]>, <description>, [optional], [default value]) | |
88 | <name>: (<[nested parameters]>, <description>, [optional], [default value]) | |
89 | <name>: (<{nested parameters}>, <description>, [optional], [default value])""" | |
90 | % (name, EndpointDoc.__name__)) | |
91 | return splitted | |
92 | ||
93 | def _split_list(data, nested): | |
9f95a23c | 94 | splitted = [] # type: List[Any] |
11fdf7f2 TL |
95 | for item in data: |
96 | splitted.extend(_split_parameters(item, nested)) | |
97 | return splitted | |
98 | ||
99 | # nested = True means parameters are inside a dict or array | |
100 | def _split_parameters(data, nested=False): | |
9f95a23c | 101 | param_list = [] # type: List[Any] |
11fdf7f2 TL |
102 | if isinstance(data, dict): |
103 | param_list.extend(_split_dict(data, nested)) | |
104 | elif isinstance(data, (list, tuple)): | |
105 | param_list.extend(_split_list(data, True)) | |
106 | return param_list | |
107 | ||
108 | resp = {} | |
109 | if responses: | |
110 | for status_code, response_body in responses.items(): | |
111 | resp[str(status_code)] = _split_parameters(response_body) | |
112 | ||
113 | def _wrapper(func): | |
114 | func.doc_info = { | |
115 | 'summary': description, | |
116 | 'tag': group, | |
117 | 'parameters': _split_parameters(parameters), | |
118 | 'response': resp | |
119 | } | |
120 | return func | |
121 | ||
122 | return _wrapper | |
123 | ||
124 | ||
125 | class ControllerDoc(object): | |
126 | def __init__(self, description="", group=""): | |
127 | self.tag = group | |
128 | self.tag_descr = description | |
129 | ||
130 | def __call__(self, cls): | |
131 | cls.doc_info = { | |
132 | 'tag': self.tag, | |
133 | 'tag_descr': self.tag_descr | |
134 | } | |
135 | return cls | |
136 | ||
137 | ||
138 | class Controller(object): | |
139 | def __init__(self, path, base_url=None, security_scope=None, secure=True): | |
140 | if security_scope and not Scope.valid_scope(security_scope): | |
11fdf7f2 TL |
141 | raise ScopeNotValid(security_scope) |
142 | self.path = path | |
143 | self.base_url = base_url | |
144 | self.security_scope = security_scope | |
145 | self.secure = secure | |
146 | ||
147 | if self.path and self.path[0] != "/": | |
148 | self.path = "/" + self.path | |
149 | ||
150 | if self.base_url is None: | |
151 | self.base_url = "" | |
152 | elif self.base_url == "/": | |
153 | self.base_url = "" | |
154 | ||
155 | if self.base_url == "" and self.path == "": | |
156 | self.base_url = "/" | |
157 | ||
158 | def __call__(self, cls): | |
159 | cls._cp_controller_ = True | |
160 | cls._cp_path_ = "{}{}".format(self.base_url, self.path) | |
161 | cls._security_scope = self.security_scope | |
162 | ||
163 | config = { | |
164 | 'tools.dashboard_exception_handler.on': True, | |
165 | 'tools.authenticate.on': self.secure, | |
166 | } | |
167 | if not hasattr(cls, '_cp_config'): | |
168 | cls._cp_config = {} | |
169 | cls._cp_config.update(config) | |
170 | return cls | |
171 | ||
172 | ||
173 | class ApiController(Controller): | |
174 | def __init__(self, path, security_scope=None, secure=True): | |
175 | super(ApiController, self).__init__(path, base_url="/api", | |
176 | security_scope=security_scope, | |
177 | secure=secure) | |
178 | ||
179 | def __call__(self, cls): | |
180 | cls = super(ApiController, self).__call__(cls) | |
181 | cls._api_endpoint = True | |
182 | return cls | |
183 | ||
184 | ||
185 | class UiApiController(Controller): | |
186 | def __init__(self, path, security_scope=None, secure=True): | |
187 | super(UiApiController, self).__init__(path, base_url="/ui-api", | |
188 | security_scope=security_scope, | |
189 | secure=secure) | |
190 | ||
191 | ||
9f95a23c | 192 | def Endpoint(method=None, path=None, path_params=None, query_params=None, # noqa: N802 |
11fdf7f2 TL |
193 | json_response=True, proxy=False, xml=False): |
194 | ||
195 | if method is None: | |
196 | method = 'GET' | |
197 | elif not isinstance(method, str) or \ | |
198 | method.upper() not in ['GET', 'POST', 'DELETE', 'PUT']: | |
199 | raise TypeError("Possible values for method are: 'GET', 'POST', " | |
200 | "'DELETE', or 'PUT'") | |
201 | ||
202 | method = method.upper() | |
203 | ||
204 | if method in ['GET', 'DELETE']: | |
205 | if path_params is not None: | |
206 | raise TypeError("path_params should not be used for {} " | |
207 | "endpoints. All function params are considered" | |
208 | " path parameters by default".format(method)) | |
209 | ||
210 | if path_params is None: | |
211 | if method in ['POST', 'PUT']: | |
212 | path_params = [] | |
213 | ||
214 | if query_params is None: | |
215 | query_params = [] | |
216 | ||
217 | def _wrapper(func): | |
218 | if method in ['POST', 'PUT']: | |
219 | func_params = _get_function_params(func) | |
220 | for param in func_params: | |
221 | if param['name'] in path_params and not param['required']: | |
222 | raise TypeError("path_params can only reference " | |
223 | "non-optional function parameters") | |
224 | ||
225 | if func.__name__ == '__call__' and path is None: | |
226 | e_path = "" | |
227 | else: | |
228 | e_path = path | |
229 | ||
230 | if e_path is not None: | |
231 | e_path = e_path.strip() | |
232 | if e_path and e_path[0] != "/": | |
233 | e_path = "/" + e_path | |
234 | elif e_path == "/": | |
235 | e_path = "" | |
236 | ||
237 | func._endpoint = { | |
238 | 'method': method, | |
239 | 'path': e_path, | |
240 | 'path_params': path_params, | |
241 | 'query_params': query_params, | |
242 | 'json_response': json_response, | |
243 | 'proxy': proxy, | |
244 | 'xml': xml | |
245 | } | |
246 | return func | |
247 | return _wrapper | |
248 | ||
249 | ||
9f95a23c | 250 | def Proxy(path=None): # noqa: N802 |
11fdf7f2 TL |
251 | if path is None: |
252 | path = "" | |
253 | elif path == "/": | |
254 | path = "" | |
255 | path += "/{path:.*}" | |
256 | return Endpoint(path=path, proxy=True) | |
257 | ||
258 | ||
259 | def load_controllers(): | |
9f95a23c | 260 | logger = logging.getLogger('controller.load') |
11fdf7f2 TL |
261 | # setting sys.path properly when not running under the mgr |
262 | controllers_dir = os.path.dirname(os.path.realpath(__file__)) | |
263 | dashboard_dir = os.path.dirname(controllers_dir) | |
264 | mgr_dir = os.path.dirname(dashboard_dir) | |
9f95a23c TL |
265 | logger.debug("controllers_dir=%s", controllers_dir) |
266 | logger.debug("dashboard_dir=%s", dashboard_dir) | |
267 | logger.debug("mgr_dir=%s", mgr_dir) | |
11fdf7f2 TL |
268 | if mgr_dir not in sys.path: |
269 | sys.path.append(mgr_dir) | |
270 | ||
271 | controllers = [] | |
272 | mods = [mod for _, mod, _ in pkgutil.iter_modules([controllers_dir])] | |
9f95a23c | 273 | logger.debug("mods=%s", mods) |
11fdf7f2 TL |
274 | for mod_name in mods: |
275 | mod = importlib.import_module('.controllers.{}'.format(mod_name), | |
276 | package='dashboard') | |
277 | for _, cls in mod.__dict__.items(): | |
278 | # Controllers MUST be derived from the class BaseController. | |
279 | if inspect.isclass(cls) and issubclass(cls, BaseController) and \ | |
280 | hasattr(cls, '_cp_controller_'): | |
281 | if cls._cp_path_.startswith(':'): | |
282 | # invalid _cp_path_ value | |
283 | logger.error("Invalid url prefix '%s' for controller '%s'", | |
284 | cls._cp_path_, cls.__name__) | |
285 | continue | |
286 | controllers.append(cls) | |
287 | ||
288 | for clist in PLUGIN_MANAGER.hook.get_controllers() or []: | |
289 | controllers.extend(clist) | |
290 | ||
291 | return controllers | |
292 | ||
293 | ||
9f95a23c | 294 | ENDPOINT_MAP = collections.defaultdict(list) # type: dict |
11fdf7f2 TL |
295 | |
296 | ||
297 | def generate_controller_routes(endpoint, mapper, base_url): | |
298 | inst = endpoint.inst | |
299 | ctrl_class = endpoint.ctrl | |
11fdf7f2 TL |
300 | |
301 | if endpoint.proxy: | |
302 | conditions = None | |
303 | else: | |
304 | conditions = dict(method=[endpoint.method]) | |
305 | ||
eafe8130 TL |
306 | # base_url can be empty or a URL path that starts with "/" |
307 | # we will remove the trailing "/" if exists to help with the | |
308 | # concatenation with the endpoint url below | |
309 | if base_url.endswith("/"): | |
310 | base_url = base_url[:-1] | |
311 | ||
11fdf7f2 | 312 | endp_url = endpoint.url |
11fdf7f2 | 313 | |
eafe8130 TL |
314 | if endp_url.find("/", 1) == -1: |
315 | parent_url = "{}{}".format(base_url, endp_url) | |
11fdf7f2 | 316 | else: |
eafe8130 TL |
317 | parent_url = "{}{}".format(base_url, endp_url[:endp_url.find("/", 1)]) |
318 | ||
319 | # parent_url might be of the form "/.../{...}" where "{...}" is a path parameter | |
320 | # we need to remove the path parameter definition | |
321 | parent_url = re.sub(r'(?:/\{[^}]+\})$', '', parent_url) | |
322 | if not parent_url: # root path case | |
323 | parent_url = "/" | |
324 | ||
325 | url = "{}{}".format(base_url, endp_url) | |
11fdf7f2 | 326 | |
9f95a23c | 327 | logger = logging.getLogger('controller') |
11fdf7f2 TL |
328 | logger.debug("Mapped [%s] to %s:%s restricted to %s", |
329 | url, ctrl_class.__name__, endpoint.action, | |
330 | endpoint.method) | |
331 | ||
332 | ENDPOINT_MAP[endpoint.url].append(endpoint) | |
333 | ||
334 | name = ctrl_class.__name__ + ":" + endpoint.action | |
335 | mapper.connect(name, url, controller=inst, action=endpoint.action, | |
336 | conditions=conditions) | |
337 | ||
338 | # adding route with trailing slash | |
339 | name += "/" | |
340 | url += "/" | |
341 | mapper.connect(name, url, controller=inst, action=endpoint.action, | |
342 | conditions=conditions) | |
343 | ||
eafe8130 | 344 | return parent_url |
11fdf7f2 TL |
345 | |
346 | ||
347 | def generate_routes(url_prefix): | |
348 | mapper = cherrypy.dispatch.RoutesDispatcher() | |
349 | ctrls = load_controllers() | |
350 | ||
351 | parent_urls = set() | |
352 | ||
353 | endpoint_list = [] | |
354 | for ctrl in ctrls: | |
355 | inst = ctrl() | |
356 | for endpoint in ctrl.endpoints(): | |
357 | endpoint.inst = inst | |
358 | endpoint_list.append(endpoint) | |
359 | ||
360 | endpoint_list = sorted(endpoint_list, key=lambda e: e.url) | |
361 | for endpoint in endpoint_list: | |
362 | parent_urls.add(generate_controller_routes(endpoint, mapper, | |
363 | "{}".format(url_prefix))) | |
364 | ||
9f95a23c | 365 | logger = logging.getLogger('controller') |
11fdf7f2 TL |
366 | logger.debug("list of parent paths: %s", parent_urls) |
367 | return mapper, parent_urls | |
368 | ||
369 | ||
370 | def json_error_page(status, message, traceback, version): | |
371 | cherrypy.response.headers['Content-Type'] = 'application/json' | |
372 | return json.dumps(dict(status=status, detail=message, traceback=traceback, | |
373 | version=version)) | |
374 | ||
375 | ||
376 | def _get_function_params(func): | |
377 | """ | |
378 | Retrieves the list of parameters declared in function. | |
379 | Each parameter is represented as dict with keys: | |
380 | * name (str): the name of the parameter | |
381 | * required (bool): whether the parameter is required or not | |
382 | * default (obj): the parameter's default value | |
383 | """ | |
384 | fspec = getargspec(func) | |
385 | ||
386 | func_params = [] | |
387 | nd = len(fspec.args) if not fspec.defaults else -len(fspec.defaults) | |
388 | for param in fspec.args[1:nd]: | |
389 | func_params.append({'name': param, 'required': True}) | |
390 | ||
391 | if fspec.defaults: | |
392 | for param, val in zip(fspec.args[nd:], fspec.defaults): | |
393 | func_params.append({ | |
394 | 'name': param, | |
395 | 'required': False, | |
396 | 'default': val | |
397 | }) | |
398 | ||
399 | return func_params | |
400 | ||
401 | ||
402 | class Task(object): | |
403 | def __init__(self, name, metadata, wait_for=5.0, exception_handler=None): | |
404 | self.name = name | |
405 | if isinstance(metadata, list): | |
406 | self.metadata = {e[1:-1]: e for e in metadata} | |
407 | else: | |
408 | self.metadata = metadata | |
409 | self.wait_for = wait_for | |
410 | self.exception_handler = exception_handler | |
411 | ||
412 | def _gen_arg_map(self, func, args, kwargs): | |
413 | arg_map = {} | |
414 | params = _get_function_params(func) | |
415 | ||
416 | args = args[1:] # exclude self | |
417 | for idx, param in enumerate(params): | |
418 | if idx < len(args): | |
419 | arg_map[param['name']] = args[idx] | |
420 | else: | |
421 | if param['name'] in kwargs: | |
422 | arg_map[param['name']] = kwargs[param['name']] | |
423 | else: | |
424 | assert not param['required'], "{0} is required".format(param['name']) | |
425 | arg_map[param['name']] = param['default'] | |
426 | ||
427 | if param['name'] in arg_map: | |
428 | # This is not a type error. We are using the index here. | |
429 | arg_map[idx+1] = arg_map[param['name']] | |
430 | ||
431 | return arg_map | |
432 | ||
433 | def __call__(self, func): | |
434 | @wraps(func) | |
435 | def wrapper(*args, **kwargs): | |
436 | arg_map = self._gen_arg_map(func, args, kwargs) | |
437 | metadata = {} | |
438 | for k, v in self.metadata.items(): | |
439 | if isinstance(v, str) and v and v[0] == '{' and v[-1] == '}': | |
440 | param = v[1:-1] | |
441 | try: | |
442 | pos = int(param) | |
443 | metadata[k] = arg_map[pos] | |
444 | except ValueError: | |
445 | if param.find('.') == -1: | |
446 | metadata[k] = arg_map[param] | |
447 | else: | |
448 | path = param.split('.') | |
449 | metadata[k] = arg_map[path[0]] | |
450 | for i in range(1, len(path)): | |
451 | metadata[k] = metadata[k][path[i]] | |
452 | else: | |
453 | metadata[k] = v | |
454 | task = TaskManager.run(self.name, metadata, func, args, kwargs, | |
455 | exception_handler=self.exception_handler) | |
456 | try: | |
457 | status, value = task.wait(self.wait_for) | |
458 | except Exception as ex: | |
459 | if task.ret_value: | |
460 | # exception was handled by task.exception_handler | |
461 | if 'status' in task.ret_value: | |
462 | status = task.ret_value['status'] | |
463 | else: | |
464 | status = getattr(ex, 'status', 500) | |
465 | cherrypy.response.status = status | |
466 | return task.ret_value | |
467 | raise ex | |
468 | if status == TaskManager.VALUE_EXECUTING: | |
469 | cherrypy.response.status = 202 | |
470 | return {'name': self.name, 'metadata': metadata} | |
471 | return value | |
472 | return wrapper | |
473 | ||
474 | ||
475 | class BaseController(object): | |
476 | """ | |
477 | Base class for all controllers providing API endpoints. | |
478 | """ | |
479 | ||
480 | class Endpoint(object): | |
481 | """ | |
482 | An instance of this class represents an endpoint. | |
483 | """ | |
f91f0fd5 | 484 | |
11fdf7f2 TL |
485 | def __init__(self, ctrl, func): |
486 | self.ctrl = ctrl | |
487 | self.inst = None | |
488 | self.func = func | |
489 | ||
490 | if not self.config['proxy']: | |
491 | setattr(self.ctrl, func.__name__, self.function) | |
492 | ||
493 | @property | |
494 | def config(self): | |
495 | func = self.func | |
496 | while not hasattr(func, '_endpoint'): | |
497 | if hasattr(func, "__wrapped__"): | |
498 | func = func.__wrapped__ | |
499 | else: | |
500 | return None | |
501 | return func._endpoint | |
502 | ||
503 | @property | |
504 | def function(self): | |
505 | return self.ctrl._request_wrapper(self.func, self.method, | |
506 | self.config['json_response'], | |
507 | self.config['xml']) | |
508 | ||
509 | @property | |
510 | def method(self): | |
511 | return self.config['method'] | |
512 | ||
513 | @property | |
514 | def proxy(self): | |
515 | return self.config['proxy'] | |
516 | ||
517 | @property | |
518 | def url(self): | |
eafe8130 TL |
519 | ctrl_path = self.ctrl.get_path() |
520 | if ctrl_path == "/": | |
521 | ctrl_path = "" | |
11fdf7f2 | 522 | if self.config['path'] is not None: |
eafe8130 | 523 | url = "{}{}".format(ctrl_path, self.config['path']) |
11fdf7f2 | 524 | else: |
eafe8130 | 525 | url = "{}/{}".format(ctrl_path, self.func.__name__) |
11fdf7f2 TL |
526 | |
527 | ctrl_path_params = self.ctrl.get_path_param_names( | |
528 | self.config['path']) | |
529 | path_params = [p['name'] for p in self.path_params | |
530 | if p['name'] not in ctrl_path_params] | |
531 | path_params = ["{{{}}}".format(p) for p in path_params] | |
532 | if path_params: | |
533 | url += "/{}".format("/".join(path_params)) | |
534 | ||
535 | return url | |
536 | ||
537 | @property | |
538 | def action(self): | |
539 | return self.func.__name__ | |
540 | ||
541 | @property | |
542 | def path_params(self): | |
543 | ctrl_path_params = self.ctrl.get_path_param_names( | |
544 | self.config['path']) | |
545 | func_params = _get_function_params(self.func) | |
546 | ||
547 | if self.method in ['GET', 'DELETE']: | |
548 | assert self.config['path_params'] is None | |
549 | ||
550 | return [p for p in func_params if p['name'] in ctrl_path_params | |
551 | or (p['name'] not in self.config['query_params'] | |
552 | and p['required'])] | |
553 | ||
554 | # elif self.method in ['POST', 'PUT']: | |
555 | return [p for p in func_params if p['name'] in ctrl_path_params | |
556 | or p['name'] in self.config['path_params']] | |
557 | ||
558 | @property | |
559 | def query_params(self): | |
560 | if self.method in ['GET', 'DELETE']: | |
561 | func_params = _get_function_params(self.func) | |
562 | path_params = [p['name'] for p in self.path_params] | |
563 | return [p for p in func_params if p['name'] not in path_params] | |
564 | ||
565 | # elif self.method in ['POST', 'PUT']: | |
566 | func_params = _get_function_params(self.func) | |
567 | return [p for p in func_params | |
568 | if p['name'] in self.config['query_params']] | |
569 | ||
570 | @property | |
571 | def body_params(self): | |
572 | func_params = _get_function_params(self.func) | |
573 | path_params = [p['name'] for p in self.path_params] | |
574 | query_params = [p['name'] for p in self.query_params] | |
575 | return [p for p in func_params | |
576 | if p['name'] not in path_params | |
577 | and p['name'] not in query_params] | |
578 | ||
579 | @property | |
580 | def group(self): | |
581 | return self.ctrl.__name__ | |
582 | ||
583 | @property | |
584 | def is_api(self): | |
585 | return hasattr(self.ctrl, '_api_endpoint') | |
586 | ||
587 | @property | |
588 | def is_secure(self): | |
589 | return self.ctrl._cp_config['tools.authenticate.on'] | |
590 | ||
591 | def __repr__(self): | |
592 | return "Endpoint({}, {}, {})".format(self.url, self.method, | |
593 | self.action) | |
594 | ||
595 | def __init__(self): | |
9f95a23c | 596 | logger = logging.getLogger('controller') |
11fdf7f2 | 597 | logger.info('Initializing controller: %s -> %s', |
9f95a23c TL |
598 | self.__class__.__name__, self._cp_path_) # type: ignore |
599 | super(BaseController, self).__init__() | |
11fdf7f2 TL |
600 | |
601 | def _has_permissions(self, permissions, scope=None): | |
9f95a23c | 602 | if not self._cp_config['tools.authenticate.on']: # type: ignore |
11fdf7f2 TL |
603 | raise Exception("Cannot verify permission in non secured " |
604 | "controllers") | |
605 | ||
606 | if not isinstance(permissions, list): | |
607 | permissions = [permissions] | |
608 | ||
609 | if scope is None: | |
610 | scope = getattr(self, '_security_scope', None) | |
611 | if scope is None: | |
612 | raise Exception("Cannot verify permissions without scope security" | |
613 | " defined") | |
614 | username = JwtManager.LOCAL_USER.username | |
615 | return AuthManager.authorize(username, scope, permissions) | |
616 | ||
617 | @classmethod | |
618 | def get_path_param_names(cls, path_extension=None): | |
619 | if path_extension is None: | |
620 | path_extension = "" | |
9f95a23c | 621 | full_path = cls._cp_path_[1:] + path_extension # type: ignore |
11fdf7f2 TL |
622 | path_params = [] |
623 | for step in full_path.split('/'): | |
624 | param = None | |
625 | if not step: | |
626 | continue | |
627 | if step[0] == ':': | |
628 | param = step[1:] | |
629 | elif step[0] == '{' and step[-1] == '}': | |
630 | param, _, _ = step[1:-1].partition(':') | |
631 | if param: | |
632 | path_params.append(param) | |
633 | return path_params | |
634 | ||
635 | @classmethod | |
636 | def get_path(cls): | |
9f95a23c | 637 | return cls._cp_path_ # type: ignore |
11fdf7f2 TL |
638 | |
639 | @classmethod | |
640 | def endpoints(cls): | |
641 | """ | |
642 | This method iterates over all the methods decorated with ``@endpoint`` | |
643 | and creates an Endpoint object for each one of the methods. | |
644 | ||
645 | :return: A list of endpoint objects | |
646 | :rtype: list[BaseController.Endpoint] | |
647 | """ | |
648 | result = [] | |
649 | for _, func in inspect.getmembers(cls, predicate=callable): | |
650 | if hasattr(func, '_endpoint'): | |
651 | result.append(cls.Endpoint(cls, func)) | |
652 | return result | |
653 | ||
654 | @staticmethod | |
655 | def _request_wrapper(func, method, json_response, xml): # pylint: disable=unused-argument | |
656 | @wraps(func) | |
657 | def inner(*args, **kwargs): | |
658 | for key, value in kwargs.items(): | |
9f95a23c | 659 | if isinstance(value, six.text_type): |
11fdf7f2 TL |
660 | kwargs[key] = unquote(value) |
661 | ||
662 | # Process method arguments. | |
663 | params = get_request_body_params(cherrypy.request) | |
664 | kwargs.update(params) | |
665 | ||
666 | ret = func(*args, **kwargs) | |
667 | if isinstance(ret, bytes): | |
668 | ret = ret.decode('utf-8') | |
669 | if xml: | |
670 | cherrypy.response.headers['Content-Type'] = 'application/xml' | |
671 | return ret.encode('utf8') | |
672 | if json_response: | |
673 | cherrypy.response.headers['Content-Type'] = 'application/json' | |
674 | ret = json.dumps(ret).encode('utf8') | |
675 | return ret | |
676 | return inner | |
677 | ||
678 | @property | |
679 | def _request(self): | |
680 | return self.Request(cherrypy.request) | |
681 | ||
682 | class Request(object): | |
683 | def __init__(self, cherrypy_req): | |
684 | self._creq = cherrypy_req | |
685 | ||
686 | @property | |
687 | def scheme(self): | |
688 | return self._creq.scheme | |
689 | ||
690 | @property | |
691 | def host(self): | |
692 | base = self._creq.base | |
693 | base = base[len(self.scheme)+3:] | |
694 | return base[:base.find(":")] if ":" in base else base | |
695 | ||
696 | @property | |
697 | def port(self): | |
698 | base = self._creq.base | |
699 | base = base[len(self.scheme)+3:] | |
700 | default_port = 443 if self.scheme == 'https' else 80 | |
701 | return int(base[base.find(":")+1:]) if ":" in base else default_port | |
702 | ||
703 | @property | |
704 | def path_info(self): | |
705 | return self._creq.path_info | |
706 | ||
707 | ||
708 | class RESTController(BaseController): | |
709 | """ | |
710 | Base class for providing a RESTful interface to a resource. | |
711 | ||
712 | To use this class, simply derive a class from it and implement the methods | |
713 | you want to support. The list of possible methods are: | |
714 | ||
715 | * list() | |
716 | * bulk_set(data) | |
717 | * create(data) | |
718 | * bulk_delete() | |
719 | * get(key) | |
720 | * set(data, key) | |
e306af50 | 721 | * singleton_set(data) |
11fdf7f2 TL |
722 | * delete(key) |
723 | ||
724 | Test with curl: | |
725 | ||
726 | curl -H "Content-Type: application/json" -X POST \ | |
727 | -d '{"username":"xyz","password":"xyz"}' https://127.0.0.1:8443/foo | |
728 | curl https://127.0.0.1:8443/foo | |
729 | curl https://127.0.0.1:8443/foo/0 | |
730 | ||
731 | """ | |
732 | ||
733 | # resource id parameter for using in get, set, and delete methods | |
734 | # should be overridden by subclasses. | |
735 | # to specify a composite id (two parameters) use '/'. e.g., "param1/param2". | |
736 | # If subclasses don't override this property we try to infer the structure | |
737 | # of the resource ID. | |
9f95a23c | 738 | RESOURCE_ID = None # type: Optional[str] |
11fdf7f2 TL |
739 | |
740 | _permission_map = { | |
741 | 'GET': Permission.READ, | |
742 | 'POST': Permission.CREATE, | |
743 | 'PUT': Permission.UPDATE, | |
744 | 'DELETE': Permission.DELETE | |
745 | } | |
746 | ||
747 | _method_mapping = collections.OrderedDict([ | |
748 | ('list', {'method': 'GET', 'resource': False, 'status': 200}), | |
749 | ('create', {'method': 'POST', 'resource': False, 'status': 201}), | |
750 | ('bulk_set', {'method': 'PUT', 'resource': False, 'status': 200}), | |
751 | ('bulk_delete', {'method': 'DELETE', 'resource': False, 'status': 204}), | |
752 | ('get', {'method': 'GET', 'resource': True, 'status': 200}), | |
753 | ('delete', {'method': 'DELETE', 'resource': True, 'status': 204}), | |
e306af50 TL |
754 | ('set', {'method': 'PUT', 'resource': True, 'status': 200}), |
755 | ('singleton_set', {'method': 'PUT', 'resource': False, 'status': 200}) | |
11fdf7f2 TL |
756 | ]) |
757 | ||
758 | @classmethod | |
759 | def infer_resource_id(cls): | |
760 | if cls.RESOURCE_ID is not None: | |
761 | return cls.RESOURCE_ID.split('/') | |
762 | for k, v in cls._method_mapping.items(): | |
763 | func = getattr(cls, k, None) | |
764 | while hasattr(func, "__wrapped__"): | |
765 | func = func.__wrapped__ | |
766 | if v['resource'] and func: | |
767 | path_params = cls.get_path_param_names() | |
768 | params = _get_function_params(func) | |
769 | return [p['name'] for p in params | |
770 | if p['required'] and p['name'] not in path_params] | |
771 | return None | |
772 | ||
773 | @classmethod | |
774 | def endpoints(cls): | |
775 | result = super(RESTController, cls).endpoints() | |
776 | res_id_params = cls.infer_resource_id() | |
777 | ||
778 | for _, func in inspect.getmembers(cls, predicate=callable): | |
779 | no_resource_id_params = False | |
780 | status = 200 | |
781 | method = None | |
782 | query_params = None | |
783 | path = "" | |
784 | sec_permissions = hasattr(func, '_security_permissions') | |
785 | permission = None | |
786 | ||
787 | if func.__name__ in cls._method_mapping: | |
9f95a23c | 788 | meth = cls._method_mapping[func.__name__] # type: dict |
11fdf7f2 TL |
789 | |
790 | if meth['resource']: | |
791 | if not res_id_params: | |
792 | no_resource_id_params = True | |
793 | else: | |
794 | path_params = ["{{{}}}".format(p) for p in res_id_params] | |
795 | path += "/{}".format("/".join(path_params)) | |
796 | ||
797 | status = meth['status'] | |
798 | method = meth['method'] | |
799 | if not sec_permissions: | |
800 | permission = cls._permission_map[method] | |
801 | ||
802 | elif hasattr(func, "_collection_method_"): | |
803 | if func._collection_method_['path']: | |
804 | path = func._collection_method_['path'] | |
805 | else: | |
806 | path = "/{}".format(func.__name__) | |
807 | status = func._collection_method_['status'] | |
808 | method = func._collection_method_['method'] | |
809 | query_params = func._collection_method_['query_params'] | |
810 | if not sec_permissions: | |
811 | permission = cls._permission_map[method] | |
812 | ||
813 | elif hasattr(func, "_resource_method_"): | |
814 | if not res_id_params: | |
815 | no_resource_id_params = True | |
816 | else: | |
817 | path_params = ["{{{}}}".format(p) for p in res_id_params] | |
818 | path += "/{}".format("/".join(path_params)) | |
819 | if func._resource_method_['path']: | |
820 | path += func._resource_method_['path'] | |
821 | else: | |
822 | path += "/{}".format(func.__name__) | |
823 | status = func._resource_method_['status'] | |
824 | method = func._resource_method_['method'] | |
825 | query_params = func._resource_method_['query_params'] | |
826 | if not sec_permissions: | |
827 | permission = cls._permission_map[method] | |
828 | ||
829 | else: | |
830 | continue | |
831 | ||
832 | if no_resource_id_params: | |
833 | raise TypeError("Could not infer the resource ID parameters for" | |
834 | " method {} of controller {}. " | |
835 | "Please specify the resource ID parameters " | |
836 | "using the RESOURCE_ID class property" | |
837 | .format(func.__name__, cls.__name__)) | |
838 | ||
839 | if method in ['GET', 'DELETE']: | |
840 | params = _get_function_params(func) | |
841 | if res_id_params is None: | |
842 | res_id_params = [] | |
843 | if query_params is None: | |
844 | query_params = [p['name'] for p in params | |
845 | if p['name'] not in res_id_params] | |
846 | ||
847 | func = cls._status_code_wrapper(func, status) | |
848 | endp_func = Endpoint(method, path=path, | |
849 | query_params=query_params)(func) | |
850 | if permission: | |
851 | _set_func_permissions(endp_func, [permission]) | |
852 | result.append(cls.Endpoint(cls, endp_func)) | |
853 | ||
854 | return result | |
855 | ||
856 | @classmethod | |
857 | def _status_code_wrapper(cls, func, status_code): | |
858 | @wraps(func) | |
859 | def wrapper(*vpath, **params): | |
860 | cherrypy.response.status = status_code | |
861 | return func(*vpath, **params) | |
862 | ||
863 | return wrapper | |
864 | ||
865 | @staticmethod | |
9f95a23c | 866 | def Resource(method=None, path=None, status=None, query_params=None): # noqa: N802 |
11fdf7f2 TL |
867 | if not method: |
868 | method = 'GET' | |
869 | ||
870 | if status is None: | |
871 | status = 200 | |
872 | ||
873 | def _wrapper(func): | |
874 | func._resource_method_ = { | |
875 | 'method': method, | |
876 | 'path': path, | |
877 | 'status': status, | |
878 | 'query_params': query_params | |
879 | } | |
880 | return func | |
881 | return _wrapper | |
882 | ||
883 | @staticmethod | |
9f95a23c | 884 | def Collection(method=None, path=None, status=None, query_params=None): # noqa: N802 |
11fdf7f2 TL |
885 | if not method: |
886 | method = 'GET' | |
887 | ||
888 | if status is None: | |
889 | status = 200 | |
890 | ||
891 | def _wrapper(func): | |
892 | func._collection_method_ = { | |
893 | 'method': method, | |
894 | 'path': path, | |
895 | 'status': status, | |
896 | 'query_params': query_params | |
897 | } | |
898 | return func | |
899 | return _wrapper | |
900 | ||
901 | ||
902 | # Role-based access permissions decorators | |
903 | ||
904 | def _set_func_permissions(func, permissions): | |
905 | if not isinstance(permissions, list): | |
906 | permissions = [permissions] | |
907 | ||
908 | for perm in permissions: | |
909 | if not Permission.valid_permission(perm): | |
9f95a23c | 910 | logger = logging.getLogger('controller.set_func_perms') |
11fdf7f2 TL |
911 | logger.debug("Invalid security permission: %s\n " |
912 | "Possible values: %s", perm, | |
913 | Permission.all_permissions()) | |
914 | raise PermissionNotValid(perm) | |
915 | ||
916 | if not hasattr(func, '_security_permissions'): | |
917 | func._security_permissions = permissions | |
918 | else: | |
919 | permissions.extend(func._security_permissions) | |
920 | func._security_permissions = list(set(permissions)) | |
921 | ||
922 | ||
9f95a23c TL |
923 | def ReadPermission(func): # noqa: N802 |
924 | """ | |
925 | :raises PermissionNotValid: If the permission is missing. | |
926 | """ | |
11fdf7f2 TL |
927 | _set_func_permissions(func, Permission.READ) |
928 | return func | |
929 | ||
930 | ||
9f95a23c TL |
931 | def CreatePermission(func): # noqa: N802 |
932 | """ | |
933 | :raises PermissionNotValid: If the permission is missing. | |
934 | """ | |
11fdf7f2 TL |
935 | _set_func_permissions(func, Permission.CREATE) |
936 | return func | |
937 | ||
938 | ||
9f95a23c TL |
939 | def DeletePermission(func): # noqa: N802 |
940 | """ | |
941 | :raises PermissionNotValid: If the permission is missing. | |
942 | """ | |
11fdf7f2 TL |
943 | _set_func_permissions(func, Permission.DELETE) |
944 | return func | |
945 | ||
946 | ||
9f95a23c TL |
947 | def UpdatePermission(func): # noqa: N802 |
948 | """ | |
949 | :raises PermissionNotValid: If the permission is missing. | |
950 | """ | |
11fdf7f2 TL |
951 | _set_func_permissions(func, Permission.UPDATE) |
952 | return func | |
f91f0fd5 TL |
953 | |
954 | ||
955 | # Empty request body decorator | |
956 | ||
957 | def allow_empty_body(func): # noqa: N802 | |
958 | """ | |
959 | The POST/PUT request methods decorated with ``@allow_empty_body`` | |
960 | are allowed to send empty request body. | |
961 | """ | |
962 | try: | |
963 | func._cp_config['tools.json_in.force'] = False | |
964 | except (AttributeError, KeyError): | |
965 | func._cp_config = {'tools.json_in.force': False} | |
966 | return func |