]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one | |
3 | * or more contributor license agreements. See the NOTICE file | |
4 | * distributed with this work for additional information | |
5 | * regarding copyright ownership. The ASF licenses this file | |
6 | * to you under the Apache License, Version 2.0 (the | |
7 | * "License"); you may not use this file except in compliance | |
8 | * with the License. You may obtain a copy of the License at | |
9 | * | |
10 | * http://www.apache.org/licenses/LICENSE-2.0 | |
11 | * | |
12 | * Unless required by applicable law or agreed to in writing, | |
13 | * software distributed under the License is distributed on an | |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | * KIND, either express or implied. See the License for the | |
16 | * specific language governing permissions and limitations | |
17 | * under the License. | |
18 | */ | |
19 | ||
20 | package org.apache.thrift.transport; | |
21 | ||
22 | import java.io.IOException; | |
23 | import java.nio.charset.StandardCharsets; | |
24 | import java.util.HashMap; | |
25 | import java.util.Map; | |
26 | ||
27 | import javax.security.auth.callback.Callback; | |
28 | import javax.security.auth.callback.CallbackHandler; | |
29 | import javax.security.auth.callback.NameCallback; | |
30 | import javax.security.auth.callback.PasswordCallback; | |
31 | import javax.security.auth.callback.UnsupportedCallbackException; | |
32 | import javax.security.sasl.AuthorizeCallback; | |
33 | import javax.security.sasl.RealmCallback; | |
34 | import javax.security.sasl.Sasl; | |
35 | import javax.security.sasl.SaslClient; | |
36 | import javax.security.sasl.SaslClientFactory; | |
37 | import javax.security.sasl.SaslException; | |
38 | import javax.security.sasl.SaslServer; | |
39 | import javax.security.sasl.SaslServerFactory; | |
40 | ||
41 | import junit.framework.TestCase; | |
42 | ||
43 | import org.apache.thrift.TProcessor; | |
44 | import org.apache.thrift.protocol.TProtocolFactory; | |
45 | import org.apache.thrift.server.ServerTestBase; | |
46 | import org.apache.thrift.server.TServer; | |
47 | import org.apache.thrift.server.TSimpleServer; | |
48 | import org.apache.thrift.server.TServer.Args; | |
49 | import org.slf4j.Logger; | |
50 | import org.slf4j.LoggerFactory; | |
51 | ||
52 | public class TestTSaslTransports extends TestCase { | |
53 | ||
54 | private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class); | |
55 | ||
56 | private static final String HOST = "localhost"; | |
57 | private static final String SERVICE = "thrift-test"; | |
58 | private static final String PRINCIPAL = "thrift-test-principal"; | |
59 | private static final String PASSWORD = "super secret password"; | |
60 | private static final String REALM = "thrift-test-realm"; | |
61 | ||
62 | private static final String UNWRAPPED_MECHANISM = "CRAM-MD5"; | |
63 | private static final Map<String, String> UNWRAPPED_PROPS = null; | |
64 | ||
65 | private static final String WRAPPED_MECHANISM = "DIGEST-MD5"; | |
66 | private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>(); | |
67 | ||
68 | static { | |
69 | WRAPPED_PROPS.put(Sasl.QOP, "auth-int"); | |
70 | WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM); | |
71 | } | |
72 | ||
73 | private static final String testMessage1 = "Hello, world! Also, four " | |
74 | + "score and seven years ago our fathers brought forth on this " | |
75 | + "continent a new nation, conceived in liberty, and dedicated to the " | |
76 | + "proposition that all men are created equal."; | |
77 | ||
78 | private static final String testMessage2 = "I have a dream that one day " | |
79 | + "this nation will rise up and live out the true meaning of its creed: " | |
80 | + "'We hold these truths to be self-evident, that all men are created equal.'"; | |
81 | ||
82 | ||
83 | private static class TestSaslCallbackHandler implements CallbackHandler { | |
84 | private final String password; | |
85 | ||
86 | public TestSaslCallbackHandler(String password) { | |
87 | this.password = password; | |
88 | } | |
89 | ||
90 | @Override | |
91 | public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException { | |
92 | for (Callback c : callbacks) { | |
93 | if (c instanceof NameCallback) { | |
94 | ((NameCallback) c).setName(PRINCIPAL); | |
95 | } else if (c instanceof PasswordCallback) { | |
96 | ((PasswordCallback) c).setPassword(password.toCharArray()); | |
97 | } else if (c instanceof AuthorizeCallback) { | |
98 | ((AuthorizeCallback) c).setAuthorized(true); | |
99 | } else if (c instanceof RealmCallback) { | |
100 | ((RealmCallback) c).setText(REALM); | |
101 | } else { | |
102 | throw new UnsupportedCallbackException(c); | |
103 | } | |
104 | } | |
105 | } | |
106 | } | |
107 | ||
108 | private class ServerThread extends Thread { | |
109 | final String mechanism; | |
110 | final Map<String, String> props; | |
111 | volatile Throwable thrown; | |
112 | ||
113 | public ServerThread(String mechanism, Map<String, String> props) { | |
114 | this.mechanism = mechanism; | |
115 | this.props = props; | |
116 | } | |
117 | ||
118 | public void run() { | |
119 | try { | |
120 | internalRun(); | |
121 | } catch (Throwable t) { | |
122 | thrown = t; | |
123 | } | |
124 | } | |
125 | ||
126 | private void internalRun() throws Exception { | |
127 | TServerSocket serverSocket = new TServerSocket( | |
128 | new TServerSocket.ServerSocketTransportArgs(). | |
129 | port(ServerTestBase.PORT)); | |
130 | try { | |
131 | acceptAndWrite(serverSocket); | |
132 | } finally { | |
133 | serverSocket.close(); | |
134 | } | |
135 | } | |
136 | ||
137 | private void acceptAndWrite(TServerSocket serverSocket) | |
138 | throws Exception { | |
139 | TTransport serverTransport = serverSocket.accept(); | |
140 | TTransport saslServerTransport = new TSaslServerTransport( | |
141 | mechanism, SERVICE, HOST, | |
142 | props, new TestSaslCallbackHandler(PASSWORD), serverTransport); | |
143 | ||
144 | saslServerTransport.open(); | |
145 | ||
146 | byte[] inBuf = new byte[testMessage1.getBytes().length]; | |
147 | // Deliberately read less than the full buffer to ensure | |
148 | // that TSaslTransport is correctly buffering reads. This | |
149 | // will fail for the WRAPPED test, if it doesn't work. | |
150 | saslServerTransport.readAll(inBuf, 0, 5); | |
151 | saslServerTransport.readAll(inBuf, 5, 10); | |
152 | saslServerTransport.readAll(inBuf, 15, inBuf.length - 15); | |
153 | LOGGER.debug("server got: {}", new String(inBuf)); | |
154 | assertEquals(new String(inBuf), testMessage1); | |
155 | ||
156 | LOGGER.debug("server writing: {}", testMessage2); | |
157 | saslServerTransport.write(testMessage2.getBytes()); | |
158 | saslServerTransport.flush(); | |
159 | ||
160 | saslServerTransport.close(); | |
161 | } | |
162 | } | |
163 | ||
164 | private void testSaslOpen(final String mechanism, final Map<String, String> props) | |
165 | throws Exception { | |
166 | ServerThread serverThread = new ServerThread(mechanism, props); | |
167 | serverThread.start(); | |
168 | ||
169 | try { | |
170 | Thread.sleep(1000); | |
171 | } catch (InterruptedException e) { | |
172 | // Ah well. | |
173 | } | |
174 | ||
175 | try { | |
176 | TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); | |
177 | TTransport saslClientTransport = new TSaslClientTransport(mechanism, | |
178 | PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket); | |
179 | saslClientTransport.open(); | |
180 | LOGGER.debug("client writing: {}", testMessage1); | |
181 | saslClientTransport.write(testMessage1.getBytes()); | |
182 | saslClientTransport.flush(); | |
183 | ||
184 | byte[] inBuf = new byte[testMessage2.getBytes().length]; | |
185 | saslClientTransport.readAll(inBuf, 0, inBuf.length); | |
186 | LOGGER.debug("client got: {}", new String(inBuf)); | |
187 | assertEquals(new String(inBuf), testMessage2); | |
188 | ||
189 | TTransportException expectedException = null; | |
190 | try { | |
191 | saslClientTransport.open(); | |
192 | } catch (TTransportException e) { | |
193 | expectedException = e; | |
194 | } | |
195 | assertNotNull(expectedException); | |
196 | ||
197 | saslClientTransport.close(); | |
198 | } catch (Exception e) { | |
199 | LOGGER.warn("Exception caught", e); | |
200 | throw e; | |
201 | } finally { | |
202 | serverThread.interrupt(); | |
203 | try { | |
204 | serverThread.join(); | |
205 | } catch (InterruptedException e) { | |
206 | // Ah well. | |
207 | } | |
208 | assertNull(serverThread.thrown); | |
209 | } | |
210 | } | |
211 | ||
212 | public void testUnwrappedOpen() throws Exception { | |
213 | testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); | |
214 | } | |
215 | ||
216 | public void testWrappedOpen() throws Exception { | |
217 | testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS); | |
218 | } | |
219 | ||
220 | public void testAnonymousOpen() throws Exception { | |
221 | testSaslOpen("ANONYMOUS", null); | |
222 | } | |
223 | ||
224 | /** | |
225 | * Test that we get the proper exceptions thrown back the server when | |
226 | * the client provides invalid password. | |
227 | */ | |
228 | public void testBadPassword() throws Exception { | |
229 | ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS); | |
230 | serverThread.start(); | |
231 | ||
232 | try { | |
233 | Thread.sleep(1000); | |
234 | } catch (InterruptedException e) { | |
235 | // Ah well. | |
236 | } | |
237 | ||
238 | boolean clientSidePassed = true; | |
239 | ||
240 | try { | |
241 | TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT); | |
242 | TTransport saslClientTransport = new TSaslClientTransport( | |
243 | UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS, | |
244 | new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket); | |
245 | saslClientTransport.open(); | |
246 | clientSidePassed = false; | |
247 | fail("Was able to open transport with bad password"); | |
248 | } catch (TTransportException tte) { | |
249 | LOGGER.error("Exception for bad password", tte); | |
250 | assertNotNull(tte.getMessage()); | |
251 | assertTrue(tte.getMessage().contains("Invalid response")); | |
252 | ||
253 | } finally { | |
254 | serverThread.interrupt(); | |
255 | serverThread.join(); | |
256 | ||
257 | if (clientSidePassed) { | |
258 | assertNotNull(serverThread.thrown); | |
259 | assertTrue(serverThread.thrown.getMessage().contains("Invalid response")); | |
260 | } | |
261 | } | |
262 | } | |
263 | ||
264 | public void testWithServer() throws Exception { | |
265 | new TestTSaslTransportsWithServer().testIt(); | |
266 | } | |
267 | ||
268 | private static class TestTSaslTransportsWithServer extends ServerTestBase { | |
269 | ||
270 | private Thread serverThread; | |
271 | private TServer server; | |
272 | ||
273 | @Override | |
274 | public TTransport getClientTransport(TTransport underlyingTransport) throws Exception { | |
275 | return new TSaslClientTransport( | |
276 | WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS, | |
277 | new TestSaslCallbackHandler(PASSWORD), underlyingTransport); | |
278 | } | |
279 | ||
280 | @Override | |
281 | public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception { | |
282 | serverThread = new Thread() { | |
283 | public void run() { | |
284 | try { | |
285 | // Transport | |
286 | TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT)); | |
287 | ||
288 | TTransportFactory factory = new TSaslServerTransport.Factory( | |
289 | WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS, | |
290 | new TestSaslCallbackHandler(PASSWORD)); | |
291 | server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory)); | |
292 | ||
293 | // Run it | |
294 | LOGGER.debug("Starting the server on port {}", PORT); | |
295 | server.serve(); | |
296 | } catch (Exception e) { | |
297 | e.printStackTrace(); | |
298 | fail(); | |
299 | } | |
300 | } | |
301 | }; | |
302 | serverThread.start(); | |
303 | Thread.sleep(1000); | |
304 | } | |
305 | ||
306 | @Override | |
307 | public void stopServer() throws Exception { | |
308 | server.stop(); | |
309 | try { | |
310 | serverThread.join(); | |
311 | } catch (InterruptedException e) {} | |
312 | } | |
313 | ||
314 | } | |
315 | ||
316 | ||
317 | /** | |
318 | * Implementation of SASL ANONYMOUS, used for testing client-side | |
319 | * initial responses. | |
320 | */ | |
321 | private static class AnonymousClient implements SaslClient { | |
322 | private final String username; | |
323 | private boolean hasProvidedInitialResponse; | |
324 | ||
325 | public AnonymousClient(String username) { | |
326 | this.username = username; | |
327 | } | |
328 | ||
329 | public String getMechanismName() { return "ANONYMOUS"; } | |
330 | public boolean hasInitialResponse() { return true; } | |
331 | public byte[] evaluateChallenge(byte[] challenge) throws SaslException { | |
332 | if (hasProvidedInitialResponse) { | |
333 | throw new SaslException("Already complete!"); | |
334 | } | |
335 | ||
336 | hasProvidedInitialResponse = true; | |
337 | return username.getBytes(StandardCharsets.UTF_8); | |
338 | } | |
339 | public boolean isComplete() { return hasProvidedInitialResponse; } | |
340 | public byte[] unwrap(byte[] incoming, int offset, int len) { | |
341 | throw new UnsupportedOperationException(); | |
342 | } | |
343 | public byte[] wrap(byte[] outgoing, int offset, int len) { | |
344 | throw new UnsupportedOperationException(); | |
345 | } | |
346 | public Object getNegotiatedProperty(String propName) { return null; } | |
347 | public void dispose() {} | |
348 | } | |
349 | ||
350 | private static class AnonymousServer implements SaslServer { | |
351 | private String user; | |
352 | public String getMechanismName() { return "ANONYMOUS"; } | |
353 | public byte[] evaluateResponse(byte[] response) throws SaslException { | |
354 | this.user = new String(response, StandardCharsets.UTF_8); | |
355 | return null; | |
356 | } | |
357 | public boolean isComplete() { return user != null; } | |
358 | public String getAuthorizationID() { return user; } | |
359 | public byte[] unwrap(byte[] incoming, int offset, int len) { | |
360 | throw new UnsupportedOperationException(); | |
361 | } | |
362 | public byte[] wrap(byte[] outgoing, int offset, int len) { | |
363 | throw new UnsupportedOperationException(); | |
364 | } | |
365 | public Object getNegotiatedProperty(String propName) { return null; } | |
366 | public void dispose() {} | |
367 | ||
368 | } | |
369 | ||
370 | public static class SaslAnonymousFactory | |
371 | implements SaslClientFactory, SaslServerFactory { | |
372 | ||
373 | public SaslClient createSaslClient( | |
374 | String[] mechanisms, String authorizationId, String protocol, | |
375 | String serverName, Map<String,?> props, CallbackHandler cbh) | |
376 | { | |
377 | for (String mech : mechanisms) { | |
378 | if ("ANONYMOUS".equals(mech)) { | |
379 | return new AnonymousClient(authorizationId); | |
380 | } | |
381 | } | |
382 | return null; | |
383 | } | |
384 | ||
385 | public SaslServer createSaslServer( | |
386 | String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh) | |
387 | { | |
388 | if ("ANONYMOUS".equals(mechanism)) { | |
389 | return new AnonymousServer(); | |
390 | } | |
391 | return null; | |
392 | } | |
393 | public String[] getMechanismNames(Map<String, ?> props) { | |
394 | return new String[] { "ANONYMOUS" }; | |
395 | } | |
396 | } | |
397 | ||
398 | static { | |
399 | java.security.Security.addProvider(new SaslAnonymousProvider()); | |
400 | } | |
401 | public static class SaslAnonymousProvider extends java.security.Provider { | |
402 | public SaslAnonymousProvider() { | |
403 | super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider"); | |
404 | put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); | |
405 | put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName()); | |
406 | } | |
407 | } | |
408 | ||
409 | private static class MockTTransport extends TTransport { | |
410 | ||
411 | byte[] badHeader = null; | |
412 | private TMemoryInputTransport readBuffer = new TMemoryInputTransport(); | |
413 | ||
414 | public MockTTransport(int mode) { | |
415 | if (mode==1) { | |
416 | // Invalid status byte | |
417 | badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 }; | |
418 | } else if (mode == 2) { | |
419 | // Valid status byte, negative payload length | |
420 | badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF }; | |
421 | } else if (mode == 3) { | |
422 | // Valid status byte, excessively large, bogus payload length | |
423 | badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 }; | |
424 | } | |
425 | readBuffer.reset(badHeader); | |
426 | } | |
427 | ||
428 | @Override | |
429 | public boolean isOpen() { | |
430 | return true; | |
431 | } | |
432 | ||
433 | @Override | |
434 | public void open() throws TTransportException {} | |
435 | ||
436 | @Override | |
437 | public void close() {} | |
438 | ||
439 | @Override | |
440 | public int read(byte[] buf, int off, int len) throws TTransportException { | |
441 | return readBuffer.read(buf, off, len); | |
442 | } | |
443 | ||
444 | @Override | |
445 | public void write(byte[] buf, int off, int len) throws TTransportException {} | |
446 | } | |
447 | ||
448 | public void testBadHeader() { | |
449 | TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1)); | |
450 | try { | |
451 | saslTransport.receiveSaslMessage(); | |
452 | fail("Should have gotten an error due to incorrect status byte value."); | |
453 | } catch (TTransportException e) { | |
454 | assertEquals(e.getMessage(), "Invalid status -1"); | |
455 | } | |
456 | saslTransport = new TSaslServerTransport(new MockTTransport(2)); | |
457 | try { | |
458 | saslTransport.receiveSaslMessage(); | |
459 | fail("Should have gotten an error due to negative payload length."); | |
460 | } catch (TTransportException e) { | |
461 | assertEquals(e.getMessage(), "Invalid payload header length: -1"); | |
462 | } | |
463 | saslTransport = new TSaslServerTransport(new MockTTransport(3)); | |
464 | try { | |
465 | saslTransport.receiveSaslMessage(); | |
466 | fail("Should have gotten an error due to bogus (large) payload length."); | |
467 | } catch (TTransportException e) { | |
468 | assertEquals(e.getMessage(), "Invalid payload header length: 1677721600"); | |
469 | } | |
470 | } | |
471 | } |