]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | // Licensed to the Apache Software Foundation(ASF) under one |
2 | // or more contributor license agreements.See the NOTICE file | |
3 | // distributed with this work for additional information | |
4 | // regarding copyright ownership.The ASF licenses this file | |
5 | // to you under the Apache License, Version 2.0 (the | |
6 | // "License"); you may not use this file except in compliance | |
7 | // with the License. You may obtain a copy of the License at | |
8 | // | |
9 | // http://www.apache.org/licenses/LICENSE-2.0 | |
10 | // | |
11 | // Unless required by applicable law or agreed to in writing, | |
12 | // software distributed under the License is distributed on an | |
13 | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
14 | // KIND, either express or implied. See the License for the | |
15 | // specific language governing permissions and limitations | |
16 | // under the License. | |
17 | ||
18 | using System; | |
19 | using System.Net; | |
20 | using System.Net.Security; | |
21 | using System.Net.Sockets; | |
22 | using System.Security.Authentication; | |
23 | using System.Security.Cryptography.X509Certificates; | |
24 | using System.Threading; | |
25 | using System.Threading.Tasks; | |
26 | ||
27 | namespace Thrift.Transport.Client | |
28 | { | |
29 | //TODO: check for correct work | |
30 | ||
31 | // ReSharper disable once InconsistentNaming | |
32 | public class TTlsSocketTransport : TStreamTransport | |
33 | { | |
34 | private readonly X509Certificate2 _certificate; | |
35 | private readonly RemoteCertificateValidationCallback _certValidator; | |
36 | private readonly IPAddress _host; | |
37 | private readonly bool _isServer; | |
38 | private readonly LocalCertificateSelectionCallback _localCertificateSelectionCallback; | |
39 | private readonly int _port; | |
40 | private readonly SslProtocols _sslProtocols; | |
41 | private TcpClient _client; | |
42 | private SslStream _secureStream; | |
43 | private int _timeout; | |
44 | ||
45 | public TTlsSocketTransport(TcpClient client, X509Certificate2 certificate, bool isServer = false, | |
46 | RemoteCertificateValidationCallback certValidator = null, | |
47 | LocalCertificateSelectionCallback localCertificateSelectionCallback = null, | |
48 | SslProtocols sslProtocols = SslProtocols.Tls12) | |
49 | { | |
50 | _client = client; | |
51 | _certificate = certificate; | |
52 | _certValidator = certValidator; | |
53 | _localCertificateSelectionCallback = localCertificateSelectionCallback; | |
54 | _sslProtocols = sslProtocols; | |
55 | _isServer = isServer; | |
56 | ||
57 | if (isServer && certificate == null) | |
58 | { | |
59 | throw new ArgumentException("TTlsSocketTransport needs certificate to be used for server", | |
60 | nameof(certificate)); | |
61 | } | |
62 | ||
63 | if (IsOpen) | |
64 | { | |
65 | InputStream = client.GetStream(); | |
66 | OutputStream = client.GetStream(); | |
67 | } | |
68 | } | |
69 | ||
70 | public TTlsSocketTransport(IPAddress host, int port, string certificatePath, | |
71 | RemoteCertificateValidationCallback certValidator = null, | |
72 | LocalCertificateSelectionCallback localCertificateSelectionCallback = null, | |
73 | SslProtocols sslProtocols = SslProtocols.Tls12) | |
74 | : this(host, port, 0, | |
75 | new X509Certificate2(certificatePath), | |
76 | certValidator, | |
77 | localCertificateSelectionCallback, | |
78 | sslProtocols) | |
79 | { | |
80 | } | |
81 | ||
82 | public TTlsSocketTransport(IPAddress host, int port, | |
83 | X509Certificate2 certificate = null, | |
84 | RemoteCertificateValidationCallback certValidator = null, | |
85 | LocalCertificateSelectionCallback localCertificateSelectionCallback = null, | |
86 | SslProtocols sslProtocols = SslProtocols.Tls12) | |
87 | : this(host, port, 0, | |
88 | certificate, | |
89 | certValidator, | |
90 | localCertificateSelectionCallback, | |
91 | sslProtocols) | |
92 | { | |
93 | } | |
94 | ||
95 | public TTlsSocketTransport(IPAddress host, int port, int timeout, | |
96 | X509Certificate2 certificate, | |
97 | RemoteCertificateValidationCallback certValidator = null, | |
98 | LocalCertificateSelectionCallback localCertificateSelectionCallback = null, | |
99 | SslProtocols sslProtocols = SslProtocols.Tls12) | |
100 | { | |
101 | _host = host; | |
102 | _port = port; | |
103 | _timeout = timeout; | |
104 | _certificate = certificate; | |
105 | _certValidator = certValidator; | |
106 | _localCertificateSelectionCallback = localCertificateSelectionCallback; | |
107 | _sslProtocols = sslProtocols; | |
108 | ||
109 | InitSocket(); | |
110 | } | |
111 | ||
112 | public TTlsSocketTransport(string host, int port, int timeout, | |
113 | X509Certificate2 certificate, | |
114 | RemoteCertificateValidationCallback certValidator = null, | |
115 | LocalCertificateSelectionCallback localCertificateSelectionCallback = null, | |
116 | SslProtocols sslProtocols = SslProtocols.Tls12) | |
117 | { | |
118 | try | |
119 | { | |
120 | var entry = Dns.GetHostEntry(host); | |
121 | if (entry.AddressList.Length == 0) | |
122 | throw new TTransportException(TTransportException.ExceptionType.Unknown, "unable to resolve host name"); | |
123 | ||
124 | var addr = entry.AddressList[0]; | |
125 | ||
126 | _host = new IPAddress(addr.GetAddressBytes(), addr.ScopeId); | |
127 | _port = port; | |
128 | _timeout = timeout; | |
129 | _certificate = certificate; | |
130 | _certValidator = certValidator; | |
131 | _localCertificateSelectionCallback = localCertificateSelectionCallback; | |
132 | _sslProtocols = sslProtocols; | |
133 | ||
134 | InitSocket(); | |
135 | } | |
136 | catch (SocketException e) | |
137 | { | |
138 | throw new TTransportException(TTransportException.ExceptionType.Unknown, e.Message, e); | |
139 | } | |
140 | } | |
141 | ||
142 | public int Timeout | |
143 | { | |
144 | set { _client.ReceiveTimeout = _client.SendTimeout = _timeout = value; } | |
145 | } | |
146 | ||
147 | public TcpClient TcpClient => _client; | |
148 | ||
149 | public IPAddress Host => _host; | |
150 | ||
151 | public int Port => _port; | |
152 | ||
153 | public override bool IsOpen | |
154 | { | |
155 | get | |
156 | { | |
157 | if (_client == null) | |
158 | { | |
159 | return false; | |
160 | } | |
161 | ||
162 | return _client.Connected; | |
163 | } | |
164 | } | |
165 | ||
166 | private void InitSocket() | |
167 | { | |
168 | _client = new TcpClient(); | |
169 | _client.ReceiveTimeout = _client.SendTimeout = _timeout; | |
170 | _client.Client.NoDelay = true; | |
171 | } | |
172 | ||
173 | private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, | |
174 | SslPolicyErrors sslValidationErrors) | |
175 | { | |
176 | return sslValidationErrors == SslPolicyErrors.None; | |
177 | } | |
178 | ||
179 | public override async Task OpenAsync(CancellationToken cancellationToken) | |
180 | { | |
181 | if (IsOpen) | |
182 | { | |
183 | throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected"); | |
184 | } | |
185 | ||
186 | if (_host == null) | |
187 | { | |
188 | throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host"); | |
189 | } | |
190 | ||
191 | if (_port <= 0) | |
192 | { | |
193 | throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port"); | |
194 | } | |
195 | ||
196 | if (_client == null) | |
197 | { | |
198 | InitSocket(); | |
199 | } | |
200 | ||
201 | if (_client != null) | |
202 | { | |
203 | await _client.ConnectAsync(_host, _port); | |
204 | await SetupTlsAsync(); | |
205 | } | |
206 | } | |
207 | ||
208 | public async Task SetupTlsAsync() | |
209 | { | |
210 | var validator = _certValidator ?? DefaultCertificateValidator; | |
211 | ||
212 | if (_localCertificateSelectionCallback != null) | |
213 | { | |
214 | _secureStream = new SslStream(_client.GetStream(), false, validator, _localCertificateSelectionCallback); | |
215 | } | |
216 | else | |
217 | { | |
218 | _secureStream = new SslStream(_client.GetStream(), false, validator); | |
219 | } | |
220 | ||
221 | try | |
222 | { | |
223 | if (_isServer) | |
224 | { | |
225 | // Server authentication | |
226 | await | |
227 | _secureStream.AuthenticateAsServerAsync(_certificate, _certValidator != null, _sslProtocols, | |
228 | true); | |
229 | } | |
230 | else | |
231 | { | |
232 | // Client authentication | |
233 | var certs = _certificate != null | |
234 | ? new X509CertificateCollection {_certificate} | |
235 | : new X509CertificateCollection(); | |
236 | ||
237 | var targetHost = _host.ToString(); | |
238 | await _secureStream.AuthenticateAsClientAsync(targetHost, certs, _sslProtocols, true); | |
239 | } | |
240 | } | |
241 | catch (Exception) | |
242 | { | |
243 | Close(); | |
244 | throw; | |
245 | } | |
246 | ||
247 | InputStream = _secureStream; | |
248 | OutputStream = _secureStream; | |
249 | } | |
250 | ||
251 | public override void Close() | |
252 | { | |
253 | base.Close(); | |
254 | if (_client != null) | |
255 | { | |
256 | _client.Dispose(); | |
257 | _client = null; | |
258 | } | |
259 | ||
260 | if (_secureStream != null) | |
261 | { | |
262 | _secureStream.Dispose(); | |
263 | _secureStream = null; | |
264 | } | |
265 | } | |
266 | } | |
267 | } |