]>
Commit | Line | Data |
---|---|---|
f67539c2 TL |
1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one | |
3 | * or more contributor license agreements. See the NOTICE file | |
4 | * distributed with this work for additional information | |
5 | * regarding copyright ownership. The ASF licenses this file | |
6 | * to you under the Apache License, Version 2.0 (the | |
7 | * "License"); you may not use this file except in compliance | |
8 | * with the License. You may obtain a copy of the License at | |
9 | * | |
10 | * http://www.apache.org/licenses/LICENSE-2.0 | |
11 | * | |
12 | * Unless required by applicable law or agreed to in writing, | |
13 | * software distributed under the License is distributed on an | |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
15 | * KIND, either express or implied. See the License for the | |
16 | * specific language governing permissions and limitations | |
17 | * under the License. | |
18 | */ | |
19 | ||
20 | package thrift | |
21 | ||
22 | import ( | |
23 | "context" | |
24 | "crypto/tls" | |
25 | "net" | |
26 | "time" | |
27 | ) | |
28 | ||
29 | type TSSLSocket struct { | |
30 | conn net.Conn | |
31 | // hostPort contains host:port (e.g. "asdf.com:12345"). The field is | |
32 | // only valid if addr is nil. | |
33 | hostPort string | |
34 | // addr is nil when hostPort is not "", and is only used when the | |
35 | // TSSLSocket is constructed from a net.Addr. | |
36 | addr net.Addr | |
37 | timeout time.Duration | |
38 | cfg *tls.Config | |
39 | } | |
40 | ||
41 | // NewTSSLSocket creates a net.Conn-backed TTransport, given a host and port and tls Configuration | |
42 | // | |
43 | // Example: | |
44 | // trans, err := thrift.NewTSSLSocket("localhost:9090", nil) | |
45 | func NewTSSLSocket(hostPort string, cfg *tls.Config) (*TSSLSocket, error) { | |
46 | return NewTSSLSocketTimeout(hostPort, cfg, 0) | |
47 | } | |
48 | ||
49 | // NewTSSLSocketTimeout creates a net.Conn-backed TTransport, given a host and port | |
50 | // it also accepts a tls Configuration and a timeout as a time.Duration | |
51 | func NewTSSLSocketTimeout(hostPort string, cfg *tls.Config, timeout time.Duration) (*TSSLSocket, error) { | |
52 | if cfg.MinVersion == 0 { | |
53 | cfg.MinVersion = tls.VersionTLS10 | |
54 | } | |
55 | return &TSSLSocket{hostPort: hostPort, timeout: timeout, cfg: cfg}, nil | |
56 | } | |
57 | ||
58 | // Creates a TSSLSocket from a net.Addr | |
59 | func NewTSSLSocketFromAddrTimeout(addr net.Addr, cfg *tls.Config, timeout time.Duration) *TSSLSocket { | |
60 | return &TSSLSocket{addr: addr, timeout: timeout, cfg: cfg} | |
61 | } | |
62 | ||
63 | // Creates a TSSLSocket from an existing net.Conn | |
64 | func NewTSSLSocketFromConnTimeout(conn net.Conn, cfg *tls.Config, timeout time.Duration) *TSSLSocket { | |
65 | return &TSSLSocket{conn: conn, addr: conn.RemoteAddr(), timeout: timeout, cfg: cfg} | |
66 | } | |
67 | ||
68 | // Sets the socket timeout | |
69 | func (p *TSSLSocket) SetTimeout(timeout time.Duration) error { | |
70 | p.timeout = timeout | |
71 | return nil | |
72 | } | |
73 | ||
74 | func (p *TSSLSocket) pushDeadline(read, write bool) { | |
75 | var t time.Time | |
76 | if p.timeout > 0 { | |
77 | t = time.Now().Add(time.Duration(p.timeout)) | |
78 | } | |
79 | if read && write { | |
80 | p.conn.SetDeadline(t) | |
81 | } else if read { | |
82 | p.conn.SetReadDeadline(t) | |
83 | } else if write { | |
84 | p.conn.SetWriteDeadline(t) | |
85 | } | |
86 | } | |
87 | ||
88 | // Connects the socket, creating a new socket object if necessary. | |
89 | func (p *TSSLSocket) Open() error { | |
90 | var err error | |
91 | // If we have a hostname, we need to pass the hostname to tls.Dial for | |
92 | // certificate hostname checks. | |
93 | if p.hostPort != "" { | |
94 | if p.conn, err = tls.DialWithDialer(&net.Dialer{ | |
95 | Timeout: p.timeout}, "tcp", p.hostPort, p.cfg); err != nil { | |
96 | return NewTTransportException(NOT_OPEN, err.Error()) | |
97 | } | |
98 | } else { | |
99 | if p.IsOpen() { | |
100 | return NewTTransportException(ALREADY_OPEN, "Socket already connected.") | |
101 | } | |
102 | if p.addr == nil { | |
103 | return NewTTransportException(NOT_OPEN, "Cannot open nil address.") | |
104 | } | |
105 | if len(p.addr.Network()) == 0 { | |
106 | return NewTTransportException(NOT_OPEN, "Cannot open bad network name.") | |
107 | } | |
108 | if len(p.addr.String()) == 0 { | |
109 | return NewTTransportException(NOT_OPEN, "Cannot open bad address.") | |
110 | } | |
111 | if p.conn, err = tls.DialWithDialer(&net.Dialer{ | |
112 | Timeout: p.timeout}, p.addr.Network(), p.addr.String(), p.cfg); err != nil { | |
113 | return NewTTransportException(NOT_OPEN, err.Error()) | |
114 | } | |
115 | } | |
116 | return nil | |
117 | } | |
118 | ||
119 | // Retrieve the underlying net.Conn | |
120 | func (p *TSSLSocket) Conn() net.Conn { | |
121 | return p.conn | |
122 | } | |
123 | ||
124 | // Returns true if the connection is open | |
125 | func (p *TSSLSocket) IsOpen() bool { | |
126 | if p.conn == nil { | |
127 | return false | |
128 | } | |
129 | return true | |
130 | } | |
131 | ||
132 | // Closes the socket. | |
133 | func (p *TSSLSocket) Close() error { | |
134 | // Close the socket | |
135 | if p.conn != nil { | |
136 | err := p.conn.Close() | |
137 | if err != nil { | |
138 | return err | |
139 | } | |
140 | p.conn = nil | |
141 | } | |
142 | return nil | |
143 | } | |
144 | ||
145 | func (p *TSSLSocket) Read(buf []byte) (int, error) { | |
146 | if !p.IsOpen() { | |
147 | return 0, NewTTransportException(NOT_OPEN, "Connection not open") | |
148 | } | |
149 | p.pushDeadline(true, false) | |
150 | n, err := p.conn.Read(buf) | |
151 | return n, NewTTransportExceptionFromError(err) | |
152 | } | |
153 | ||
154 | func (p *TSSLSocket) Write(buf []byte) (int, error) { | |
155 | if !p.IsOpen() { | |
156 | return 0, NewTTransportException(NOT_OPEN, "Connection not open") | |
157 | } | |
158 | p.pushDeadline(false, true) | |
159 | return p.conn.Write(buf) | |
160 | } | |
161 | ||
162 | func (p *TSSLSocket) Flush(ctx context.Context) error { | |
163 | return nil | |
164 | } | |
165 | ||
166 | func (p *TSSLSocket) Interrupt() error { | |
167 | if !p.IsOpen() { | |
168 | return nil | |
169 | } | |
170 | return p.conn.Close() | |
171 | } | |
172 | ||
173 | func (p *TSSLSocket) RemainingBytes() (num_bytes uint64) { | |
174 | const maxSize = ^uint64(0) | |
175 | return maxSize // the thruth is, we just don't know unless framed is used | |
176 | } |