]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one | |
3 | * or more contributor license agreements. See the NOTICE file | |
4 | * distributed with this work for additional information | |
5 | * regarding copyright ownership. The ASF licenses this file | |
6 | * to you under the Apache License, Version 2.0 (the | |
7 | * "License"); you may not use this file except in compliance | |
8 | * with the License. You may obtain a copy of the License at | |
9 | * | |
10 | * http://www.apache.org/licenses/LICENSE-2.0 | |
11 | * | |
12 | * Unless required by applicable law or agreed to in writing, | |
13 | * software distributed under the License is distributed on an | |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | * KIND, either express or implied. See the License for the | |
16 | * specific language governing permissions and limitations | |
17 | * under the License. | |
18 | */ | |
19 | ||
20 | #define BOOST_TEST_MODULE SecurityTest | |
21 | #include <boost/test/unit_test.hpp> | |
22 | #include <boost/filesystem.hpp> | |
23 | #include <boost/foreach.hpp> | |
24 | #include <boost/format.hpp> | |
25 | #include <boost/thread.hpp> | |
26 | #include <memory> | |
27 | #include <thrift/transport/TSSLServerSocket.h> | |
28 | #include <thrift/transport/TSSLSocket.h> | |
29 | #include <thrift/transport/TTransport.h> | |
30 | #include <vector> | |
31 | #ifdef __linux__ | |
32 | #include <signal.h> | |
33 | #endif | |
34 | ||
35 | using apache::thrift::transport::TSSLServerSocket; | |
36 | using apache::thrift::transport::TServerTransport; | |
37 | using apache::thrift::transport::TSSLSocket; | |
38 | using apache::thrift::transport::TSSLSocketFactory; | |
39 | using apache::thrift::transport::TTransport; | |
40 | using apache::thrift::transport::TTransportException; | |
41 | using apache::thrift::transport::TTransportFactory; | |
42 | ||
43 | using std::bind; | |
44 | using std::shared_ptr; | |
45 | ||
46 | boost::filesystem::path keyDir; | |
47 | boost::filesystem::path certFile(const std::string& filename) | |
48 | { | |
49 | return keyDir / filename; | |
50 | } | |
51 | boost::mutex gMutex; | |
52 | ||
53 | struct GlobalFixture | |
54 | { | |
55 | GlobalFixture() | |
56 | { | |
57 | using namespace boost::unit_test::framework; | |
58 | for (int i = 0; i < master_test_suite().argc; ++i) | |
59 | { | |
60 | BOOST_TEST_MESSAGE(boost::format("argv[%1%] = \"%2%\"") % i % master_test_suite().argv[i]); | |
61 | } | |
62 | ||
63 | #ifdef __linux__ | |
64 | // OpenSSL calls send() without MSG_NOSIGPIPE so writing to a socket that has | |
65 | // disconnected can cause a SIGPIPE signal... | |
66 | signal(SIGPIPE, SIG_IGN); | |
67 | #endif | |
68 | ||
69 | TSSLSocketFactory::setManualOpenSSLInitialization(true); | |
70 | apache::thrift::transport::initializeOpenSSL(); | |
71 | ||
72 | keyDir = boost::filesystem::current_path().parent_path().parent_path().parent_path() / "test" / "keys"; | |
73 | if (!boost::filesystem::exists(certFile("server.crt"))) | |
74 | { | |
75 | keyDir = boost::filesystem::path(master_test_suite().argv[master_test_suite().argc - 1]); | |
76 | if (!boost::filesystem::exists(certFile("server.crt"))) | |
77 | { | |
78 | throw std::invalid_argument("The last argument to this test must be the directory containing the test certificate(s)."); | |
79 | } | |
80 | } | |
81 | } | |
82 | ||
83 | virtual ~GlobalFixture() | |
84 | { | |
85 | apache::thrift::transport::cleanupOpenSSL(); | |
86 | #ifdef __linux__ | |
87 | signal(SIGPIPE, SIG_DFL); | |
88 | #endif | |
89 | } | |
90 | }; | |
91 | ||
92 | #if (BOOST_VERSION >= 105900) | |
93 | BOOST_GLOBAL_FIXTURE(GlobalFixture); | |
94 | #else | |
95 | BOOST_GLOBAL_FIXTURE(GlobalFixture) | |
96 | #endif | |
97 | ||
98 | struct SecurityFixture | |
99 | { | |
100 | void server(apache::thrift::transport::SSLProtocol protocol) | |
101 | { | |
102 | try | |
103 | { | |
104 | boost::mutex::scoped_lock lock(mMutex); | |
105 | ||
106 | shared_ptr<TSSLSocketFactory> pServerSocketFactory; | |
107 | shared_ptr<TSSLServerSocket> pServerSocket; | |
108 | ||
109 | pServerSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol))); | |
110 | pServerSocketFactory->ciphers("ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH"); | |
111 | pServerSocketFactory->loadCertificate(certFile("server.crt").string().c_str()); | |
112 | pServerSocketFactory->loadPrivateKey(certFile("server.key").string().c_str()); | |
113 | pServerSocketFactory->server(true); | |
114 | pServerSocket.reset(new TSSLServerSocket("localhost", 0, pServerSocketFactory)); | |
115 | shared_ptr<TTransport> connectedClient; | |
116 | ||
117 | try | |
118 | { | |
119 | pServerSocket->listen(); | |
120 | mPort = pServerSocket->getPort(); | |
121 | mCVar.notify_one(); | |
122 | lock.unlock(); | |
123 | ||
124 | connectedClient = pServerSocket->accept(); | |
125 | uint8_t buf[2]; | |
126 | buf[0] = 'O'; | |
127 | buf[1] = 'K'; | |
128 | connectedClient->write(&buf[0], 2); | |
129 | connectedClient->flush(); | |
130 | } | |
131 | ||
132 | catch (apache::thrift::transport::TTransportException& ex) | |
133 | { | |
134 | boost::mutex::scoped_lock lock(gMutex); | |
135 | BOOST_TEST_MESSAGE(boost::format("SRV %1% Exception: %2%") % boost::this_thread::get_id() % ex.what()); | |
136 | } | |
137 | ||
138 | if (connectedClient) | |
139 | { | |
140 | connectedClient->close(); | |
141 | connectedClient.reset(); | |
142 | } | |
143 | ||
144 | pServerSocket->close(); | |
145 | pServerSocket.reset(); | |
146 | } | |
147 | catch (std::exception& ex) | |
148 | { | |
149 | BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what()); | |
150 | } | |
151 | } | |
152 | ||
153 | void client(apache::thrift::transport::SSLProtocol protocol) | |
154 | { | |
155 | try | |
156 | { | |
157 | shared_ptr<TSSLSocketFactory> pClientSocketFactory; | |
158 | shared_ptr<TSSLSocket> pClientSocket; | |
159 | ||
160 | try | |
161 | { | |
162 | pClientSocketFactory.reset(new TSSLSocketFactory(static_cast<apache::thrift::transport::SSLProtocol>(protocol))); | |
163 | pClientSocketFactory->authenticate(true); | |
164 | pClientSocketFactory->loadCertificate(certFile("client.crt").string().c_str()); | |
165 | pClientSocketFactory->loadPrivateKey(certFile("client.key").string().c_str()); | |
166 | pClientSocketFactory->loadTrustedCertificates(certFile("CA.pem").string().c_str()); | |
167 | pClientSocket = pClientSocketFactory->createSocket("localhost", mPort); | |
168 | pClientSocket->open(); | |
169 | ||
170 | uint8_t buf[3]; | |
171 | buf[0] = 0; | |
172 | buf[1] = 0; | |
173 | BOOST_CHECK_EQUAL(2, pClientSocket->read(&buf[0], 2)); | |
174 | BOOST_CHECK_EQUAL(0, memcmp(&buf[0], "OK", 2)); | |
175 | mConnected = true; | |
176 | } | |
177 | catch (apache::thrift::transport::TTransportException& ex) | |
178 | { | |
179 | boost::mutex::scoped_lock lock(gMutex); | |
180 | BOOST_TEST_MESSAGE(boost::format("CLI %1% Exception: %2%") % boost::this_thread::get_id() % ex.what()); | |
181 | } | |
182 | ||
183 | if (pClientSocket) | |
184 | { | |
185 | pClientSocket->close(); | |
186 | pClientSocket.reset(); | |
187 | } | |
188 | } | |
189 | catch (std::exception& ex) | |
190 | { | |
191 | BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what()); | |
192 | } | |
193 | } | |
194 | ||
195 | static const char *protocol2str(size_t protocol) | |
196 | { | |
197 | static const char *strings[apache::thrift::transport::LATEST + 1] = | |
198 | { | |
199 | "SSLTLS", | |
200 | "SSLv2", | |
201 | "SSLv3", | |
202 | "TLSv1_0", | |
203 | "TLSv1_1", | |
204 | "TLSv1_2" | |
205 | }; | |
206 | return strings[protocol]; | |
207 | } | |
208 | ||
209 | boost::mutex mMutex; | |
210 | boost::condition_variable mCVar; | |
211 | int mPort; | |
212 | bool mConnected; | |
213 | }; | |
214 | ||
215 | BOOST_FIXTURE_TEST_SUITE(BOOST_TEST_MODULE, SecurityFixture) | |
216 | ||
217 | BOOST_AUTO_TEST_CASE(ssl_security_matrix) | |
218 | { | |
219 | try | |
220 | { | |
221 | // matrix of connection success between client and server with different SSLProtocol selections | |
222 | bool matrix[apache::thrift::transport::LATEST + 1][apache::thrift::transport::LATEST + 1] = | |
223 | { | |
224 | // server = SSLTLS SSLv2 SSLv3 TLSv1_0 TLSv1_1 TLSv1_2 | |
225 | // client | |
226 | /* SSLTLS */ { true, false, false, true, true, true }, | |
227 | /* SSLv2 */ { false, false, false, false, false, false }, | |
228 | /* SSLv3 */ { false, false, true, false, false, false }, | |
229 | /* TLSv1_0 */ { true, false, false, true, false, false }, | |
230 | /* TLSv1_1 */ { true, false, false, false, true, false }, | |
231 | /* TLSv1_2 */ { true, false, false, false, false, true } | |
232 | }; | |
233 | ||
234 | for (size_t si = 0; si <= apache::thrift::transport::LATEST; ++si) | |
235 | { | |
236 | for (size_t ci = 0; ci <= apache::thrift::transport::LATEST; ++ci) | |
237 | { | |
238 | if (si == 1 || ci == 1) | |
239 | { | |
240 | // Skip all SSLv2 cases - protocol not supported | |
241 | continue; | |
242 | } | |
243 | ||
244 | #ifdef OPENSSL_NO_SSL3 | |
245 | if (si == 2 || ci == 2) | |
246 | { | |
247 | // Skip all SSLv3 cases - protocol not supported | |
248 | continue; | |
249 | } | |
250 | #endif | |
251 | ||
252 | boost::mutex::scoped_lock lock(mMutex); | |
253 | ||
254 | BOOST_TEST_MESSAGE(boost::format("TEST: Server = %1%, Client = %2%") | |
255 | % protocol2str(si) % protocol2str(ci)); | |
256 | ||
257 | mConnected = false; | |
258 | // thread_group manages the thread lifetime - ignore the return value of create_thread | |
259 | boost::thread_group threads; | |
260 | (void)threads.create_thread(bind(&SecurityFixture::server, this, static_cast<apache::thrift::transport::SSLProtocol>(si))); | |
261 | mCVar.wait(lock); // wait for listen() to succeed | |
262 | lock.unlock(); | |
263 | (void)threads.create_thread(bind(&SecurityFixture::client, this, static_cast<apache::thrift::transport::SSLProtocol>(ci))); | |
264 | threads.join_all(); | |
265 | ||
266 | BOOST_CHECK_MESSAGE(mConnected == matrix[ci][si], | |
267 | boost::format(" Server = %1%, Client = %2% expected mConnected == %3% but was %4%") | |
268 | % protocol2str(si) % protocol2str(ci) % matrix[ci][si] % mConnected); | |
269 | } | |
270 | } | |
271 | } | |
272 | catch (std::exception& ex) | |
273 | { | |
274 | BOOST_FAIL(boost::format("%1%: %2%") % typeid(ex).name() % ex.what()); | |
275 | } | |
276 | } | |
277 | ||
278 | BOOST_AUTO_TEST_SUITE_END() |