2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
20 #include <thrift/thrift-config.h>
26 #ifdef HAVE_ARPA_INET_H
27 #include <arpa/inet.h>
29 #include <sys/types.h>
30 #ifdef HAVE_SYS_SOCKET_H
31 #include <sys/socket.h>
33 #ifdef HAVE_SYS_POLL_H
40 #define OPENSSL_VERSION_NO_THREAD_ID_BEFORE 0x10000000L
41 #define OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE 0x10100000L
43 #include <boost/shared_array.hpp>
44 #include <openssl/opensslv.h>
45 #if (OPENSSL_VERSION_NUMBER < OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE)
46 #include <openssl/engine.h>
48 #include <openssl/err.h>
49 #include <openssl/rand.h>
50 #include <openssl/ssl.h>
51 #include <openssl/x509v3.h>
52 #include <thrift/concurrency/Mutex.h>
53 #include <thrift/transport/TSSLSocket.h>
54 #include <thrift/transport/PlatformSocket.h>
55 #include <thrift/TToString.h>
57 using namespace apache::thrift::concurrency
;
60 struct CRYPTO_dynlock_value
{
68 // OpenSSL initialization/cleanup
70 static bool openSSLInitialized
= false;
71 static boost::shared_array
<Mutex
> mutexes
;
73 static void callbackLocking(int mode
, int n
, const char*, int) {
74 if (mode
& CRYPTO_LOCK
) {
75 // assertion of (px != 0) here typically means that a TSSLSocket's lifetime
76 // exceeded the lifetime of the TSSLSocketFactory that created it, and the
77 // TSSLSocketFactory already ran cleanupOpenSSL(), which deleted "mutexes".
84 #if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID_BEFORE)
85 static unsigned long callbackThreadID() {
87 return (unsigned long)GetCurrentThreadId();
89 return (unsigned long)pthread_self();
94 static CRYPTO_dynlock_value
* dyn_create(const char*, int) {
95 return new CRYPTO_dynlock_value
;
98 static void dyn_lock(int mode
, struct CRYPTO_dynlock_value
* lock
, const char*, int) {
99 if (lock
!= nullptr) {
100 if (mode
& CRYPTO_LOCK
) {
103 lock
->mutex
.unlock();
108 static void dyn_destroy(struct CRYPTO_dynlock_value
* lock
, const char*, int) {
112 void initializeOpenSSL() {
113 if (openSSLInitialized
) {
116 openSSLInitialized
= true;
118 SSL_load_error_strings();
119 ERR_load_crypto_strings();
122 // newer versions of OpenSSL changed CRYPTO_num_locks - see THRIFT-3878
123 #ifdef CRYPTO_num_locks
124 mutexes
= boost::shared_array
<Mutex
>(new Mutex
[CRYPTO_num_locks()]);
126 mutexes
= boost::shared_array
<Mutex
>(new Mutex
[ ::CRYPTO_num_locks()]);
129 #if (OPENSSL_VERSION_NUMBER < OPENSSL_VERSION_NO_THREAD_ID_BEFORE)
130 CRYPTO_set_id_callback(callbackThreadID
);
133 CRYPTO_set_locking_callback(callbackLocking
);
136 CRYPTO_set_dynlock_create_callback(dyn_create
);
137 CRYPTO_set_dynlock_lock_callback(dyn_lock
);
138 CRYPTO_set_dynlock_destroy_callback(dyn_destroy
);
141 void cleanupOpenSSL() {
142 if (!openSSLInitialized
) {
145 openSSLInitialized
= false;
147 // https://wiki.openssl.org/index.php/Library_Initialization#Cleanup
148 // we purposefully do NOT call FIPS_mode_set(0) and leave it up to the enclosing application to manage FIPS entirely
149 #if (OPENSSL_VERSION_NUMBER < OPENSSL_ENGINE_CLEANUP_REQUIRED_BEFORE)
150 ENGINE_cleanup(); // https://www.openssl.org/docs/man1.1.0/crypto/ENGINE_cleanup.html - cleanup call is needed before 1.1.0
152 CONF_modules_unload(1);
154 CRYPTO_cleanup_all_ex_data();
161 static void buildErrors(string
& message
, int errno_copy
= 0, int sslerrno
= 0);
162 static bool matchName(const char* host
, const char* pattern
, int size
);
163 static char uppercase(char c
);
165 // SSLContext implementation
166 SSLContext::SSLContext(const SSLProtocol
& protocol
) {
167 if (protocol
== SSLTLS
) {
168 ctx_
= SSL_CTX_new(SSLv23_method());
169 #ifndef OPENSSL_NO_SSL3
170 } else if (protocol
== SSLv3
) {
171 ctx_
= SSL_CTX_new(SSLv3_method());
173 } else if (protocol
== TLSv1_0
) {
174 ctx_
= SSL_CTX_new(TLSv1_method());
175 } else if (protocol
== TLSv1_1
) {
176 ctx_
= SSL_CTX_new(TLSv1_1_method());
177 } else if (protocol
== TLSv1_2
) {
178 ctx_
= SSL_CTX_new(TLSv1_2_method());
180 /// UNKNOWN PROTOCOL!
181 throw TSSLException("SSL_CTX_new: Unknown protocol");
184 if (ctx_
== nullptr) {
187 throw TSSLException("SSL_CTX_new: " + errors
);
189 SSL_CTX_set_mode(ctx_
, SSL_MODE_AUTO_RETRY
);
191 // Disable horribly insecure SSLv2 and SSLv3 protocols but allow a handshake
192 // with older clients so they get a graceful denial.
193 if (protocol
== SSLTLS
) {
194 SSL_CTX_set_options(ctx_
, SSL_OP_NO_SSLv2
);
195 SSL_CTX_set_options(ctx_
, SSL_OP_NO_SSLv3
); // THRIFT-3164
199 SSLContext::~SSLContext() {
200 if (ctx_
!= nullptr) {
206 SSL
* SSLContext::createSSL() {
207 SSL
* ssl
= SSL_new(ctx_
);
208 if (ssl
== nullptr) {
211 throw TSSLException("SSL_new: " + errors
);
216 // TSSLSocket implementation
217 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
)
218 : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx
) {
222 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
, std::shared_ptr
<THRIFT_SOCKET
> interruptListener
)
223 : TSocket(), server_(false), ssl_(nullptr), ctx_(ctx
) {
225 interruptListener_
= interruptListener
;
228 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
, THRIFT_SOCKET socket
)
229 : TSocket(socket
), server_(false), ssl_(nullptr), ctx_(ctx
) {
233 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
, THRIFT_SOCKET socket
, std::shared_ptr
<THRIFT_SOCKET
> interruptListener
)
234 : TSocket(socket
, interruptListener
), server_(false), ssl_(nullptr), ctx_(ctx
) {
238 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
, string host
, int port
)
239 : TSocket(host
, port
), server_(false), ssl_(nullptr), ctx_(ctx
) {
243 TSSLSocket::TSSLSocket(std::shared_ptr
<SSLContext
> ctx
, string host
, int port
, std::shared_ptr
<THRIFT_SOCKET
> interruptListener
)
244 : TSocket(host
, port
), server_(false), ssl_(nullptr), ctx_(ctx
) {
246 interruptListener_
= interruptListener
;
249 TSSLSocket::~TSSLSocket() {
253 bool TSSLSocket::hasPendingDataToRead() {
257 initializeHandshake();
258 if (!checkHandshake())
259 throw TSSLException("TSSLSocket::hasPendingDataToRead: Handshake is not completed");
260 // data may be available in SSL buffers (note: SSL_pending does not have a failure mode)
261 return SSL_pending(ssl_
) > 0 || TSocket::hasPendingDataToRead();
264 void TSSLSocket::init() {
265 handshakeCompleted_
= false;
270 bool TSSLSocket::isOpen() const {
271 if (ssl_
== nullptr || !TSocket::isOpen()) {
274 int shutdown
= SSL_get_shutdown(ssl_
);
275 // "!!" is squelching C4800 "forcing bool -> true or false" performance warning
276 bool shutdownReceived
= !!(shutdown
& SSL_RECEIVED_SHUTDOWN
);
277 bool shutdownSent
= !!(shutdown
& SSL_SENT_SHUTDOWN
);
278 if (shutdownReceived
&& shutdownSent
) {
285 * Note: This method is not libevent safe.
287 bool TSSLSocket::peek() {
291 initializeHandshake();
292 if (!checkHandshake())
293 throw TSSLException("SSL_peek: Handshake is not completed");
297 rc
= SSL_peek(ssl_
, &byte
, 1);
299 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
300 int error
= SSL_get_error(ssl_
, rc
);
302 case SSL_ERROR_SYSCALL
:
303 if ((errno_copy
!= THRIFT_EINTR
)
304 && (errno_copy
!= THRIFT_EAGAIN
)) {
308 case SSL_ERROR_WANT_READ
:
309 case SSL_ERROR_WANT_WRITE
:
310 // in the case of SSL_ERROR_SYSCALL we want to wait for an read event again
311 waitForEvent(error
!= SSL_ERROR_WANT_WRITE
);
313 default:;// do nothing
316 buildErrors(errors
, errno_copy
, error
);
317 throw TSSLException("SSL_peek: " + errors
);
318 } else if (rc
== 0) {
328 void TSSLSocket::open() {
329 if (isOpen() || server()) {
330 throw TTransportException(TTransportException::BAD_ARGS
);
336 * Note: This method is not libevent safe.
338 void TSSLSocket::close() {
339 if (ssl_
!= nullptr) {
346 rc
= SSL_shutdown(ssl_
);
348 errno_copy
= THRIFT_GET_SOCKET_ERROR
;
349 error
= SSL_get_error(ssl_
, rc
);
351 case SSL_ERROR_SYSCALL
:
352 if ((errno_copy
!= THRIFT_EINTR
)
353 && (errno_copy
!= THRIFT_EAGAIN
)) {
357 case SSL_ERROR_WANT_READ
:
358 case SSL_ERROR_WANT_WRITE
:
359 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
360 waitForEvent(error
== SSL_ERROR_WANT_READ
);
362 default:;// do nothing
369 buildErrors(errors
, errno_copy
, error
);
370 GlobalOutput(("SSL_shutdown: " + errors
).c_str());
372 } catch (TTransportException
& te
) {
373 // Don't emit an exception because this method is called by the
374 // destructor. There's also not much that a user can do to recover, so
375 // just clean up as much as possible without throwing, similar to the rc
377 GlobalOutput
.printf("SSL_shutdown: %s", te
.what());
381 handshakeCompleted_
= false;
388 * Returns number of bytes read in SSL Socket.
389 * If eventSafe is set, and it may returns 0 bytes then read method
390 * needs to be called again until it is successfull or it throws
391 * exception incase of failure.
393 uint32_t TSSLSocket::read(uint8_t* buf
, uint32_t len
) {
394 initializeHandshake();
395 if (!checkHandshake())
396 throw TTransportException(TTransportException::UNKNOWN
, "retry again");
398 while (readRetryCount_
< maxRecvRetries_
) {
399 bytes
= SSL_read(ssl_
, buf
, len
);
400 int32_t errno_copy
= THRIFT_GET_SOCKET_ERROR
;
401 int32_t error
= SSL_get_error(ssl_
, bytes
);
403 if (error
== SSL_ERROR_NONE
) {
407 unsigned int waitEventReturn
;
408 bool breakout
= false;
410 case SSL_ERROR_ZERO_RETURN
:
411 throw TTransportException(TTransportException::END_OF_FILE
, "client disconnected");
413 case SSL_ERROR_SYSCALL
:
414 if (errno_copy
== 0 && ERR_peek_error() == 0) {
418 if ((errno_copy
!= THRIFT_EINTR
)
419 && (errno_copy
!= THRIFT_EAGAIN
)) {
422 if (readRetryCount_
>= maxRecvRetries_
) {
423 // THRIFT_EINTR needs to be handled manually and we can tolerate
429 case SSL_ERROR_WANT_READ
:
430 case SSL_ERROR_WANT_WRITE
:
431 if (isLibeventSafe()) {
432 if (readRetryCount_
< maxRecvRetries_
) {
433 // THRIFT_EINTR needs to be handled manually and we can tolerate
435 throw TTransportException(TTransportException::UNKNOWN
, "retry again");
437 throw TTransportException(TTransportException::INTERNAL_ERROR
, "too much recv retries");
439 // in the case of SSL_ERROR_SYSCALL we want to wait for an read event again
440 else if ((waitEventReturn
= waitForEvent(error
!= SSL_ERROR_WANT_WRITE
)) == TSSL_EINTR
) {
442 if (readRetryCount_
< maxRecvRetries_
) {
443 // THRIFT_EINTR needs to be handled manually and we can tolerate
447 throw TTransportException(TTransportException::INTERNAL_ERROR
, "too much recv retries");
449 else if (waitEventReturn
== TSSL_DATA
) {
450 // in case of SSL and huge thrift packets, there may be a number of
451 // socket operations, before any data becomes available by SSL_read().
452 // Therefore the number of retries should not be increased and
453 // the operation should be repeated.
457 throw TTransportException(TTransportException::INTERNAL_ERROR
, "unkown waitForEvent return value");
458 default:;// do nothing
464 buildErrors(errors
, errno_copy
, error
);
465 throw TSSLException("SSL_read: " + errors
);
470 void TSSLSocket::write(const uint8_t* buf
, uint32_t len
) {
471 initializeHandshake();
472 if (!checkHandshake())
474 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
475 uint32_t written
= 0;
476 while (written
< len
) {
478 int32_t bytes
= SSL_write(ssl_
, &buf
[written
], len
- written
);
480 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
481 int error
= SSL_get_error(ssl_
, bytes
);
483 case SSL_ERROR_SYSCALL
:
484 if ((errno_copy
!= THRIFT_EINTR
)
485 && (errno_copy
!= THRIFT_EAGAIN
)) {
489 case SSL_ERROR_WANT_READ
:
490 case SSL_ERROR_WANT_WRITE
:
491 if (isLibeventSafe()) {
495 // in the case of SSL_ERROR_SYSCALL we want to wait for an write event again
496 waitForEvent(error
== SSL_ERROR_WANT_READ
);
499 default:;// do nothing
502 buildErrors(errors
, errno_copy
, error
);
503 throw TSSLException("SSL_write: " + errors
);
510 * Returns number of bytes written in SSL Socket.
511 * If eventSafe is set, and it may returns 0 bytes then write method
512 * needs to be called again until it is successfull or it throws
513 * exception incase of failure.
515 uint32_t TSSLSocket::write_partial(const uint8_t* buf
, uint32_t len
) {
516 initializeHandshake();
517 if (!checkHandshake())
519 // loop in case SSL_MODE_ENABLE_PARTIAL_WRITE is set in SSL_CTX.
520 uint32_t written
= 0;
521 while (written
< len
) {
523 int32_t bytes
= SSL_write(ssl_
, &buf
[written
], len
- written
);
525 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
526 int error
= SSL_get_error(ssl_
, bytes
);
528 case SSL_ERROR_SYSCALL
:
529 if ((errno_copy
!= THRIFT_EINTR
)
530 && (errno_copy
!= THRIFT_EAGAIN
)) {
534 case SSL_ERROR_WANT_READ
:
535 case SSL_ERROR_WANT_WRITE
:
536 if (isLibeventSafe()) {
540 // in the case of SSL_ERROR_SYSCALL we want to wait for an write event again
541 waitForEvent(error
== SSL_ERROR_WANT_READ
);
544 default:;// do nothing
547 buildErrors(errors
, errno_copy
, error
);
548 throw TSSLException("SSL_write: " + errors
);
555 void TSSLSocket::flush() {
556 // Don't throw exception if not open. Thrift servers close socket twice.
557 if (ssl_
== nullptr) {
560 initializeHandshake();
561 if (!checkHandshake())
562 throw TSSLException("BIO_flush: Handshake is not completed");
563 BIO
* bio
= SSL_get_wbio(ssl_
);
564 if (bio
== nullptr) {
565 throw TSSLException("SSL_get_wbio returns NULL");
567 if (BIO_flush(bio
) != 1) {
568 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
570 buildErrors(errors
, errno_copy
);
571 throw TSSLException("BIO_flush: " + errors
);
575 void TSSLSocket::initializeHandshakeParams() {
576 // set underlying socket to non-blocking
578 if ((flags
= THRIFT_FCNTL(socket_
, THRIFT_F_GETFL
, 0)) < 0
579 || THRIFT_FCNTL(socket_
, THRIFT_F_SETFL
, flags
| THRIFT_O_NONBLOCK
) < 0) {
580 GlobalOutput
.perror("thriftServerEventHandler: set THRIFT_O_NONBLOCK (THRIFT_FCNTL) ",
581 THRIFT_GET_SOCKET_ERROR
);
582 ::THRIFT_CLOSESOCKET(socket_
);
585 ssl_
= ctx_
->createSSL();
587 SSL_set_fd(ssl_
, static_cast<int>(socket_
));
590 bool TSSLSocket::checkHandshake() {
591 return handshakeCompleted_
;
594 void TSSLSocket::initializeHandshake() {
595 if (!TSocket::isOpen()) {
596 throw TTransportException(TTransportException::NOT_OPEN
);
598 if (checkHandshake()) {
602 if (ssl_
== nullptr) {
603 initializeHandshakeParams();
611 rc
= SSL_accept(ssl_
);
613 errno_copy
= THRIFT_GET_SOCKET_ERROR
;
614 error
= SSL_get_error(ssl_
, rc
);
616 case SSL_ERROR_SYSCALL
:
617 if ((errno_copy
!= THRIFT_EINTR
)
618 && (errno_copy
!= THRIFT_EAGAIN
)) {
622 case SSL_ERROR_WANT_READ
:
623 case SSL_ERROR_WANT_WRITE
:
624 if (isLibeventSafe()) {
629 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
630 waitForEvent(error
== SSL_ERROR_WANT_READ
);
633 default:;// do nothing
638 // OpenSSL < 0.9.8f does not have SSL_set_tlsext_host_name()
639 #if defined(SSL_set_tlsext_host_name)
640 // set the SNI hostname
641 SSL_set_tlsext_host_name(ssl_
, getHost().c_str());
644 rc
= SSL_connect(ssl_
);
646 errno_copy
= THRIFT_GET_SOCKET_ERROR
;
647 error
= SSL_get_error(ssl_
, rc
);
649 case SSL_ERROR_SYSCALL
:
650 if ((errno_copy
!= THRIFT_EINTR
)
651 && (errno_copy
!= THRIFT_EAGAIN
)) {
655 case SSL_ERROR_WANT_READ
:
656 case SSL_ERROR_WANT_WRITE
:
657 if (isLibeventSafe()) {
662 // in the case of SSL_ERROR_SYSCALL we want to wait for an write/read event again
663 waitForEvent(error
== SSL_ERROR_WANT_READ
);
666 default:;// do nothing
672 string
fname(server() ? "SSL_accept" : "SSL_connect");
674 buildErrors(errors
, errno_copy
, error
);
675 throw TSSLException(fname
+ ": " + errors
);
678 handshakeCompleted_
= true;
681 void TSSLSocket::authorize() {
682 int rc
= SSL_get_verify_result(ssl_
);
683 if (rc
!= X509_V_OK
) { // verify authentication result
684 throw TSSLException(string("SSL_get_verify_result(), ") + X509_verify_cert_error_string(rc
));
687 X509
* cert
= SSL_get_peer_certificate(ssl_
);
688 if (cert
== nullptr) {
689 // certificate is not present
690 if (SSL_get_verify_mode(ssl_
) & SSL_VERIFY_FAIL_IF_NO_PEER_CERT
) {
691 throw TSSLException("authorize: required certificate not present");
693 // certificate was optional: didn't intend to authorize remote
694 if (server() && access_
!= nullptr) {
695 throw TSSLException("authorize: certificate required for authorization");
699 // certificate is present
700 if (access_
== nullptr) {
704 // both certificate and access manager are present
708 socklen_t saLength
= sizeof(sa
);
710 if (getpeername(socket_
, (sockaddr
*)&sa
, &saLength
) != 0) {
711 sa
.ss_family
= AF_UNSPEC
;
714 AccessManager::Decision decision
= access_
->verify(sa
);
716 if (decision
!= AccessManager::SKIP
) {
718 if (decision
!= AccessManager::ALLOW
) {
719 throw TSSLException("authorize: access denied based on remote IP");
724 // extract subjectAlternativeName
726 = (STACK_OF(GENERAL_NAME
)*)X509_get_ext_d2i(cert
, NID_subject_alt_name
, nullptr, nullptr);
727 if (alternatives
!= nullptr) {
728 const int count
= sk_GENERAL_NAME_num(alternatives
);
729 for (int i
= 0; decision
== AccessManager::SKIP
&& i
< count
; i
++) {
730 const GENERAL_NAME
* name
= sk_GENERAL_NAME_value(alternatives
, i
);
731 if (name
== nullptr) {
734 char* data
= (char*)ASN1_STRING_data(name
->d
.ia5
);
735 int length
= ASN1_STRING_length(name
->d
.ia5
);
736 switch (name
->type
) {
739 host
= (server() ? getPeerHost() : getHost());
741 decision
= access_
->verify(host
, data
, length
);
744 decision
= access_
->verify(sa
, data
, length
);
748 sk_GENERAL_NAME_pop_free(alternatives
, GENERAL_NAME_free
);
751 if (decision
!= AccessManager::SKIP
) {
753 if (decision
!= AccessManager::ALLOW
) {
754 throw TSSLException("authorize: access denied");
759 // extract commonName
760 X509_NAME
* name
= X509_get_subject_name(cert
);
761 if (name
!= nullptr) {
762 X509_NAME_ENTRY
* entry
;
765 while (decision
== AccessManager::SKIP
) {
766 last
= X509_NAME_get_index_by_NID(name
, NID_commonName
, last
);
769 entry
= X509_NAME_get_entry(name
, last
);
770 if (entry
== nullptr)
772 ASN1_STRING
* common
= X509_NAME_ENTRY_get_data(entry
);
773 int size
= ASN1_STRING_to_UTF8(&utf8
, common
);
775 host
= (server() ? getPeerHost() : getHost());
777 decision
= access_
->verify(host
, (char*)utf8
, size
);
782 if (decision
!= AccessManager::ALLOW
) {
783 throw TSSLException("authorize: cannot authorize peer");
788 * Note: This method is not libevent safe.
790 unsigned int TSSLSocket::waitForEvent(bool wantRead
) {
795 bio
= SSL_get_rbio(ssl_
);
797 bio
= SSL_get_wbio(ssl_
);
800 if (bio
== nullptr) {
801 throw TSSLException("SSL_get_?bio returned NULL");
804 if (BIO_get_fd(bio
, &fdSocket
) <= 0) {
805 throw TSSLException("BIO_get_fd failed");
808 struct THRIFT_POLLFD fds
[2];
809 memset(fds
, 0, sizeof(fds
));
810 fds
[0].fd
= fdSocket
;
811 // use POLLIN also on write operations too, this is needed for operations
812 // which requires read and write on the socket.
813 fds
[0].events
= wantRead
? THRIFT_POLLIN
: THRIFT_POLLIN
| THRIFT_POLLOUT
;
815 if (interruptListener_
) {
816 fds
[1].fd
= *(interruptListener_
.get());
817 fds
[1].events
= THRIFT_POLLIN
;
821 if (wantRead
&& recvTimeout_
) {
822 timeout
= recvTimeout_
;
824 if (!wantRead
&& sendTimeout_
) {
825 timeout
= sendTimeout_
;
828 int ret
= THRIFT_POLL(fds
, interruptListener_
? 2 : 1, timeout
);
832 if (THRIFT_GET_SOCKET_ERROR
== THRIFT_EINTR
) {
833 return TSSL_EINTR
; // repeat operation
835 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
836 GlobalOutput
.perror("TSSLSocket::read THRIFT_POLL() ", errno_copy
);
837 throw TTransportException(TTransportException::UNKNOWN
, "Unknown", errno_copy
);
839 if (fds
[1].revents
& THRIFT_POLLIN
) {
840 throw TTransportException(TTransportException::INTERRUPTED
, "Interrupted");
844 throw TTransportException(TTransportException::TIMED_OUT
, "THRIFT_POLL (timed out)");
848 // TSSLSocketFactory implementation
849 uint64_t TSSLSocketFactory::count_
= 0;
850 Mutex
TSSLSocketFactory::mutex_
;
851 bool TSSLSocketFactory::manualOpenSSLInitialization_
= false;
853 TSSLSocketFactory::TSSLSocketFactory(SSLProtocol protocol
) : server_(false) {
856 if (!manualOpenSSLInitialization_
) {
862 ctx_
= std::make_shared
<SSLContext
>(protocol
);
865 TSSLSocketFactory::~TSSLSocketFactory() {
869 if (count_
== 0 && !manualOpenSSLInitialization_
) {
874 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket() {
875 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
));
880 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket(std::shared_ptr
<THRIFT_SOCKET
> interruptListener
) {
881 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
, interruptListener
));
886 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket
) {
887 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
, socket
));
892 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket(THRIFT_SOCKET socket
, std::shared_ptr
<THRIFT_SOCKET
> interruptListener
) {
893 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
, socket
, interruptListener
));
898 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket(const string
& host
, int port
) {
899 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
, host
, port
));
904 std::shared_ptr
<TSSLSocket
> TSSLSocketFactory::createSocket(const string
& host
, int port
, std::shared_ptr
<THRIFT_SOCKET
> interruptListener
) {
905 std::shared_ptr
<TSSLSocket
> ssl(new TSSLSocket(ctx_
, host
, port
, interruptListener
));
911 void TSSLSocketFactory::setup(std::shared_ptr
<TSSLSocket
> ssl
) {
912 ssl
->server(server());
913 if (access_
== nullptr && !server()) {
914 access_
= std::shared_ptr
<AccessManager
>(new DefaultClientAccessManager
);
916 if (access_
!= nullptr) {
917 ssl
->access(access_
);
921 void TSSLSocketFactory::ciphers(const string
& enable
) {
922 int rc
= SSL_CTX_set_cipher_list(ctx_
->get(), enable
.c_str());
923 if (ERR_peek_error() != 0) {
926 throw TSSLException("SSL_CTX_set_cipher_list: " + errors
);
929 throw TSSLException("None of specified ciphers are supported");
933 void TSSLSocketFactory::authenticate(bool required
) {
936 mode
= SSL_VERIFY_PEER
| SSL_VERIFY_FAIL_IF_NO_PEER_CERT
| SSL_VERIFY_CLIENT_ONCE
;
938 mode
= SSL_VERIFY_NONE
;
940 SSL_CTX_set_verify(ctx_
->get(), mode
, nullptr);
943 void TSSLSocketFactory::loadCertificate(const char* path
, const char* format
) {
944 if (path
== nullptr || format
== nullptr) {
945 throw TTransportException(TTransportException::BAD_ARGS
,
946 "loadCertificateChain: either <path> or <format> is NULL");
948 if (strcmp(format
, "PEM") == 0) {
949 if (SSL_CTX_use_certificate_chain_file(ctx_
->get(), path
) == 0) {
950 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
952 buildErrors(errors
, errno_copy
);
953 throw TSSLException("SSL_CTX_use_certificate_chain_file: " + errors
);
956 throw TSSLException("Unsupported certificate format: " + string(format
));
960 void TSSLSocketFactory::loadPrivateKey(const char* path
, const char* format
) {
961 if (path
== nullptr || format
== nullptr) {
962 throw TTransportException(TTransportException::BAD_ARGS
,
963 "loadPrivateKey: either <path> or <format> is NULL");
965 if (strcmp(format
, "PEM") == 0) {
966 if (SSL_CTX_use_PrivateKey_file(ctx_
->get(), path
, SSL_FILETYPE_PEM
) == 0) {
967 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
969 buildErrors(errors
, errno_copy
);
970 throw TSSLException("SSL_CTX_use_PrivateKey_file: " + errors
);
975 void TSSLSocketFactory::loadTrustedCertificates(const char* path
, const char* capath
) {
976 if (path
== nullptr) {
977 throw TTransportException(TTransportException::BAD_ARGS
,
978 "loadTrustedCertificates: <path> is NULL");
980 if (SSL_CTX_load_verify_locations(ctx_
->get(), path
, capath
) == 0) {
981 int errno_copy
= THRIFT_GET_SOCKET_ERROR
;
983 buildErrors(errors
, errno_copy
);
984 throw TSSLException("SSL_CTX_load_verify_locations: " + errors
);
988 void TSSLSocketFactory::randomize() {
992 void TSSLSocketFactory::overrideDefaultPasswordCallback() {
993 SSL_CTX_set_default_passwd_cb(ctx_
->get(), passwordCallback
);
994 SSL_CTX_set_default_passwd_cb_userdata(ctx_
->get(), this);
997 int TSSLSocketFactory::passwordCallback(char* password
, int size
, int, void* data
) {
998 auto* factory
= (TSSLSocketFactory
*)data
;
1000 factory
->getPassword(userPassword
, size
);
1001 int length
= static_cast<int>(userPassword
.size());
1002 if (length
> size
) {
1005 strncpy(password
, userPassword
.c_str(), length
);
1006 userPassword
.assign(userPassword
.size(), '*');
1010 // extract error messages from error queue
1011 void buildErrors(string
& errors
, int errno_copy
, int sslerrno
) {
1012 unsigned long errorCode
;
1015 errors
.reserve(512);
1016 while ((errorCode
= ERR_get_error()) != 0) {
1017 if (!errors
.empty()) {
1020 const char* reason
= ERR_reason_error_string(errorCode
);
1021 if (reason
== nullptr) {
1022 THRIFT_SNPRINTF(message
, sizeof(message
) - 1, "SSL error # %lu", errorCode
);
1027 if (errors
.empty()) {
1028 if (errno_copy
!= 0) {
1029 errors
+= TOutput::strerror_s(errno_copy
);
1032 if (errors
.empty()) {
1033 errors
= "error code: " + to_string(errno_copy
);
1036 errors
+= " (SSL_error_code = " + to_string(sslerrno
) + ")";
1037 if (sslerrno
== SSL_ERROR_SYSCALL
) {
1040 while ((err
= ERR_get_error()) != 0) {
1042 errors
+= ERR_error_string(err
, buf
);
1049 * Default implementation of AccessManager
1051 Decision
DefaultClientAccessManager::verify(const sockaddr_storage
& sa
) noexcept
{
1056 Decision
DefaultClientAccessManager::verify(const string
& host
,
1058 int size
) noexcept
{
1059 if (host
.empty() || name
== nullptr || size
<= 0) {
1062 return (matchName(host
.c_str(), name
, size
) ? ALLOW
: SKIP
);
1065 Decision
DefaultClientAccessManager::verify(const sockaddr_storage
& sa
,
1067 int size
) noexcept
{
1069 if (sa
.ss_family
== AF_INET
&& size
== sizeof(in_addr
)) {
1070 match
= (memcmp(&((sockaddr_in
*)&sa
)->sin_addr
, data
, size
) == 0);
1071 } else if (sa
.ss_family
== AF_INET6
&& size
== sizeof(in6_addr
)) {
1072 match
= (memcmp(&((sockaddr_in6
*)&sa
)->sin6_addr
, data
, size
) == 0);
1074 return (match
? ALLOW
: SKIP
);
1078 * Match a name with a pattern. The pattern may include wildcard. A single
1079 * wildcard "*" can match up to one component in the domain name.
1081 * @param host Host name, typically the name of the remote host
1082 * @param pattern Name retrieved from certificate
1083 * @param size Size of "pattern"
1084 * @return True, if "host" matches "pattern". False otherwise.
1086 bool matchName(const char* host
, const char* pattern
, int size
) {
1089 while (i
< size
&& host
[j
] != '\0') {
1090 if (uppercase(pattern
[i
]) == uppercase(host
[j
])) {
1095 if (pattern
[i
] == '*') {
1096 while (host
[j
] != '.' && host
[j
] != '\0') {
1104 if (i
== size
&& host
[j
] == '\0') {
1110 // This is to work around the Turkish locale issue, i.e.,
1111 // toupper('i') != toupper('I') if locale is "tr_TR"
1112 char uppercase(char c
) {
1113 if ('a' <= c
&& c
<= 'z') {
1114 return c
+ ('A' - 'a');