]> git.proxmox.com Git - ceph.git/blame - ceph/src/jaegertracing/thrift/lib/java/test/org/apache/thrift/transport/TestTSaslTransports.java
buildsys: switch source download to quincy
[ceph.git] / ceph / src / jaegertracing / thrift / lib / java / test / org / apache / thrift / transport / TestTSaslTransports.java
CommitLineData
f67539c2
TL
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package org.apache.thrift.transport;
21
22import java.io.IOException;
23import java.nio.charset.StandardCharsets;
24import java.util.HashMap;
25import java.util.Map;
26
27import javax.security.auth.callback.Callback;
28import javax.security.auth.callback.CallbackHandler;
29import javax.security.auth.callback.NameCallback;
30import javax.security.auth.callback.PasswordCallback;
31import javax.security.auth.callback.UnsupportedCallbackException;
32import javax.security.sasl.AuthorizeCallback;
33import javax.security.sasl.RealmCallback;
34import javax.security.sasl.Sasl;
35import javax.security.sasl.SaslClient;
36import javax.security.sasl.SaslClientFactory;
37import javax.security.sasl.SaslException;
38import javax.security.sasl.SaslServer;
39import javax.security.sasl.SaslServerFactory;
40
41import junit.framework.TestCase;
42
43import org.apache.thrift.TProcessor;
44import org.apache.thrift.protocol.TProtocolFactory;
45import org.apache.thrift.server.ServerTestBase;
46import org.apache.thrift.server.TServer;
47import org.apache.thrift.server.TSimpleServer;
48import org.apache.thrift.server.TServer.Args;
49import org.slf4j.Logger;
50import org.slf4j.LoggerFactory;
51
52public class TestTSaslTransports extends TestCase {
53
54 private static final Logger LOGGER = LoggerFactory.getLogger(TestTSaslTransports.class);
55
56 private static final String HOST = "localhost";
57 private static final String SERVICE = "thrift-test";
58 private static final String PRINCIPAL = "thrift-test-principal";
59 private static final String PASSWORD = "super secret password";
60 private static final String REALM = "thrift-test-realm";
61
62 private static final String UNWRAPPED_MECHANISM = "CRAM-MD5";
63 private static final Map<String, String> UNWRAPPED_PROPS = null;
64
65 private static final String WRAPPED_MECHANISM = "DIGEST-MD5";
66 private static final Map<String, String> WRAPPED_PROPS = new HashMap<String, String>();
67
68 static {
69 WRAPPED_PROPS.put(Sasl.QOP, "auth-int");
70 WRAPPED_PROPS.put("com.sun.security.sasl.digest.realm", REALM);
71 }
72
73 private static final String testMessage1 = "Hello, world! Also, four "
74 + "score and seven years ago our fathers brought forth on this "
75 + "continent a new nation, conceived in liberty, and dedicated to the "
76 + "proposition that all men are created equal.";
77
78 private static final String testMessage2 = "I have a dream that one day "
79 + "this nation will rise up and live out the true meaning of its creed: "
80 + "'We hold these truths to be self-evident, that all men are created equal.'";
81
82
83 private static class TestSaslCallbackHandler implements CallbackHandler {
84 private final String password;
85
86 public TestSaslCallbackHandler(String password) {
87 this.password = password;
88 }
89
90 @Override
91 public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
92 for (Callback c : callbacks) {
93 if (c instanceof NameCallback) {
94 ((NameCallback) c).setName(PRINCIPAL);
95 } else if (c instanceof PasswordCallback) {
96 ((PasswordCallback) c).setPassword(password.toCharArray());
97 } else if (c instanceof AuthorizeCallback) {
98 ((AuthorizeCallback) c).setAuthorized(true);
99 } else if (c instanceof RealmCallback) {
100 ((RealmCallback) c).setText(REALM);
101 } else {
102 throw new UnsupportedCallbackException(c);
103 }
104 }
105 }
106 }
107
108 private class ServerThread extends Thread {
109 final String mechanism;
110 final Map<String, String> props;
111 volatile Throwable thrown;
112
113 public ServerThread(String mechanism, Map<String, String> props) {
114 this.mechanism = mechanism;
115 this.props = props;
116 }
117
118 public void run() {
119 try {
120 internalRun();
121 } catch (Throwable t) {
122 thrown = t;
123 }
124 }
125
126 private void internalRun() throws Exception {
127 TServerSocket serverSocket = new TServerSocket(
128 new TServerSocket.ServerSocketTransportArgs().
129 port(ServerTestBase.PORT));
130 try {
131 acceptAndWrite(serverSocket);
132 } finally {
133 serverSocket.close();
134 }
135 }
136
137 private void acceptAndWrite(TServerSocket serverSocket)
138 throws Exception {
139 TTransport serverTransport = serverSocket.accept();
140 TTransport saslServerTransport = new TSaslServerTransport(
141 mechanism, SERVICE, HOST,
142 props, new TestSaslCallbackHandler(PASSWORD), serverTransport);
143
144 saslServerTransport.open();
145
146 byte[] inBuf = new byte[testMessage1.getBytes().length];
147 // Deliberately read less than the full buffer to ensure
148 // that TSaslTransport is correctly buffering reads. This
149 // will fail for the WRAPPED test, if it doesn't work.
150 saslServerTransport.readAll(inBuf, 0, 5);
151 saslServerTransport.readAll(inBuf, 5, 10);
152 saslServerTransport.readAll(inBuf, 15, inBuf.length - 15);
153 LOGGER.debug("server got: {}", new String(inBuf));
154 assertEquals(new String(inBuf), testMessage1);
155
156 LOGGER.debug("server writing: {}", testMessage2);
157 saslServerTransport.write(testMessage2.getBytes());
158 saslServerTransport.flush();
159
160 saslServerTransport.close();
161 }
162 }
163
164 private void testSaslOpen(final String mechanism, final Map<String, String> props)
165 throws Exception {
166 ServerThread serverThread = new ServerThread(mechanism, props);
167 serverThread.start();
168
169 try {
170 Thread.sleep(1000);
171 } catch (InterruptedException e) {
172 // Ah well.
173 }
174
175 try {
176 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
177 TTransport saslClientTransport = new TSaslClientTransport(mechanism,
178 PRINCIPAL, SERVICE, HOST, props, new TestSaslCallbackHandler(PASSWORD), clientSocket);
179 saslClientTransport.open();
180 LOGGER.debug("client writing: {}", testMessage1);
181 saslClientTransport.write(testMessage1.getBytes());
182 saslClientTransport.flush();
183
184 byte[] inBuf = new byte[testMessage2.getBytes().length];
185 saslClientTransport.readAll(inBuf, 0, inBuf.length);
186 LOGGER.debug("client got: {}", new String(inBuf));
187 assertEquals(new String(inBuf), testMessage2);
188
189 TTransportException expectedException = null;
190 try {
191 saslClientTransport.open();
192 } catch (TTransportException e) {
193 expectedException = e;
194 }
195 assertNotNull(expectedException);
196
197 saslClientTransport.close();
198 } catch (Exception e) {
199 LOGGER.warn("Exception caught", e);
200 throw e;
201 } finally {
202 serverThread.interrupt();
203 try {
204 serverThread.join();
205 } catch (InterruptedException e) {
206 // Ah well.
207 }
208 assertNull(serverThread.thrown);
209 }
210 }
211
212 public void testUnwrappedOpen() throws Exception {
213 testSaslOpen(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
214 }
215
216 public void testWrappedOpen() throws Exception {
217 testSaslOpen(WRAPPED_MECHANISM, WRAPPED_PROPS);
218 }
219
220 public void testAnonymousOpen() throws Exception {
221 testSaslOpen("ANONYMOUS", null);
222 }
223
224 /**
225 * Test that we get the proper exceptions thrown back the server when
226 * the client provides invalid password.
227 */
228 public void testBadPassword() throws Exception {
229 ServerThread serverThread = new ServerThread(UNWRAPPED_MECHANISM, UNWRAPPED_PROPS);
230 serverThread.start();
231
232 try {
233 Thread.sleep(1000);
234 } catch (InterruptedException e) {
235 // Ah well.
236 }
237
238 boolean clientSidePassed = true;
239
240 try {
241 TSocket clientSocket = new TSocket(HOST, ServerTestBase.PORT);
242 TTransport saslClientTransport = new TSaslClientTransport(
243 UNWRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, UNWRAPPED_PROPS,
244 new TestSaslCallbackHandler("NOT THE PASSWORD"), clientSocket);
245 saslClientTransport.open();
246 clientSidePassed = false;
247 fail("Was able to open transport with bad password");
248 } catch (TTransportException tte) {
249 LOGGER.error("Exception for bad password", tte);
250 assertNotNull(tte.getMessage());
251 assertTrue(tte.getMessage().contains("Invalid response"));
252
253 } finally {
254 serverThread.interrupt();
255 serverThread.join();
256
257 if (clientSidePassed) {
258 assertNotNull(serverThread.thrown);
259 assertTrue(serverThread.thrown.getMessage().contains("Invalid response"));
260 }
261 }
262 }
263
264 public void testWithServer() throws Exception {
265 new TestTSaslTransportsWithServer().testIt();
266 }
267
268 private static class TestTSaslTransportsWithServer extends ServerTestBase {
269
270 private Thread serverThread;
271 private TServer server;
272
273 @Override
274 public TTransport getClientTransport(TTransport underlyingTransport) throws Exception {
275 return new TSaslClientTransport(
276 WRAPPED_MECHANISM, PRINCIPAL, SERVICE, HOST, WRAPPED_PROPS,
277 new TestSaslCallbackHandler(PASSWORD), underlyingTransport);
278 }
279
280 @Override
281 public void startServer(final TProcessor processor, final TProtocolFactory protoFactory, final TTransportFactory factory) throws Exception {
282 serverThread = new Thread() {
283 public void run() {
284 try {
285 // Transport
286 TServerSocket socket = new TServerSocket(new TServerSocket.ServerSocketTransportArgs().port(PORT));
287
288 TTransportFactory factory = new TSaslServerTransport.Factory(
289 WRAPPED_MECHANISM, SERVICE, HOST, WRAPPED_PROPS,
290 new TestSaslCallbackHandler(PASSWORD));
291 server = new TSimpleServer(new Args(socket).processor(processor).transportFactory(factory).protocolFactory(protoFactory));
292
293 // Run it
294 LOGGER.debug("Starting the server on port {}", PORT);
295 server.serve();
296 } catch (Exception e) {
297 e.printStackTrace();
298 fail();
299 }
300 }
301 };
302 serverThread.start();
303 Thread.sleep(1000);
304 }
305
306 @Override
307 public void stopServer() throws Exception {
308 server.stop();
309 try {
310 serverThread.join();
311 } catch (InterruptedException e) {}
312 }
313
314 }
315
316
317 /**
318 * Implementation of SASL ANONYMOUS, used for testing client-side
319 * initial responses.
320 */
321 private static class AnonymousClient implements SaslClient {
322 private final String username;
323 private boolean hasProvidedInitialResponse;
324
325 public AnonymousClient(String username) {
326 this.username = username;
327 }
328
329 public String getMechanismName() { return "ANONYMOUS"; }
330 public boolean hasInitialResponse() { return true; }
331 public byte[] evaluateChallenge(byte[] challenge) throws SaslException {
332 if (hasProvidedInitialResponse) {
333 throw new SaslException("Already complete!");
334 }
335
336 hasProvidedInitialResponse = true;
337 return username.getBytes(StandardCharsets.UTF_8);
338 }
339 public boolean isComplete() { return hasProvidedInitialResponse; }
340 public byte[] unwrap(byte[] incoming, int offset, int len) {
341 throw new UnsupportedOperationException();
342 }
343 public byte[] wrap(byte[] outgoing, int offset, int len) {
344 throw new UnsupportedOperationException();
345 }
346 public Object getNegotiatedProperty(String propName) { return null; }
347 public void dispose() {}
348 }
349
350 private static class AnonymousServer implements SaslServer {
351 private String user;
352 public String getMechanismName() { return "ANONYMOUS"; }
353 public byte[] evaluateResponse(byte[] response) throws SaslException {
354 this.user = new String(response, StandardCharsets.UTF_8);
355 return null;
356 }
357 public boolean isComplete() { return user != null; }
358 public String getAuthorizationID() { return user; }
359 public byte[] unwrap(byte[] incoming, int offset, int len) {
360 throw new UnsupportedOperationException();
361 }
362 public byte[] wrap(byte[] outgoing, int offset, int len) {
363 throw new UnsupportedOperationException();
364 }
365 public Object getNegotiatedProperty(String propName) { return null; }
366 public void dispose() {}
367
368 }
369
370 public static class SaslAnonymousFactory
371 implements SaslClientFactory, SaslServerFactory {
372
373 public SaslClient createSaslClient(
374 String[] mechanisms, String authorizationId, String protocol,
375 String serverName, Map<String,?> props, CallbackHandler cbh)
376 {
377 for (String mech : mechanisms) {
378 if ("ANONYMOUS".equals(mech)) {
379 return new AnonymousClient(authorizationId);
380 }
381 }
382 return null;
383 }
384
385 public SaslServer createSaslServer(
386 String mechanism, String protocol, String serverName, Map<String,?> props, CallbackHandler cbh)
387 {
388 if ("ANONYMOUS".equals(mechanism)) {
389 return new AnonymousServer();
390 }
391 return null;
392 }
393 public String[] getMechanismNames(Map<String, ?> props) {
394 return new String[] { "ANONYMOUS" };
395 }
396 }
397
398 static {
399 java.security.Security.addProvider(new SaslAnonymousProvider());
400 }
401 public static class SaslAnonymousProvider extends java.security.Provider {
402 public SaslAnonymousProvider() {
403 super("ThriftSaslAnonymous", 1.0, "Thrift Anonymous SASL provider");
404 put("SaslClientFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
405 put("SaslServerFactory.ANONYMOUS", SaslAnonymousFactory.class.getName());
406 }
407 }
408
409 private static class MockTTransport extends TTransport {
410
411 byte[] badHeader = null;
412 private TMemoryInputTransport readBuffer = new TMemoryInputTransport();
413
414 public MockTTransport(int mode) {
415 if (mode==1) {
416 // Invalid status byte
417 badHeader = new byte[] { (byte)0xFF, (byte)0x00, (byte)0x00, (byte)0x00, (byte)0x05 };
418 } else if (mode == 2) {
419 // Valid status byte, negative payload length
420 badHeader = new byte[] { (byte)0x01, (byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF };
421 } else if (mode == 3) {
422 // Valid status byte, excessively large, bogus payload length
423 badHeader = new byte[] { (byte)0x01, (byte)0x64, (byte)0x00, (byte)0x00, (byte)0x00 };
424 }
425 readBuffer.reset(badHeader);
426 }
427
428 @Override
429 public boolean isOpen() {
430 return true;
431 }
432
433 @Override
434 public void open() throws TTransportException {}
435
436 @Override
437 public void close() {}
438
439 @Override
440 public int read(byte[] buf, int off, int len) throws TTransportException {
441 return readBuffer.read(buf, off, len);
442 }
443
444 @Override
445 public void write(byte[] buf, int off, int len) throws TTransportException {}
446 }
447
448 public void testBadHeader() {
449 TSaslTransport saslTransport = new TSaslServerTransport(new MockTTransport(1));
450 try {
451 saslTransport.receiveSaslMessage();
452 fail("Should have gotten an error due to incorrect status byte value.");
453 } catch (TTransportException e) {
454 assertEquals(e.getMessage(), "Invalid status -1");
455 }
456 saslTransport = new TSaslServerTransport(new MockTTransport(2));
457 try {
458 saslTransport.receiveSaslMessage();
459 fail("Should have gotten an error due to negative payload length.");
460 } catch (TTransportException e) {
461 assertEquals(e.getMessage(), "Invalid payload header length: -1");
462 }
463 saslTransport = new TSaslServerTransport(new MockTTransport(3));
464 try {
465 saslTransport.receiveSaslMessage();
466 fail("Should have gotten an error due to bogus (large) payload length.");
467 } catch (TTransportException e) {
468 assertEquals(e.getMessage(), "Invalid payload header length: 1677721600");
469 }
470 }
471}