]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | # |
2 | # Licensed to the Apache Software Foundation (ASF) under one | |
3 | # or more contributor license agreements. See the NOTICE file | |
4 | # distributed with this work for additional information | |
5 | # regarding copyright ownership. The ASF licenses this file | |
6 | # to you under the Apache License, Version 2.0 (the | |
7 | # "License"); you may not use this file except in compliance | |
8 | # with the License. You may obtain a copy of the License at | |
9 | # | |
10 | # http://www.apache.org/licenses/LICENSE-2.0 | |
11 | # | |
12 | # Unless required by applicable law or agreed to in writing, | |
13 | # software distributed under the License is distributed on an | |
14 | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | # KIND, either express or implied. See the License for the | |
16 | # specific language governing permissions and limitations | |
17 | # under the License. | |
18 | # | |
19 | ||
20 | import logging | |
21 | import os | |
22 | import socket | |
23 | import ssl | |
24 | import sys | |
25 | import warnings | |
26 | ||
27 | from .sslcompat import _match_hostname, _match_has_ipaddress | |
28 | from thrift.transport import TSocket | |
29 | from thrift.transport.TTransport import TTransportException | |
30 | ||
31 | logger = logging.getLogger(__name__) | |
32 | warnings.filterwarnings( | |
33 | 'default', category=DeprecationWarning, module=__name__) | |
34 | ||
35 | ||
36 | class TSSLBase(object): | |
37 | # SSLContext is not available for Python < 2.7.9 | |
38 | _has_ssl_context = sys.hexversion >= 0x020709F0 | |
39 | ||
40 | # ciphers argument is not available for Python < 2.7.0 | |
41 | _has_ciphers = sys.hexversion >= 0x020700F0 | |
42 | ||
43 | # For python >= 2.7.9, use latest TLS that both client and server | |
44 | # supports. | |
45 | # SSL 2.0 and 3.0 are disabled via ssl.OP_NO_SSLv2 and ssl.OP_NO_SSLv3. | |
46 | # For python < 2.7.9, use TLS 1.0 since TLSv1_X nor OP_NO_SSLvX is | |
47 | # unavailable. | |
48 | _default_protocol = ssl.PROTOCOL_SSLv23 if _has_ssl_context else \ | |
49 | ssl.PROTOCOL_TLSv1 | |
50 | ||
51 | def _init_context(self, ssl_version): | |
52 | if self._has_ssl_context: | |
53 | self._context = ssl.SSLContext(ssl_version) | |
54 | if self._context.protocol == ssl.PROTOCOL_SSLv23: | |
55 | self._context.options |= ssl.OP_NO_SSLv2 | |
56 | self._context.options |= ssl.OP_NO_SSLv3 | |
57 | else: | |
58 | self._context = None | |
59 | self._ssl_version = ssl_version | |
60 | ||
61 | @property | |
62 | def _should_verify(self): | |
63 | if self._has_ssl_context: | |
64 | return self._context.verify_mode != ssl.CERT_NONE | |
65 | else: | |
66 | return self.cert_reqs != ssl.CERT_NONE | |
67 | ||
68 | @property | |
69 | def ssl_version(self): | |
70 | if self._has_ssl_context: | |
71 | return self.ssl_context.protocol | |
72 | else: | |
73 | return self._ssl_version | |
74 | ||
75 | @property | |
76 | def ssl_context(self): | |
77 | return self._context | |
78 | ||
79 | SSL_VERSION = _default_protocol | |
80 | """ | |
81 | Default SSL version. | |
82 | For backwards compatibility, it can be modified. | |
83 | Use __init__ keyword argument "ssl_version" instead. | |
84 | """ | |
85 | ||
86 | def _deprecated_arg(self, args, kwargs, pos, key): | |
87 | if len(args) <= pos: | |
88 | return | |
89 | real_pos = pos + 3 | |
90 | warnings.warn( | |
91 | '%dth positional argument is deprecated.' | |
92 | 'please use keyword argument instead.' | |
93 | % real_pos, DeprecationWarning, stacklevel=3) | |
94 | ||
95 | if key in kwargs: | |
96 | raise TypeError( | |
97 | 'Duplicate argument: %dth argument and %s keyword argument.' | |
98 | % (real_pos, key)) | |
99 | kwargs[key] = args[pos] | |
100 | ||
101 | def _unix_socket_arg(self, host, port, args, kwargs): | |
102 | key = 'unix_socket' | |
103 | if host is None and port is None and len(args) == 1 and key not in kwargs: | |
104 | kwargs[key] = args[0] | |
105 | return True | |
106 | return False | |
107 | ||
108 | def __getattr__(self, key): | |
109 | if key == 'SSL_VERSION': | |
110 | warnings.warn( | |
111 | 'SSL_VERSION is deprecated.' | |
112 | 'please use ssl_version attribute instead.', | |
113 | DeprecationWarning, stacklevel=2) | |
114 | return self.ssl_version | |
115 | ||
116 | def __init__(self, server_side, host, ssl_opts): | |
117 | self._server_side = server_side | |
118 | if TSSLBase.SSL_VERSION != self._default_protocol: | |
119 | warnings.warn( | |
120 | 'SSL_VERSION is deprecated.' | |
121 | 'please use ssl_version keyword argument instead.', | |
122 | DeprecationWarning, stacklevel=2) | |
123 | self._context = ssl_opts.pop('ssl_context', None) | |
124 | self._server_hostname = None | |
125 | if not self._server_side: | |
126 | self._server_hostname = ssl_opts.pop('server_hostname', host) | |
127 | if self._context: | |
128 | self._custom_context = True | |
129 | if ssl_opts: | |
130 | raise ValueError( | |
131 | 'Incompatible arguments: ssl_context and %s' | |
132 | % ' '.join(ssl_opts.keys())) | |
133 | if not self._has_ssl_context: | |
134 | raise ValueError( | |
135 | 'ssl_context is not available for this version of Python') | |
136 | else: | |
137 | self._custom_context = False | |
138 | ssl_version = ssl_opts.pop('ssl_version', TSSLBase.SSL_VERSION) | |
139 | self._init_context(ssl_version) | |
140 | self.cert_reqs = ssl_opts.pop('cert_reqs', ssl.CERT_REQUIRED) | |
141 | self.ca_certs = ssl_opts.pop('ca_certs', None) | |
142 | self.keyfile = ssl_opts.pop('keyfile', None) | |
143 | self.certfile = ssl_opts.pop('certfile', None) | |
144 | self.ciphers = ssl_opts.pop('ciphers', None) | |
145 | ||
146 | if ssl_opts: | |
147 | raise ValueError( | |
148 | 'Unknown keyword arguments: ', ' '.join(ssl_opts.keys())) | |
149 | ||
150 | if self._should_verify: | |
151 | if not self.ca_certs: | |
152 | raise ValueError( | |
153 | 'ca_certs is needed when cert_reqs is not ssl.CERT_NONE') | |
154 | if not os.access(self.ca_certs, os.R_OK): | |
155 | raise IOError('Certificate Authority ca_certs file "%s" ' | |
156 | 'is not readable, cannot validate SSL ' | |
157 | 'certificates.' % (self.ca_certs)) | |
158 | ||
159 | @property | |
160 | def certfile(self): | |
161 | return self._certfile | |
162 | ||
163 | @certfile.setter | |
164 | def certfile(self, certfile): | |
165 | if self._server_side and not certfile: | |
166 | raise ValueError('certfile is needed for server-side') | |
167 | if certfile and not os.access(certfile, os.R_OK): | |
168 | raise IOError('No such certfile found: %s' % (certfile)) | |
169 | self._certfile = certfile | |
170 | ||
171 | def _wrap_socket(self, sock): | |
172 | if self._has_ssl_context: | |
173 | if not self._custom_context: | |
174 | self.ssl_context.verify_mode = self.cert_reqs | |
175 | if self.certfile: | |
176 | self.ssl_context.load_cert_chain(self.certfile, | |
177 | self.keyfile) | |
178 | if self.ciphers: | |
179 | self.ssl_context.set_ciphers(self.ciphers) | |
180 | if self.ca_certs: | |
181 | self.ssl_context.load_verify_locations(self.ca_certs) | |
182 | return self.ssl_context.wrap_socket( | |
183 | sock, server_side=self._server_side, | |
184 | server_hostname=self._server_hostname) | |
185 | else: | |
186 | ssl_opts = { | |
187 | 'ssl_version': self._ssl_version, | |
188 | 'server_side': self._server_side, | |
189 | 'ca_certs': self.ca_certs, | |
190 | 'keyfile': self.keyfile, | |
191 | 'certfile': self.certfile, | |
192 | 'cert_reqs': self.cert_reqs, | |
193 | } | |
194 | if self.ciphers: | |
195 | if self._has_ciphers: | |
196 | ssl_opts['ciphers'] = self.ciphers | |
197 | else: | |
198 | logger.warning( | |
199 | 'ciphers is specified but ignored due to old Python version') | |
200 | return ssl.wrap_socket(sock, **ssl_opts) | |
201 | ||
202 | ||
203 | class TSSLSocket(TSocket.TSocket, TSSLBase): | |
204 | """ | |
205 | SSL implementation of TSocket | |
206 | ||
207 | This class creates outbound sockets wrapped using the | |
208 | python standard ssl module for encrypted connections. | |
209 | """ | |
210 | ||
211 | # New signature | |
212 | # def __init__(self, host='localhost', port=9090, unix_socket=None, | |
213 | # **ssl_args): | |
214 | # Deprecated signature | |
215 | # def __init__(self, host='localhost', port=9090, validate=True, | |
216 | # ca_certs=None, keyfile=None, certfile=None, | |
217 | # unix_socket=None, ciphers=None): | |
218 | def __init__(self, host='localhost', port=9090, *args, **kwargs): | |
219 | """Positional arguments: ``host``, ``port``, ``unix_socket`` | |
220 | ||
221 | Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, | |
222 | ``ssl_version``, ``ca_certs``, | |
223 | ``ciphers`` (Python 2.7.0 or later), | |
224 | ``server_hostname`` (Python 2.7.9 or later) | |
225 | Passed to ssl.wrap_socket. See ssl.wrap_socket documentation. | |
226 | ||
227 | Alternative keyword arguments: (Python 2.7.9 or later) | |
228 | ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket | |
229 | ``server_hostname``: Passed to SSLContext.wrap_socket | |
230 | ||
231 | Common keyword argument: | |
232 | ``validate_callback`` (cert, hostname) -> None: | |
233 | Called after SSL handshake. Can raise when hostname does not | |
234 | match the cert. | |
235 | ``socket_keepalive`` enable TCP keepalive, default off. | |
236 | """ | |
237 | self.is_valid = False | |
238 | self.peercert = None | |
239 | ||
240 | if args: | |
241 | if len(args) > 6: | |
242 | raise TypeError('Too many positional argument') | |
243 | if not self._unix_socket_arg(host, port, args, kwargs): | |
244 | self._deprecated_arg(args, kwargs, 0, 'validate') | |
245 | self._deprecated_arg(args, kwargs, 1, 'ca_certs') | |
246 | self._deprecated_arg(args, kwargs, 2, 'keyfile') | |
247 | self._deprecated_arg(args, kwargs, 3, 'certfile') | |
248 | self._deprecated_arg(args, kwargs, 4, 'unix_socket') | |
249 | self._deprecated_arg(args, kwargs, 5, 'ciphers') | |
250 | ||
251 | validate = kwargs.pop('validate', None) | |
252 | if validate is not None: | |
253 | cert_reqs_name = 'CERT_REQUIRED' if validate else 'CERT_NONE' | |
254 | warnings.warn( | |
255 | 'validate is deprecated. please use cert_reqs=ssl.%s instead' | |
256 | % cert_reqs_name, | |
257 | DeprecationWarning, stacklevel=2) | |
258 | if 'cert_reqs' in kwargs: | |
259 | raise TypeError('Cannot specify both validate and cert_reqs') | |
260 | kwargs['cert_reqs'] = ssl.CERT_REQUIRED if validate else ssl.CERT_NONE | |
261 | ||
262 | unix_socket = kwargs.pop('unix_socket', None) | |
263 | socket_keepalive = kwargs.pop('socket_keepalive', False) | |
264 | self._validate_callback = kwargs.pop('validate_callback', _match_hostname) | |
265 | TSSLBase.__init__(self, False, host, kwargs) | |
266 | TSocket.TSocket.__init__(self, host, port, unix_socket, | |
267 | socket_keepalive=socket_keepalive) | |
268 | ||
269 | def close(self): | |
270 | try: | |
271 | self.handle.settimeout(0.001) | |
272 | self.handle = self.handle.unwrap() | |
273 | except (ssl.SSLError, socket.error, OSError): | |
274 | # could not complete shutdown in a reasonable amount of time. bail. | |
275 | pass | |
276 | TSocket.TSocket.close(self) | |
277 | ||
278 | @property | |
279 | def validate(self): | |
280 | warnings.warn('validate is deprecated. please use cert_reqs instead', | |
281 | DeprecationWarning, stacklevel=2) | |
282 | return self.cert_reqs != ssl.CERT_NONE | |
283 | ||
284 | @validate.setter | |
285 | def validate(self, value): | |
286 | warnings.warn('validate is deprecated. please use cert_reqs instead', | |
287 | DeprecationWarning, stacklevel=2) | |
288 | self.cert_reqs = ssl.CERT_REQUIRED if value else ssl.CERT_NONE | |
289 | ||
290 | def _do_open(self, family, socktype): | |
291 | plain_sock = socket.socket(family, socktype) | |
292 | try: | |
293 | return self._wrap_socket(plain_sock) | |
294 | except Exception as ex: | |
295 | plain_sock.close() | |
296 | msg = 'failed to initialize SSL' | |
297 | logger.exception(msg) | |
298 | raise TTransportException(type=TTransportException.NOT_OPEN, message=msg, inner=ex) | |
299 | ||
300 | def open(self): | |
301 | super(TSSLSocket, self).open() | |
302 | if self._should_verify: | |
303 | self.peercert = self.handle.getpeercert() | |
304 | try: | |
305 | self._validate_callback(self.peercert, self._server_hostname) | |
306 | self.is_valid = True | |
307 | except TTransportException: | |
308 | raise | |
309 | except Exception as ex: | |
310 | raise TTransportException(message=str(ex), inner=ex) | |
311 | ||
312 | ||
313 | class TSSLServerSocket(TSocket.TServerSocket, TSSLBase): | |
314 | """SSL implementation of TServerSocket | |
315 | ||
316 | This uses the ssl module's wrap_socket() method to provide SSL | |
317 | negotiated encryption. | |
318 | """ | |
319 | ||
320 | # New signature | |
321 | # def __init__(self, host='localhost', port=9090, unix_socket=None, **ssl_args): | |
322 | # Deprecated signature | |
323 | # def __init__(self, host=None, port=9090, certfile='cert.pem', unix_socket=None, ciphers=None): | |
324 | def __init__(self, host=None, port=9090, *args, **kwargs): | |
325 | """Positional arguments: ``host``, ``port``, ``unix_socket`` | |
326 | ||
327 | Keyword arguments: ``keyfile``, ``certfile``, ``cert_reqs``, ``ssl_version``, | |
328 | ``ca_certs``, ``ciphers`` (Python 2.7.0 or later) | |
329 | See ssl.wrap_socket documentation. | |
330 | ||
331 | Alternative keyword arguments: (Python 2.7.9 or later) | |
332 | ``ssl_context``: ssl.SSLContext to be used for SSLContext.wrap_socket | |
333 | ``server_hostname``: Passed to SSLContext.wrap_socket | |
334 | ||
335 | Common keyword argument: | |
336 | ``validate_callback`` (cert, hostname) -> None: | |
337 | Called after SSL handshake. Can raise when hostname does not | |
338 | match the cert. | |
339 | """ | |
340 | if args: | |
341 | if len(args) > 3: | |
342 | raise TypeError('Too many positional argument') | |
343 | if not self._unix_socket_arg(host, port, args, kwargs): | |
344 | self._deprecated_arg(args, kwargs, 0, 'certfile') | |
345 | self._deprecated_arg(args, kwargs, 1, 'unix_socket') | |
346 | self._deprecated_arg(args, kwargs, 2, 'ciphers') | |
347 | ||
348 | if 'ssl_context' not in kwargs: | |
349 | # Preserve existing behaviors for default values | |
350 | if 'cert_reqs' not in kwargs: | |
351 | kwargs['cert_reqs'] = ssl.CERT_NONE | |
352 | if'certfile' not in kwargs: | |
353 | kwargs['certfile'] = 'cert.pem' | |
354 | ||
355 | unix_socket = kwargs.pop('unix_socket', None) | |
356 | self._validate_callback = \ | |
357 | kwargs.pop('validate_callback', _match_hostname) | |
358 | TSSLBase.__init__(self, True, None, kwargs) | |
359 | TSocket.TServerSocket.__init__(self, host, port, unix_socket) | |
360 | if self._should_verify and not _match_has_ipaddress: | |
361 | raise ValueError('Need ipaddress and backports.ssl_match_hostname ' | |
362 | 'module to verify client certificate') | |
363 | ||
364 | def setCertfile(self, certfile): | |
365 | """Set or change the server certificate file used to wrap new | |
366 | connections. | |
367 | ||
368 | @param certfile: The filename of the server certificate, | |
369 | i.e. '/etc/certs/server.pem' | |
370 | @type certfile: str | |
371 | ||
372 | Raises an IOError exception if the certfile is not present or unreadable. | |
373 | """ | |
374 | warnings.warn( | |
375 | 'setCertfile is deprecated. please use certfile property instead.', | |
376 | DeprecationWarning, stacklevel=2) | |
377 | self.certfile = certfile | |
378 | ||
379 | def accept(self): | |
380 | plain_client, addr = self.handle.accept() | |
381 | try: | |
382 | client = self._wrap_socket(plain_client) | |
383 | except (ssl.SSLError, socket.error, OSError): | |
384 | logger.exception('Error while accepting from %s', addr) | |
385 | # failed handshake/ssl wrap, close socket to client | |
386 | plain_client.close() | |
387 | # raise | |
388 | # We can't raise the exception, because it kills most TServer derived | |
389 | # serve() methods. | |
390 | # Instead, return None, and let the TServer instance deal with it in | |
391 | # other exception handling. (but TSimpleServer dies anyway) | |
392 | return None | |
393 | ||
394 | if self._should_verify: | |
395 | client.peercert = client.getpeercert() | |
396 | try: | |
397 | self._validate_callback(client.peercert, addr[0]) | |
398 | client.is_valid = True | |
399 | except Exception: | |
400 | logger.warn('Failed to validate client certificate address: %s', | |
401 | addr[0], exc_info=True) | |
402 | client.close() | |
403 | plain_client.close() | |
404 | return None | |
405 | ||
406 | result = TSocket.TSocket() | |
407 | result.handle = client | |
408 | return result |