]> git.proxmox.com Git - ceph.git/blob - ceph/src/jaegertracing/thrift/lib/cpp/src/thrift/transport/TSSLSocket.cpp
update source to Ceph Pacific 16.2.2
[ceph.git] / ceph / src / jaegertracing / thrift / lib / cpp / src / thrift / transport / TSSLSocket.cpp
1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <thrift/thrift-config.h>
21
22 #include <cstring>
23 #include <errno.h>
24 #include <memory>
25 #include <string>
26 #ifdef HAVE_ARPA_INET_H
27 #include <arpa/inet.h>
28 #endif
29 #include <sys/types.h>
30 #ifdef HAVE_SYS_SOCKET_H
31 #include <sys/socket.h>
32 #endif
33 #ifdef HAVE_SYS_POLL_H
34 #include <sys/poll.h>
35 #endif
36 #ifdef HAVE_FCNTL_H
37 #include <fcntl.h>
38 #endif
39
40 #define OPENSSL_VERSION_NO_THREAD_ID_BEFORE 0x10000000L
41 #define OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE 0x10100000L
42
43 #include <boost/shared_array.hpp>
44 #include <openssl/opensslv.h>
45 #if (OPENSSL_VERSION_NUMBER < OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE)
46 #include <openssl/engine.h>
47 #endif
48 #include <openssl/err.h>
49 #include <openssl/rand.h>
50 #include <openssl/ssl.h>
51 #include <openssl/x509v3.h>
52 #include <thrift/concurrency/Mutex.h>
53 #include <thrift/transport/TSSLSocket.h>
54 #include <thrift/transport/PlatformSocket.h>
55 #include <thrift/TToString.h>
56
57 using namespace apache::thrift::concurrency;
58 using std::string;
59
60 struct CRYPTO_dynlock_value {
61 Mutex mutex;
62 };
63
64 namespace apache {
65 namespace thrift {
66 namespace transport {
67
68 // OpenSSL initialization/cleanup
69
70 static bool openSSLInitialized = false;
71 static boost::shared_array<Mutex> mutexes;
72
73 static void callbackLocking(int mode, int n, const char*, int) {
74 if (mode & CRYPTO_LOCK) {
75 // assertion of (px != 0) here typically means that a TSSLSocket's lifetime
76 // exceeded the lifetime of the TSSLSocketFactory that created it, and the
77 // TSSLSocketFactory already ran cleanupOpenSSL(), which deleted "mutexes".
78 mutexes[n].lock();
79 } else {
80 mutexes[n].unlock();
81 }
82 }
83
84 #if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID_BEFORE)
85 static unsigned long callbackThreadID() {
86 #ifdef _WIN32
87 return (unsigned long)GetCurrentThreadId();
88 #else
89 return (unsigned long)pthread_self();
90 #endif
91 }
92 #endif
93
94 static CRYPTO_dynlock_value* dyn_create(const char*, int) {
95 return new CRYPTO_dynlock_value;
96 }
97
98 static void dyn_lock(int mode, struct CRYPTO_dynlock_value* lock, const char*, int) {
99 if (lock != nullptr) {
100 if (mode & CRYPTO_LOCK) {
101 lock->mutex.lock();
102 } else {
103 lock->mutex.unlock();
104 }
105 }
106 }
107
108 static void dyn_destroy(struct CRYPTO_dynlock_value* lock, const char*, int) {
109 delete lock;
110 }
111
112 void initializeOpenSSL() {
113 if (openSSLInitialized) {
114 return;
115 }
116 openSSLInitialized = true;
117 SSL_library_init();
118 SSL_load_error_strings();
119 ERR_load_crypto_strings();
120
121 // static locking
122 // newer versions of OpenSSL changed CRYPTO_num_locks - see THRIFT-3878
123 #ifdef CRYPTO_num_locks
124 mutexes = boost::shared_array<Mutex>(new Mutex[CRYPTO_num_locks()]);
125 #else
126 mutexes = boost::shared_array<Mutex>(new Mutex[ ::CRYPTO_num_locks()]);
127 #endif
128
129 #if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID_BEFORE)
130 CRYPTO_set_id_callback(callbackThreadID);
131 #endif
132
133 CRYPTO_set_locking_callback(callbackLocking);
134
135 // dynamic locking
136 CRYPTO_set_dynlock_create_callback(dyn_create);
137 CRYPTO_set_dynlock_lock_callback(dyn_lock);
138 CRYPTO_set_dynlock_destroy_callback(dyn_destroy);
139 }
140
141 void cleanupOpenSSL() {
142 if (!openSSLInitialized) {
143 return;
144 }
145 openSSLInitialized = false;
146
147 // https://wiki.openssl.org/index.php/Library_Initialization#Cleanup
148 // we purposefully do NOT call FIPS_mode_set(0) and leave it up to the enclosing application to manage FIPS entirely
149 #if (OPENSSL_VERSION_NUMBER < OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE)
150 ENGINE_cleanup(); // https://www.openssl.org/docs/man1.1.0/crypto/ENGINE_cleanup.html - cleanup call is needed before 1.1.0
151 #endif
152 CONF_modules_unload(1);
153 EVP_cleanup();
154 CRYPTO_cleanup_all_ex_data();
155 ERR_remove_state(0);
156 ERR_free_strings();
157
158 mutexes.reset();
159 }
160
161 static void buildErrors(string& message, int errno_copy = 0, int sslerrno = 0);
162 static bool matchName(const char* host, const char* pattern, int size);
163 static char uppercase(char c);
164
165 // SSLContext implementation
166 SSLContext::SSLContext(const SSLProtocol& protocol) {
167 if (protocol == SSLTLS) {
168 ctx_ = SSL_CTX_new(SSLv23_method());
169 #ifndef OPENSSL_NO_SSL3
170 } else if (protocol == SSLv3) {
171 ctx_ = SSL_CTX_new(SSLv3_method());
172 #endif
173 } else if (protocol == TLSv1_0) {
174 ctx_ = SSL_CTX_new(TLSv1_method());
175 } else if (protocol == TLSv1_1) {
176 ctx_ = SSL_CTX_new(TLSv1_1_method());
177 } else if (protocol == TLSv1_2) {
178 ctx_ = SSL_CTX_new(TLSv1_2_method());
179 } else {
180 /// UNKNOWN PROTOCOL!
181 throw TSSLException("SSL_CTX_new: Unknown protocol");
182 }
183
184 if (ctx_ == nullptr) {
185 string errors;
186 buildErrors(errors);
187 throw TSSLException("SSL_CTX_new: " + errors);
188 }
189 SSL_CTX_set_mode(ctx_, SSL_MODE_AUTO_RETRY);
190
191 // Disable horribly insecure SSLv2 and SSLv3 protocols but allow a handshake
192 // with older clients so they get a graceful denial.
193 if (protocol == SSLTLS) {
194 SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv2);
195 SSL_CTX_set_options(ctx_, SSL_OP_NO_SSLv3); // THRIFT-3164
196 }
197 }
198
199 SSLContext::~SSLContext() {
200 if (ctx_ != nullptr) {
201 SSL_CTX_free(ctx_);
202 ctx_ = nullptr;
203 }
204 }
205
206 SSL* SSLContext::createSSL() {
207 SSL* ssl = SSL_new(ctx_);
208 if (ssl == nullptr) {
209 string errors;
210 buildErrors(errors);
211 throw TSSLException("SSL_new: " + errors);
212 }
213 return ssl;
214 }
215
216 // TSSLSocket implementation
217 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx)
218 : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
219 init();
220 }
221
222 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, std::shared_ptr<THRIFT_SOCKET> interruptListener)
223 : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx) {
224 init();
225 interruptListener_ = interruptListener;
226 }
227
228 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket)
229 : TSocket(socket), server_(false), ssl_(nullptr), ctx_(ctx) {
230 init();
231 }
232
233 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener)
234 : TSocket(socket, interruptListener), server_(false), ssl_(nullptr), ctx_(ctx) {
235 init();
236 }
237
238 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port)
239 : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
240 init();
241 }
242
243 TSSLSocket::TSSLSocket(std::shared_ptr<SSLContext> ctx, string host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener)
244 : TSocket(host, port), server_(false), ssl_(nullptr), ctx_(ctx) {
245 init();
246 interruptListener_ = interruptListener;
247 }
248
249 TSSLSocket::~TSSLSocket() {
250 close();
251 }
252
253 bool TSSLSocket::hasPendingDataToRead() {
254 if (!isOpen()) {
255 return false;
256 }
257 initializeHandshake();
258 if (!checkHandshake())
259 throw TSSLException("TSSLSocket::hasPendingDataToRead: Handshake is not completed");
260 // data may be available in SSL buffers (note: SSL_pending does not have a failure mode)
261 return SSL_pending(ssl_) > 0 || TSocket::hasPendingDataToRead();
262 }
263
264 void TSSLSocket::init() {
265 handshakeCompleted_ = false;
266 readRetryCount_ = 0;
267 eventSafe_ = false;
268 }
269
270 bool TSSLSocket::isOpen() const {
271 if (ssl_ == nullptr || !TSocket::isOpen()) {
272 return false;
273 }
274 int shutdown = SSL_get_shutdown(ssl_);
275 // "!!" is squelching C4800 "forcing bool -> true or false" performance warning
276 bool shutdownReceived = !!(shutdown & SSL_RECEIVED_SHUTDOWN);
277 bool shutdownSent = !!(shutdown & SSL_SENT_SHUTDOWN);
278 if (shutdownReceived && shutdownSent) {
279 return false;
280 }
281 return true;
282 }
283
284 /*
285 * Note: This method is not libevent safe.
286 */
287 bool TSSLSocket::peek() {
288 if (!isOpen()) {
289 return false;
290 }
291 initializeHandshake();
292 if (!checkHandshake())
293 throw TSSLException("SSL_peek: Handshake is not completed");
294 int rc;
295 do {
296 uint8_t byte;
297 rc = SSL_peek(ssl_, &byte, 1);
298 if (rc < 0) {
299 int errno_copy = THRIFT_GET_SOCKET_ERROR;
300 int error = SSL_get_error(ssl_, rc);
301 switch (error) {
302 case SSL_ERROR_SYSCALL:
303 if ((errno_copy != THRIFT_EINTR)
304 && (errno_copy != THRIFT_EAGAIN)) {
305 break;
306 }
307 // fallthrough
308 case SSL_ERROR_WANT_READ:
309 case SSL_ERROR_WANT_WRITE:
310 // in the case of SSL_ERROR_SYSCALL we want to wait for an read event again
311 waitForEvent(error != SSL_ERROR_WANT_WRITE);
312 continue;
313 default:;// do nothing
314 }
315 string errors;
316 buildErrors(errors, errno_copy, error);
317 throw TSSLException("SSL_peek: " + errors);
318 } else if (rc == 0) {
319 ERR_clear_error();
320 break;
321 } else {
322 break;
323 }
324 } while (true);
325 return (rc > 0);
326 }
327
328 void TSSLSocket::open() {
329 if (isOpen() || server()) {
330 throw TTransportException(TTransportException::BAD_ARGS);
331 }
332 TSocket::open();
333 }
334
335 /*
336 * Note: This method is not libevent safe.
337 */
338 void TSSLSocket::close() {
339 if (ssl_ != nullptr) {
340 try {
341 int rc;
342 int errno_copy = 0;
343 int error = 0;
344
345 do {
346 rc = SSL_shutdown(ssl_);
347 if (rc <= 0) {
348 errno_copy = THRIFT_GET_SOCKET_ERROR;
349 error = SSL_get_error(ssl_, rc);
350 switch (error) {
351 case SSL_ERROR_SYSCALL:
352 if ((errno_copy != THRIFT_EINTR)
353 && (errno_copy != THRIFT_EAGAIN)) {
354 break;
355 }
356 // fallthrough
357 case SSL_ERROR_WANT_READ:
358 case SSL_ERROR_WANT_WRITE:
359 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
360 waitForEvent(error == SSL_ERROR_WANT_READ);
361 rc = 2;
362 default:;// do nothing
363 }
364 }
365 } while (rc == 2);
366
367 if (rc < 0) {
368 string errors;
369 buildErrors(errors, errno_copy, error);
370 GlobalOutput(("SSL_shutdown: " + errors).c_str());
371 }
372 } catch (TTransportException& te) {
373 // Don't emit an exception because this method is called by the
374 // destructor. There's also not much that a user can do to recover, so
375 // just clean up as much as possible without throwing, similar to the rc
376 // < 0 case above.
377 GlobalOutput.printf("SSL_shutdown: %s", te.what());
378 }
379 SSL_free(ssl_);
380 ssl_ = nullptr;
381 handshakeCompleted_ = false;
382 ERR_remove_state(0);
383 }
384 TSocket::close();
385 }
386
387 /*
388 * Returns number of bytes read in SSL Socket.
389 * If eventSafe is set, and it may returns 0 bytes then read method
390 * needs to be called again until it is successfull or it throws
391 * exception incase of failure.
392 */
393 uint32_t TSSLSocket::read(uint8_t* buf, uint32_t len) {
394 initializeHandshake();
395 if (!checkHandshake())
396 throw TTransportException(TTransportException::UNKNOWN, "retry again");
397 int32_t bytes = 0;
398 while (readRetryCount_ < maxRecvRetries_) {
399 bytes = SSL_read(ssl_, buf, len);
400 int32_t errno_copy = THRIFT_GET_SOCKET_ERROR;
401 int32_t error = SSL_get_error(ssl_, bytes);
402 readRetryCount_++;
403 if (error == SSL_ERROR_NONE) {
404 readRetryCount_ = 0;
405 break;
406 }
407 unsigned int waitEventReturn;
408 bool breakout = false;
409 switch (error) {
410 case SSL_ERROR_ZERO_RETURN:
411 throw TTransportException(TTransportException::END_OF_FILE, "client disconnected");
412
413 case SSL_ERROR_SYSCALL:
414 if (errno_copy == 0 && ERR_peek_error() == 0) {
415 breakout = true;
416 break;
417 }
418 if ((errno_copy != THRIFT_EINTR)
419 && (errno_copy != THRIFT_EAGAIN)) {
420 break;
421 }
422 if (readRetryCount_ >= maxRecvRetries_) {
423 // THRIFT_EINTR needs to be handled manually and we can tolerate
424 // a certain number
425 break;
426 }
427 // fallthrough
428
429 case SSL_ERROR_WANT_READ:
430 case SSL_ERROR_WANT_WRITE:
431 if (isLibeventSafe()) {
432 if (readRetryCount_ < maxRecvRetries_) {
433 // THRIFT_EINTR needs to be handled manually and we can tolerate
434 // a certain number
435 throw TTransportException(TTransportException::UNKNOWN, "retry again");
436 }
437 throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
438 }
439 // in the case of SSL_ERROR_SYSCALL we want to wait for an read event again
440 else if ((waitEventReturn = waitForEvent(error != SSL_ERROR_WANT_WRITE)) == TSSL_EINTR ) {
441 // repeat operation
442 if (readRetryCount_ < maxRecvRetries_) {
443 // THRIFT_EINTR needs to be handled manually and we can tolerate
444 // a certain number
445 continue;
446 }
447 throw TTransportException(TTransportException::INTERNAL_ERROR, "too much recv retries");
448 }
449 else if (waitEventReturn == TSSL_DATA) {
450 // in case of SSL and huge thrift packets, there may be a number of
451 // socket operations, before any data becomes available by SSL_read().
452 // Therefore the number of retries should not be increased and
453 // the operation should be repeated.
454 readRetryCount_--;
455 continue;
456 }
457 throw TTransportException(TTransportException::INTERNAL_ERROR, "unkown waitForEvent return value");
458 default:;// do nothing
459 }
460 if (breakout) {
461 break;
462 }
463 string errors;
464 buildErrors(errors, errno_copy, error);
465 throw TSSLException("SSL_read: " + errors);
466 }
467 return bytes;
468 }
469
470 void TSSLSocket::write(const uint8_t* buf, uint32_t len) {
471 initializeHandshake();
472 if (!checkHandshake())
473 return;
474 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
475 uint32_t written = 0;
476 while (written < len) {
477 ERR_clear_error();
478 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
479 if (bytes <= 0) {
480 int errno_copy = THRIFT_GET_SOCKET_ERROR;
481 int error = SSL_get_error(ssl_, bytes);
482 switch (error) {
483 case SSL_ERROR_SYSCALL:
484 if ((errno_copy != THRIFT_EINTR)
485 && (errno_copy != THRIFT_EAGAIN)) {
486 break;
487 }
488 // fallthrough
489 case SSL_ERROR_WANT_READ:
490 case SSL_ERROR_WANT_WRITE:
491 if (isLibeventSafe()) {
492 return;
493 }
494 else {
495 // in the case of SSL_ERROR_SYSCALL we want to wait for an write event again
496 waitForEvent(error == SSL_ERROR_WANT_READ);
497 continue;
498 }
499 default:;// do nothing
500 }
501 string errors;
502 buildErrors(errors, errno_copy, error);
503 throw TSSLException("SSL_write: " + errors);
504 }
505 written += bytes;
506 }
507 }
508
509 /*
510 * Returns number of bytes written in SSL Socket.
511 * If eventSafe is set, and it may returns 0 bytes then write method
512 * needs to be called again until it is successfull or it throws
513 * exception incase of failure.
514 */
515 uint32_t TSSLSocket::write_partial(const uint8_t* buf, uint32_t len) {
516 initializeHandshake();
517 if (!checkHandshake())
518 return 0;
519 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
520 uint32_t written = 0;
521 while (written < len) {
522 ERR_clear_error();
523 int32_t bytes = SSL_write(ssl_, &buf[written], len - written);
524 if (bytes <= 0) {
525 int errno_copy = THRIFT_GET_SOCKET_ERROR;
526 int error = SSL_get_error(ssl_, bytes);
527 switch (error) {
528 case SSL_ERROR_SYSCALL:
529 if ((errno_copy != THRIFT_EINTR)
530 && (errno_copy != THRIFT_EAGAIN)) {
531 break;
532 }
533 // fallthrough
534 case SSL_ERROR_WANT_READ:
535 case SSL_ERROR_WANT_WRITE:
536 if (isLibeventSafe()) {
537 return 0;
538 }
539 else {
540 // in the case of SSL_ERROR_SYSCALL we want to wait for an write event again
541 waitForEvent(error == SSL_ERROR_WANT_READ);
542 continue;
543 }
544 default:;// do nothing
545 }
546 string errors;
547 buildErrors(errors, errno_copy, error);
548 throw TSSLException("SSL_write: " + errors);
549 }
550 written += bytes;
551 }
552 return written;
553 }
554
555 void TSSLSocket::flush() {
556 // Don't throw exception if not open. Thrift servers close socket twice.
557 if (ssl_ == nullptr) {
558 return;
559 }
560 initializeHandshake();
561 if (!checkHandshake())
562 throw TSSLException("BIO_flush: Handshake is not completed");
563 BIO* bio = SSL_get_wbio(ssl_);
564 if (bio == nullptr) {
565 throw TSSLException("SSL_get_wbio returns NULL");
566 }
567 if (BIO_flush(bio) != 1) {
568 int errno_copy = THRIFT_GET_SOCKET_ERROR;
569 string errors;
570 buildErrors(errors, errno_copy);
571 throw TSSLException("BIO_flush: " + errors);
572 }
573 }
574
575 void TSSLSocket::initializeHandshakeParams() {
576 // set underlying socket to non-blocking
577 int flags;
578 if ((flags = THRIFT_FCNTL(socket_, THRIFT_F_GETFL, 0)) < 0
579 || THRIFT_FCNTL(socket_, THRIFT_F_SETFL, flags | THRIFT_O_NONBLOCK) < 0) {
580 GlobalOutput.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
581 THRIFT_GET_SOCKET_ERROR);
582 ::THRIFT_CLOSESOCKET(socket_);
583 return;
584 }
585 ssl_ = ctx_->createSSL();
586
587 SSL_set_fd(ssl_, static_cast<int>(socket_));
588 }
589
590 bool TSSLSocket::checkHandshake() {
591 return handshakeCompleted_;
592 }
593
594 void TSSLSocket::initializeHandshake() {
595 if (!TSocket::isOpen()) {
596 throw TTransportException(TTransportException::NOT_OPEN);
597 }
598 if (checkHandshake()) {
599 return;
600 }
601
602 if (ssl_ == nullptr) {
603 initializeHandshakeParams();
604 }
605
606 int rc;
607 int errno_copy = 0;
608 int error = 0;
609 if (server()) {
610 do {
611 rc = SSL_accept(ssl_);
612 if (rc <= 0) {
613 errno_copy = THRIFT_GET_SOCKET_ERROR;
614 error = SSL_get_error(ssl_, rc);
615 switch (error) {
616 case SSL_ERROR_SYSCALL:
617 if ((errno_copy != THRIFT_EINTR)
618 && (errno_copy != THRIFT_EAGAIN)) {
619 break;
620 }
621 // fallthrough
622 case SSL_ERROR_WANT_READ:
623 case SSL_ERROR_WANT_WRITE:
624 if (isLibeventSafe()) {
625 return;
626 }
627 else {
628 // repeat operation
629 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
630 waitForEvent(error == SSL_ERROR_WANT_READ);
631 rc = 2;
632 }
633 default:;// do nothing
634 }
635 }
636 } while (rc == 2);
637 } else {
638 // OpenSSL < 0.9.8f does not have SSL_set_tlsext_host_name()
639 #if defined(SSL_set_tlsext_host_name)
640 // set the SNI hostname
641 SSL_set_tlsext_host_name(ssl_, getHost().c_str());
642 #endif
643 do {
644 rc = SSL_connect(ssl_);
645 if (rc <= 0) {
646 errno_copy = THRIFT_GET_SOCKET_ERROR;
647 error = SSL_get_error(ssl_, rc);
648 switch (error) {
649 case SSL_ERROR_SYSCALL:
650 if ((errno_copy != THRIFT_EINTR)
651 && (errno_copy != THRIFT_EAGAIN)) {
652 break;
653 }
654 // fallthrough
655 case SSL_ERROR_WANT_READ:
656 case SSL_ERROR_WANT_WRITE:
657 if (isLibeventSafe()) {
658 return;
659 }
660 else {
661 // repeat operation
662 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
663 waitForEvent(error == SSL_ERROR_WANT_READ);
664 rc = 2;
665 }
666 default:;// do nothing
667 }
668 }
669 } while (rc == 2);
670 }
671 if (rc <= 0) {
672 string fname(server() ? "SSL_accept" : "SSL_connect");
673 string errors;
674 buildErrors(errors, errno_copy, error);
675 throw TSSLException(fname + ": " + errors);
676 }
677 authorize();
678 handshakeCompleted_ = true;
679 }
680
681 void TSSLSocket::authorize() {
682 int rc = SSL_get_verify_result(ssl_);
683 if (rc != X509_V_OK) { // verify authentication result
684 throw TSSLException(string("SSL_get_verify_result(), ") + X509_verify_cert_error_string(rc));
685 }
686
687 X509* cert = SSL_get_peer_certificate(ssl_);
688 if (cert == nullptr) {
689 // certificate is not present
690 if (SSL_get_verify_mode(ssl_) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT) {
691 throw TSSLException("authorize: required certificate not present");
692 }
693 // certificate was optional: didn't intend to authorize remote
694 if (server() && access_ != nullptr) {
695 throw TSSLException("authorize: certificate required for authorization");
696 }
697 return;
698 }
699 // certificate is present
700 if (access_ == nullptr) {
701 X509_free(cert);
702 return;
703 }
704 // both certificate and access manager are present
705
706 string host;
707 sockaddr_storage sa;
708 socklen_t saLength = sizeof(sa);
709
710 if (getpeername(socket_, (sockaddr*)&sa, &saLength) != 0) {
711 sa.ss_family = AF_UNSPEC;
712 }
713
714 AccessManager::Decision decision = access_->verify(sa);
715
716 if (decision != AccessManager::SKIP) {
717 X509_free(cert);
718 if (decision != AccessManager::ALLOW) {
719 throw TSSLException("authorize: access denied based on remote IP");
720 }
721 return;
722 }
723
724 // extract subjectAlternativeName
725 auto* alternatives
726 = (STACK_OF(GENERAL_NAME)*)X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr);
727 if (alternatives != nullptr) {
728 const int count = sk_GENERAL_NAME_num(alternatives);
729 for (int i = 0; decision == AccessManager::SKIP && i < count; i++) {
730 const GENERAL_NAME* name = sk_GENERAL_NAME_value(alternatives, i);
731 if (name == nullptr) {
732 continue;
733 }
734 char* data = (char*)ASN1_STRING_data(name->d.ia5);
735 int length = ASN1_STRING_length(name->d.ia5);
736 switch (name->type) {
737 case GEN_DNS:
738 if (host.empty()) {
739 host = (server() ? getPeerHost() : getHost());
740 }
741 decision = access_->verify(host, data, length);
742 break;
743 case GEN_IPADD:
744 decision = access_->verify(sa, data, length);
745 break;
746 }
747 }
748 sk_GENERAL_NAME_pop_free(alternatives, GENERAL_NAME_free);
749 }
750
751 if (decision != AccessManager::SKIP) {
752 X509_free(cert);
753 if (decision != AccessManager::ALLOW) {
754 throw TSSLException("authorize: access denied");
755 }
756 return;
757 }
758
759 // extract commonName
760 X509_NAME* name = X509_get_subject_name(cert);
761 if (name != nullptr) {
762 X509_NAME_ENTRY* entry;
763 unsigned char* utf8;
764 int last = -1;
765 while (decision == AccessManager::SKIP) {
766 last = X509_NAME_get_index_by_NID(name, NID_commonName, last);
767 if (last == -1)
768 break;
769 entry = X509_NAME_get_entry(name, last);
770 if (entry == nullptr)
771 continue;
772 ASN1_STRING* common = X509_NAME_ENTRY_get_data(entry);
773 int size = ASN1_STRING_to_UTF8(&utf8, common);
774 if (host.empty()) {
775 host = (server() ? getPeerHost() : getHost());
776 }
777 decision = access_->verify(host, (char*)utf8, size);
778 OPENSSL_free(utf8);
779 }
780 }
781 X509_free(cert);
782 if (decision != AccessManager::ALLOW) {
783 throw TSSLException("authorize: cannot authorize peer");
784 }
785 }
786
787 /*
788 * Note: This method is not libevent safe.
789 */
790 unsigned int TSSLSocket::waitForEvent(bool wantRead) {
791 int fdSocket;
792 BIO* bio;
793
794 if (wantRead) {
795 bio = SSL_get_rbio(ssl_);
796 } else {
797 bio = SSL_get_wbio(ssl_);
798 }
799
800 if (bio == nullptr) {
801 throw TSSLException("SSL_get_?bio returned NULL");
802 }
803
804 if (BIO_get_fd(bio, &fdSocket) <= 0) {
805 throw TSSLException("BIO_get_fd failed");
806 }
807
808 struct THRIFT_POLLFD fds[2];
809 memset(fds, 0, sizeof(fds));
810 fds[0].fd = fdSocket;
811 // use POLLIN also on write operations too, this is needed for operations
812 // which requires read and write on the socket.
813 fds[0].events = wantRead ? THRIFT_POLLIN : THRIFT_POLLIN | THRIFT_POLLOUT;
814
815 if (interruptListener_) {
816 fds[1].fd = *(interruptListener_.get());
817 fds[1].events = THRIFT_POLLIN;
818 }
819
820 int timeout = -1;
821 if (wantRead && recvTimeout_) {
822 timeout = recvTimeout_;
823 }
824 if (!wantRead && sendTimeout_) {
825 timeout = sendTimeout_;
826 }
827
828 int ret = THRIFT_POLL(fds, interruptListener_ ? 2 : 1, timeout);
829
830 if (ret < 0) {
831 // error cases
832 if (THRIFT_GET_SOCKET_ERROR == THRIFT_EINTR) {
833 return TSSL_EINTR; // repeat operation
834 }
835 int errno_copy = THRIFT_GET_SOCKET_ERROR;
836 GlobalOutput.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy);
837 throw TTransportException(TTransportException::UNKNOWN, "Unknown", errno_copy);
838 } else if (ret > 0){
839 if (fds[1].revents & THRIFT_POLLIN) {
840 throw TTransportException(TTransportException::INTERRUPTED, "Interrupted");
841 }
842 return TSSL_DATA;
843 } else {
844 throw TTransportException(TTransportException::TIMED_OUT, "THRIFT_POLL (timed out)");
845 }
846 }
847
848 // TSSLSocketFactory implementation
849 uint64_t TSSLSocketFactory::count_ = 0;
850 Mutex TSSLSocketFactory::mutex_;
851 bool TSSLSocketFactory::manualOpenSSLInitialization_ = false;
852
853 TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol) : server_(false) {
854 Guard guard(mutex_);
855 if (count_ == 0) {
856 if (!manualOpenSSLInitialization_) {
857 initializeOpenSSL();
858 }
859 randomize();
860 }
861 count_++;
862 ctx_ = std::make_shared<SSLContext>(protocol);
863 }
864
865 TSSLSocketFactory::~TSSLSocketFactory() {
866 Guard guard(mutex_);
867 ctx_.reset();
868 count_--;
869 if (count_ == 0 && !manualOpenSSLInitialization_) {
870 cleanupOpenSSL();
871 }
872 }
873
874 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket() {
875 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_));
876 setup(ssl);
877 return ssl;
878 }
879
880 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(std::shared_ptr<THRIFT_SOCKET> interruptListener) {
881 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, interruptListener));
882 setup(ssl);
883 return ssl;
884 }
885
886 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket) {
887 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket));
888 setup(ssl);
889 return ssl;
890 }
891
892 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket, std::shared_ptr<THRIFT_SOCKET> interruptListener) {
893 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, socket, interruptListener));
894 setup(ssl);
895 return ssl;
896 }
897
898 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port) {
899 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port));
900 setup(ssl);
901 return ssl;
902 }
903
904 std::shared_ptr<TSSLSocket> TSSLSocketFactory::createSocket(const string& host, int port, std::shared_ptr<THRIFT_SOCKET> interruptListener) {
905 std::shared_ptr<TSSLSocket> ssl(new TSSLSocket(ctx_, host, port, interruptListener));
906 setup(ssl);
907 return ssl;
908 }
909
910
911 void TSSLSocketFactory::setup(std::shared_ptr<TSSLSocket> ssl) {
912 ssl->server(server());
913 if (access_ == nullptr && !server()) {
914 access_ = std::shared_ptr<AccessManager>(new DefaultClientAccessManager);
915 }
916 if (access_ != nullptr) {
917 ssl->access(access_);
918 }
919 }
920
921 void TSSLSocketFactory::ciphers(const string& enable) {
922 int rc = SSL_CTX_set_cipher_list(ctx_->get(), enable.c_str());
923 if (ERR_peek_error() != 0) {
924 string errors;
925 buildErrors(errors);
926 throw TSSLException("SSL_CTX_set_cipher_list: " + errors);
927 }
928 if (rc == 0) {
929 throw TSSLException("None of specified ciphers are supported");
930 }
931 }
932
933 void TSSLSocketFactory::authenticate(bool required) {
934 int mode;
935 if (required) {
936 mode = SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT | SSL_VERIFY_CLIENT_ONCE;
937 } else {
938 mode = SSL_VERIFY_NONE;
939 }
940 SSL_CTX_set_verify(ctx_->get(), mode, nullptr);
941 }
942
943 void TSSLSocketFactory::loadCertificate(const char* path, const char* format) {
944 if (path == nullptr || format == nullptr) {
945 throw TTransportException(TTransportException::BAD_ARGS,
946 "loadCertificateChain: either <path> or <format> is NULL");
947 }
948 if (strcmp(format, "PEM") == 0) {
949 if (SSL_CTX_use_certificate_chain_file(ctx_->get(), path) == 0) {
950 int errno_copy = THRIFT_GET_SOCKET_ERROR;
951 string errors;
952 buildErrors(errors, errno_copy);
953 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors);
954 }
955 } else {
956 throw TSSLException("Unsupported certificate format: " + string(format));
957 }
958 }
959
960 void TSSLSocketFactory::loadPrivateKey(const char* path, const char* format) {
961 if (path == nullptr || format == nullptr) {
962 throw TTransportException(TTransportException::BAD_ARGS,
963 "loadPrivateKey: either <path> or <format> is NULL");
964 }
965 if (strcmp(format, "PEM") == 0) {
966 if (SSL_CTX_use_PrivateKey_file(ctx_->get(), path, SSL_FILETYPE_PEM) == 0) {
967 int errno_copy = THRIFT_GET_SOCKET_ERROR;
968 string errors;
969 buildErrors(errors, errno_copy);
970 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors);
971 }
972 }
973 }
974
975 void TSSLSocketFactory::loadTrustedCertificates(const char* path, const char* capath) {
976 if (path == nullptr) {
977 throw TTransportException(TTransportException::BAD_ARGS,
978 "loadTrustedCertificates: <path> is NULL");
979 }
980 if (SSL_CTX_load_verify_locations(ctx_->get(), path, capath) == 0) {
981 int errno_copy = THRIFT_GET_SOCKET_ERROR;
982 string errors;
983 buildErrors(errors, errno_copy);
984 throw TSSLException("SSL_CTX_load_verify_locations: " + errors);
985 }
986 }
987
988 void TSSLSocketFactory::randomize() {
989 RAND_poll();
990 }
991
992 void TSSLSocketFactory::overrideDefaultPasswordCallback() {
993 SSL_CTX_set_default_passwd_cb(ctx_->get(), passwordCallback);
994 SSL_CTX_set_default_passwd_cb_userdata(ctx_->get(), this);
995 }
996
997 int TSSLSocketFactory::passwordCallback(char* password, int size, int, void* data) {
998 auto* factory = (TSSLSocketFactory*)data;
999 string userPassword;
1000 factory->getPassword(userPassword, size);
1001 int length = static_cast<int>(userPassword.size());
1002 if (length > size) {
1003 length = size;
1004 }
1005 strncpy(password, userPassword.c_str(), length);
1006 userPassword.assign(userPassword.size(), '*');
1007 return length;
1008 }
1009
1010 // extract error messages from error queue
1011 void buildErrors(string& errors, int errno_copy, int sslerrno) {
1012 unsigned long errorCode;
1013 char message[256];
1014
1015 errors.reserve(512);
1016 while ((errorCode = ERR_get_error()) != 0) {
1017 if (!errors.empty()) {
1018 errors += "; ";
1019 }
1020 const char* reason = ERR_reason_error_string(errorCode);
1021 if (reason == nullptr) {
1022 THRIFT_SNPRINTF(message, sizeof(message) - 1, "SSL error # %lu", errorCode);
1023 reason = message;
1024 }
1025 errors += reason;
1026 }
1027 if (errors.empty()) {
1028 if (errno_copy != 0) {
1029 errors += TOutput::strerror_s(errno_copy);
1030 }
1031 }
1032 if (errors.empty()) {
1033 errors = "error code: " + to_string(errno_copy);
1034 }
1035 if (sslerrno) {
1036 errors += " (SSL_error_code = " + to_string(sslerrno) + ")";
1037 if (sslerrno == SSL_ERROR_SYSCALL) {
1038 char buf[4096];
1039 int err;
1040 while ((err = ERR_get_error()) != 0) {
1041 errors += " ";
1042 errors += ERR_error_string(err, buf);
1043 }
1044 }
1045 }
1046 }
1047
1048 /**
1049 * Default implementation of AccessManager
1050 */
1051 Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa) noexcept {
1052 (void)sa;
1053 return SKIP;
1054 }
1055
1056 Decision DefaultClientAccessManager::verify(const string& host,
1057 const char* name,
1058 int size) noexcept {
1059 if (host.empty() || name == nullptr || size <= 0) {
1060 return SKIP;
1061 }
1062 return (matchName(host.c_str(), name, size) ? ALLOW : SKIP);
1063 }
1064
1065 Decision DefaultClientAccessManager::verify(const sockaddr_storage& sa,
1066 const char* data,
1067 int size) noexcept {
1068 bool match = false;
1069 if (sa.ss_family == AF_INET && size == sizeof(in_addr)) {
1070 match = (memcmp(&((sockaddr_in*)&sa)->sin_addr, data, size) == 0);
1071 } else if (sa.ss_family == AF_INET6 && size == sizeof(in6_addr)) {
1072 match = (memcmp(&((sockaddr_in6*)&sa)->sin6_addr, data, size) == 0);
1073 }
1074 return (match ? ALLOW : SKIP);
1075 }
1076
1077 /**
1078 * Match a name with a pattern. The pattern may include wildcard. A single
1079 * wildcard "*" can match up to one component in the domain name.
1080 *
1081 * @param host Host name, typically the name of the remote host
1082 * @param pattern Name retrieved from certificate
1083 * @param size Size of "pattern"
1084 * @return True, if "host" matches "pattern". False otherwise.
1085 */
1086 bool matchName(const char* host, const char* pattern, int size) {
1087 bool match = false;
1088 int i = 0, j = 0;
1089 while (i < size && host[j] != '\0') {
1090 if (uppercase(pattern[i]) == uppercase(host[j])) {
1091 i++;
1092 j++;
1093 continue;
1094 }
1095 if (pattern[i] == '*') {
1096 while (host[j] != '.' && host[j] != '\0') {
1097 j++;
1098 }
1099 i++;
1100 continue;
1101 }
1102 break;
1103 }
1104 if (i == size && host[j] == '\0') {
1105 match = true;
1106 }
1107 return match;
1108 }
1109
1110 // This is to work around the Turkish locale issue, i.e.,
1111 // toupper('i') != toupper('I') if locale is "tr_TR"
1112 char uppercase(char c) {
1113 if ('a' <= c && c <= 'z') {
1114 return c + ('A' - 'a');
1115 }
1116 return c;
1117 }
1118 }
1119 }
1120 }