2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
20 #include <boost/test/auto_unit_test.hpp>
21 #include <boost/test/unit_test_suite.hpp>
22 #include <boost/chrono/duration.hpp>
23 #include <boost/date_time/posix_time/posix_time_duration.hpp>
24 #include <boost/thread/thread.hpp>
25 #include <boost/filesystem.hpp>
26 #include <boost/format.hpp>
28 #include <thrift/transport/TSSLSocket.h>
29 #include <thrift/transport/TSSLServerSocket.h>
34 using apache::thrift::transport::TSSLServerSocket
;
35 using apache::thrift::transport::TSSLSocket
;
36 using apache::thrift::transport::TTransport
;
37 using apache::thrift::transport::TTransportException
;
38 using apache::thrift::transport::TSSLSocketFactory
;
40 using std::static_pointer_cast
;
41 using std::shared_ptr
;
43 BOOST_AUTO_TEST_SUITE(TSSLSocketInterruptTest
)
45 boost::filesystem::path keyDir
;
46 boost::filesystem::path
certFile(const std::string
& filename
)
48 return keyDir
/ filename
;
52 struct GlobalFixtureSSL
56 using namespace boost::unit_test::framework
;
57 for (int i
= 0; i
< master_test_suite().argc
; ++i
)
59 BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i
% master_test_suite().argv
[i
]);
63 // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has
64 // disconnected can cause a SIGPIPE signal...
65 signal(SIGPIPE
, SIG_IGN
);
68 TSSLSocketFactory::setManualOpenSSLInitialization(true);
69 apache::thrift::transport::initializeOpenSSL();
71 keyDir
= boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys";
72 if (!boost::filesystem::exists(certFile("server.crt")))
74 keyDir
= boost::filesystem::path(master_test_suite().argv
[master_test_suite().argc
- 1]);
75 if (!boost::filesystem::exists(certFile("server.crt")))
77 throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s).");
82 virtual ~GlobalFixtureSSL()
84 apache::thrift::transport::cleanupOpenSSL();
86 signal(SIGPIPE
, SIG_DFL
);
91 #if (BOOST_VERSION >= 105900)
92 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL
);
94 BOOST_GLOBAL_FIXTURE(GlobalFixtureSSL
)
97 void readerWorker(shared_ptr
<TTransport
> tt
, uint32_t expectedResult
) {
101 BOOST_CHECK_EQUAL(expectedResult
, tt
->read(buf
, 4));
102 } catch (const TTransportException
& tx
) {
103 BOOST_CHECK_EQUAL(TTransportException::TIMED_OUT
, tx
.getType());
107 void readerWorkerMustThrow(shared_ptr
<TTransport
> tt
) {
112 BOOST_ERROR("should not have gotten here");
113 } catch (const TTransportException
& tx
) {
114 BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED
, tx
.getType());
118 shared_ptr
<TSSLSocketFactory
> createServerSocketFactory() {
119 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
;
121 pServerSocketFactory
.reset(new TSSLSocketFactory());
122 pServerSocketFactory
->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");
123 pServerSocketFactory
->loadCertificate(certFile("server.crt").string().c_str());
124 pServerSocketFactory
->loadPrivateKey(certFile("server.key").string().c_str());
125 pServerSocketFactory
->server(true);
126 return pServerSocketFactory
;
129 shared_ptr
<TSSLSocketFactory
> createClientSocketFactory() {
130 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
;
132 pClientSocketFactory
.reset(new TSSLSocketFactory());
133 pClientSocketFactory
->authenticate(true);
134 pClientSocketFactory
->loadCertificate(certFile("client.crt").string().c_str());
135 pClientSocketFactory
->loadPrivateKey(certFile("client.key").string().c_str());
136 pClientSocketFactory
->loadTrustedCertificates(certFile("CA.pem").string().c_str());
137 return pClientSocketFactory
;
140 BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read_while_handshaking
) {
141 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
142 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
144 int port
= sock1
.getPort();
145 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
= createClientSocketFactory();
146 shared_ptr
<TSSLSocket
> clientSock
= pClientSocketFactory
->createSocket("localhost", port
);
148 shared_ptr
<TTransport
> accepted
= sock1
.accept();
149 boost::thread
readThread(std::bind(readerWorkerMustThrow
, accepted
));
150 boost::this_thread::sleep(boost::posix_time::milliseconds(50));
151 // readThread is practically guaranteed to be blocking now
152 sock1
.interruptChildren();
153 BOOST_CHECK_MESSAGE(readThread
.try_join_for(boost::chrono::milliseconds(20)),
154 "server socket interruptChildren did not interrupt child read");
160 BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_read
) {
161 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
162 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
164 int port
= sock1
.getPort();
165 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
= createClientSocketFactory();
166 shared_ptr
<TSSLSocket
> clientSock
= pClientSocketFactory
->createSocket("localhost", port
);
168 shared_ptr
<TTransport
> accepted
= sock1
.accept();
169 boost::thread
readThread(std::bind(readerWorkerMustThrow
, accepted
));
170 clientSock
->write((const uint8_t*)"0", 1);
171 boost::this_thread::sleep(boost::posix_time::milliseconds(50));
172 // readThread is practically guaranteed to be blocking now
173 sock1
.interruptChildren();
174 BOOST_CHECK_MESSAGE(readThread
.try_join_for(boost::chrono::milliseconds(20)),
175 "server socket interruptChildren did not interrupt child read");
181 BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_read
) {
182 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
183 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
184 sock1
.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
186 int port
= sock1
.getPort();
187 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
= createClientSocketFactory();
188 shared_ptr
<TSSLSocket
> clientSock
= pClientSocketFactory
->createSocket("localhost", port
);
190 shared_ptr
<TTransport
> accepted
= sock1
.accept();
191 static_pointer_cast
<TSSLSocket
>(accepted
)->setRecvTimeout(1000);
192 boost::thread
readThread(std::bind(readerWorker
, accepted
, 0));
193 clientSock
->write((const uint8_t*)"0", 1);
194 boost::this_thread::sleep(boost::posix_time::milliseconds(50));
195 // readThread is practically guaranteed to be blocking here
196 sock1
.interruptChildren();
197 BOOST_CHECK_MESSAGE(!readThread
.try_join_for(boost::chrono::milliseconds(200)),
198 "server socket interruptChildren interrupted child read");
200 // wait for receive timeout to kick in
207 BOOST_AUTO_TEST_CASE(test_ssl_cannot_change_after_listen
) {
208 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
209 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
211 BOOST_CHECK_THROW(sock1
.setInterruptableChildren(false), std::logic_error
);
215 void peekerWorker(shared_ptr
<TTransport
> tt
, bool expectedResult
) {
219 BOOST_CHECK_EQUAL(expectedResult
, tt
->peek());
220 } catch (const TTransportException
& tx
) {
221 BOOST_CHECK_EQUAL(TTransportException::TIMED_OUT
, tx
.getType());
225 void peekerWorkerInterrupt(shared_ptr
<TTransport
> tt
) {
230 } catch (const TTransportException
& tx
) {
231 BOOST_CHECK_EQUAL(TTransportException::INTERRUPTED
, tx
.getType());
235 BOOST_AUTO_TEST_CASE(test_ssl_interruptable_child_peek
) {
236 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
237 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
239 int port
= sock1
.getPort();
240 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
= createClientSocketFactory();
241 shared_ptr
<TSSLSocket
> clientSock
= pClientSocketFactory
->createSocket("localhost", port
);
243 shared_ptr
<TTransport
> accepted
= sock1
.accept();
244 boost::thread
peekThread(std::bind(peekerWorkerInterrupt
, accepted
));
245 clientSock
->write((const uint8_t*)"0", 1);
246 boost::this_thread::sleep(boost::posix_time::milliseconds(50));
247 // peekThread is practically guaranteed to be blocking now
248 sock1
.interruptChildren();
249 BOOST_CHECK_MESSAGE(peekThread
.try_join_for(boost::chrono::milliseconds(200)),
250 "server socket interruptChildren did not interrupt child peek");
256 BOOST_AUTO_TEST_CASE(test_ssl_non_interruptable_child_peek
) {
257 shared_ptr
<TSSLSocketFactory
> pServerSocketFactory
= createServerSocketFactory();
258 TSSLServerSocket
sock1("localhost", 0, pServerSocketFactory
);
259 sock1
.setInterruptableChildren(false); // returns to pre-THRIFT-2441 behavior
261 int port
= sock1
.getPort();
262 shared_ptr
<TSSLSocketFactory
> pClientSocketFactory
= createClientSocketFactory();
263 shared_ptr
<TSSLSocket
> clientSock
= pClientSocketFactory
->createSocket("localhost", port
);
265 shared_ptr
<TTransport
> accepted
= sock1
.accept();
266 static_pointer_cast
<TSSLSocket
>(accepted
)->setRecvTimeout(1000);
267 boost::thread
peekThread(std::bind(peekerWorker
, accepted
, false));
268 clientSock
->write((const uint8_t*)"0", 1);
269 boost::this_thread::sleep(boost::posix_time::milliseconds(50));
270 // peekThread is practically guaranteed to be blocking now
271 sock1
.interruptChildren();
272 BOOST_CHECK_MESSAGE(!peekThread
.try_join_for(boost::chrono::milliseconds(200)),
273 "server socket interruptChildren interrupted child peek");
275 // wait for the receive timeout to kick in
282 BOOST_AUTO_TEST_SUITE_END()