2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
10 * http://www.apache.org/licenses/LICENSE-2.0
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
20 package org
.apache
.thrift
.transport
;
22 import java
.lang
.ref
.WeakReference
;
23 import java
.nio
.charset
.StandardCharsets
;
24 import java
.util
.Collections
;
25 import java
.util
.HashMap
;
27 import java
.util
.WeakHashMap
;
29 import javax
.security
.auth
.callback
.CallbackHandler
;
30 import javax
.security
.sasl
.Sasl
;
31 import javax
.security
.sasl
.SaslException
;
32 import javax
.security
.sasl
.SaslServer
;
34 import org
.slf4j
.Logger
;
35 import org
.slf4j
.LoggerFactory
;
38 * Wraps another Thrift <code>TTransport</code>, but performs SASL server
39 * negotiation on the call to <code>open()</code>. This class will wrap ensuing
40 * communication over it, if a SASL QOP is negotiated with the other party.
42 public class TSaslServerTransport
extends TSaslTransport
{
44 private static final Logger LOGGER
= LoggerFactory
.getLogger(TSaslServerTransport
.class);
47 * Mapping from SASL mechanism name -> all the parameters required to
48 * instantiate a SASL server.
50 private Map
<String
, TSaslServerDefinition
> serverDefinitionMap
= new HashMap
<String
, TSaslServerDefinition
>();
53 * Contains all the parameters used to define a SASL server implementation.
55 private static class TSaslServerDefinition
{
56 public String mechanism
;
57 public String protocol
;
58 public String serverName
;
59 public Map
<String
, String
> props
;
60 public CallbackHandler cbh
;
62 public TSaslServerDefinition(String mechanism
, String protocol
, String serverName
,
63 Map
<String
, String
> props
, CallbackHandler cbh
) {
64 this.mechanism
= mechanism
;
65 this.protocol
= protocol
;
66 this.serverName
= serverName
;
73 * Uses the given underlying transport. Assumes that addServerDefinition is
77 * Transport underlying this one.
79 public TSaslServerTransport(TTransport transport
) {
84 * Creates a <code>SaslServer</code> using the given SASL-specific parameters.
85 * See the Java documentation for <code>Sasl.createSaslServer</code> for the
86 * details of the parameters.
89 * The underlying Thrift transport.
91 public TSaslServerTransport(String mechanism
, String protocol
, String serverName
,
92 Map
<String
, String
> props
, CallbackHandler cbh
, TTransport transport
) {
94 addServerDefinition(mechanism
, protocol
, serverName
, props
, cbh
);
97 private TSaslServerTransport(Map
<String
, TSaslServerDefinition
> serverDefinitionMap
, TTransport transport
) {
99 this.serverDefinitionMap
.putAll(serverDefinitionMap
);
103 * Add a supported server definition to this transport. See the Java
104 * documentation for <code>Sasl.createSaslServer</code> for the details of the
107 public void addServerDefinition(String mechanism
, String protocol
, String serverName
,
108 Map
<String
, String
> props
, CallbackHandler cbh
) {
109 serverDefinitionMap
.put(mechanism
, new TSaslServerDefinition(mechanism
, protocol
, serverName
,
114 protected SaslRole
getRole() {
115 return SaslRole
.SERVER
;
119 * Performs the server side of the initial portion of the Thrift SASL protocol.
120 * Receives the initial response from the client, creates a SASL server using
121 * the mechanism requested by the client (if this server supports it), and
122 * sends the first challenge back to the client.
125 protected void handleSaslStartMessage() throws TTransportException
, SaslException
{
126 SaslResponse message
= receiveSaslMessage();
128 LOGGER
.debug("Received start message with status {}", message
.status
);
129 if (message
.status
!= NegotiationStatus
.START
) {
130 throw sendAndThrowMessage(NegotiationStatus
.ERROR
, "Expecting START status, received " + message
.status
);
133 // Get the mechanism name.
134 String mechanismName
= new String(message
.payload
, StandardCharsets
.UTF_8
);
135 TSaslServerDefinition serverDefinition
= serverDefinitionMap
.get(mechanismName
);
136 LOGGER
.debug("Received mechanism name '{}'", mechanismName
);
138 if (serverDefinition
== null) {
139 throw sendAndThrowMessage(NegotiationStatus
.BAD
, "Unsupported mechanism type " + mechanismName
);
141 SaslServer saslServer
= Sasl
.createSaslServer(serverDefinition
.mechanism
,
142 serverDefinition
.protocol
, serverDefinition
.serverName
, serverDefinition
.props
,
143 serverDefinition
.cbh
);
144 setSaslServer(saslServer
);
148 * <code>TTransportFactory</code> to create
149 * <code>TSaslServerTransports</code>. Ensures that a given
150 * underlying <code>TTransport</code> instance receives the same
151 * <code>TSaslServerTransport</code>. This is kind of an awful hack to work
152 * around the fact that Thrift is designed assuming that
153 * <code>TTransport</code> instances are stateless, and thus the existing
154 * <code>TServers</code> use different <code>TTransport</code> instances for
157 public static class Factory
extends TTransportFactory
{
160 * This is the implementation of the awful hack described above.
161 * <code>WeakHashMap</code> is used to ensure that we don't leak memory.
163 private static Map
<TTransport
, WeakReference
<TSaslServerTransport
>> transportMap
=
164 Collections
.synchronizedMap(new WeakHashMap
<TTransport
, WeakReference
<TSaslServerTransport
>>());
167 * Mapping from SASL mechanism name -> all the parameters required to
168 * instantiate a SASL server.
170 private Map
<String
, TSaslServerDefinition
> serverDefinitionMap
= new HashMap
<String
, TSaslServerDefinition
>();
173 * Create a new Factory. Assumes that <code>addServerDefinition</code> will
181 * Create a new <code>Factory</code>, initially with the single server
182 * definition given. You may still call <code>addServerDefinition</code>
183 * later. See the Java documentation for <code>Sasl.createSaslServer</code>
184 * for the details of the parameters.
186 public Factory(String mechanism
, String protocol
, String serverName
,
187 Map
<String
, String
> props
, CallbackHandler cbh
) {
189 addServerDefinition(mechanism
, protocol
, serverName
, props
, cbh
);
193 * Add a supported server definition to the transports created by this
194 * factory. See the Java documentation for
195 * <code>Sasl.createSaslServer</code> for the details of the parameters.
197 public void addServerDefinition(String mechanism
, String protocol
, String serverName
,
198 Map
<String
, String
> props
, CallbackHandler cbh
) {
199 serverDefinitionMap
.put(mechanism
, new TSaslServerDefinition(mechanism
, protocol
, serverName
,
204 * Get a new <code>TSaslServerTransport</code> instance, or reuse the
205 * existing one if a <code>TSaslServerTransport</code> has already been
206 * created before using the given <code>TTransport</code> as an underlying
207 * transport. This ensures that a given underlying transport instance
208 * receives the same <code>TSaslServerTransport</code>.
211 public TTransport
getTransport(TTransport base
) {
212 WeakReference
<TSaslServerTransport
> ret
= transportMap
.get(base
);
213 if (ret
== null || ret
.get() == null) {
214 LOGGER
.debug("transport map does not contain key", base
);
215 ret
= new WeakReference
<TSaslServerTransport
>(new TSaslServerTransport(serverDefinitionMap
, base
));
218 } catch (TTransportException e
) {
219 LOGGER
.debug("failed to open server transport", e
);
220 throw new RuntimeException(e
);
222 transportMap
.put(base
, ret
); // No need for putIfAbsent().
223 // Concurrent calls to getTransport() will pass in different TTransports.
225 LOGGER
.debug("transport map does contain key {}", base
);