]>
Commit | Line | Data |
---|---|---|
ac4e349b WB |
1 | use std::io::{self, Read, Write}; |
2 | use std::mem; | |
3 | use std::ptr; | |
4 | ||
5 | use failure::*; | |
6 | ||
7 | use endian_trait::Endian; | |
8 | ||
9 | use crate::protocol::*; | |
10 | ||
11 | type Result<T> = std::result::Result<T, Error>; | |
12 | ||
13 | pub(crate) struct Connection<S> | |
14 | where | |
15 | S: Read + Write, | |
16 | { | |
17 | socket: S, | |
18 | pub buffer: Vec<u8>, | |
19 | pub current_packet: Packet, | |
20 | pub current_packet_type: PacketType, | |
21 | pub error: bool, | |
22 | pub eof: bool, | |
23 | upload_queue: Option<(Vec<u8>, usize)>, | |
24 | } | |
25 | ||
26 | impl<S> Connection<S> | |
27 | where | |
28 | S: Read + Write, | |
29 | { | |
30 | pub fn new(socket: S) -> Self { | |
31 | Self { | |
32 | socket, | |
33 | buffer: Vec::new(), | |
34 | current_packet: unsafe { mem::zeroed() }, | |
35 | current_packet_type: PacketType::Error, | |
36 | error: false, | |
37 | eof: false, | |
38 | upload_queue: None, | |
39 | } | |
40 | } | |
41 | ||
42 | pub fn write_some(&mut self, buf: &[u8]) -> std::io::Result<usize> { | |
43 | self.socket.write(buf) | |
44 | } | |
45 | ||
46 | /// It is safe to clear the error after an `io::ErrorKind::Interrupted`. | |
47 | pub fn clear_err(&mut self) { | |
48 | self.error = false; | |
49 | } | |
50 | ||
51 | // None => nothing was queued | |
52 | // Some(true) => queue finished | |
53 | // Some(false) => queue not finished | |
54 | pub fn poll_send(&mut self) -> Result<Option<bool>> { | |
55 | if let Some((ref data, ref mut pos)) = self.upload_queue { | |
56 | loop { | |
57 | match self.socket.write(&data[*pos..]) { | |
58 | Ok(put) => { | |
59 | *pos += put; | |
60 | if *pos == data.len() { | |
61 | self.upload_queue = None; | |
62 | return Ok(Some(true)); | |
63 | } | |
64 | // Keep writing | |
65 | continue; | |
66 | } | |
67 | Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { | |
68 | return Ok(Some(false)); | |
69 | } | |
70 | Err(e) => return Err(e.into()), | |
71 | } | |
72 | } | |
73 | } else { | |
74 | Ok(None) | |
75 | } | |
76 | } | |
77 | ||
78 | // Returns true when the data was also sent out, false if the queue is now full. | |
79 | // For now we only allow a single dataset to be queued at once. | |
80 | pub fn queue_data(&mut self, buf: Vec<u8>) -> Result<bool> { | |
81 | if self.upload_queue.is_some() { | |
82 | bail!("upload queue clash"); | |
83 | } | |
84 | ||
85 | self.upload_queue = Some((buf, 0)); | |
86 | ||
87 | match self.poll_send()? { | |
88 | None => unreachable!(), // We literally just set self.upload_queue to Some(value) | |
89 | Some(v) => Ok(v), | |
90 | } | |
91 | } | |
92 | ||
93 | // Returns 'true' if there's data available, 'false' if there isn't (if the | |
94 | // underlying reader returned `WouldBlock` or the `read()` was short). | |
95 | // Other errors are propagated. | |
96 | pub fn poll_read(&mut self) -> Result<bool> { | |
97 | if self.eof { | |
98 | return Ok(false); | |
99 | } | |
100 | ||
101 | if self.error { | |
102 | eprintln!("refusing to read from a client in error state"); | |
103 | bail!("client is in error state"); | |
104 | } | |
105 | ||
106 | match self.poll_data_do() { | |
107 | Ok(has_packet) => Ok(has_packet), | |
108 | Err(e) => { | |
ac4e349b WB |
109 | self.error = true; |
110 | Err(e) | |
111 | } | |
112 | } | |
113 | } | |
114 | ||
115 | fn poll_data_do(&mut self) -> Result<bool> { | |
116 | if !self.read_packet()? { | |
117 | return Ok(false); | |
118 | } | |
119 | ||
120 | if self.current_packet.length > MAX_PACKET_SIZE { | |
121 | bail!("client tried to send a huge packet"); | |
122 | } | |
123 | ||
124 | if !self.fill_packet()? { | |
125 | return Ok(false); | |
126 | } | |
127 | ||
128 | Ok(true) | |
129 | } | |
130 | ||
131 | pub fn packet_length(&self) -> usize { | |
132 | self.current_packet.length as usize | |
133 | } | |
134 | ||
135 | pub fn packet_data(&self) -> &[u8] { | |
136 | let beg = mem::size_of::<Packet>(); | |
137 | let end = self.packet_length(); | |
138 | &self.buffer[beg..end] | |
139 | } | |
140 | ||
e4027693 | 141 | pub fn next(&mut self) -> Result<()> { |
ac4e349b WB |
142 | let pktlen = self.packet_length(); |
143 | unsafe { | |
144 | if self.buffer.len() != pktlen { | |
145 | std::ptr::copy_nonoverlapping( | |
146 | &self.buffer[pktlen], | |
147 | &mut self.buffer[0], | |
148 | self.buffer.len() - pktlen, | |
149 | ); | |
150 | } | |
151 | self.buffer.set_len(self.buffer.len() - pktlen); | |
152 | } | |
e4027693 | 153 | Ok(()) |
ac4e349b WB |
154 | } |
155 | ||
156 | // NOTE: After calling this you must `self.buffer.set_len()` when done! | |
157 | #[must_use] | |
158 | fn buffer_set_min_size(&mut self, size: usize) -> usize { | |
159 | if self.buffer.capacity() < size { | |
160 | self.buffer.reserve(size - self.buffer.len()); | |
161 | } | |
162 | let start = self.buffer.len(); | |
163 | unsafe { | |
164 | self.buffer.set_len(size); | |
165 | } | |
166 | start | |
167 | } | |
168 | ||
169 | fn fill_buffer(&mut self, size: usize) -> Result<bool> { | |
170 | if self.buffer.len() >= size { | |
171 | return Ok(true); | |
172 | } | |
173 | let mut filled = self.buffer_set_min_size(size); | |
174 | loop { | |
175 | // We don't use read_exact to not block too long or busy-read on nonblocking sockets... | |
176 | match self.socket.read(&mut self.buffer[filled..]) { | |
177 | Ok(got) => { | |
178 | if got == 0 { | |
179 | self.eof = true; | |
180 | unsafe { | |
181 | self.buffer.set_len(filled); | |
182 | } | |
183 | return Ok(false); | |
184 | } | |
185 | filled += got; | |
186 | if filled >= size { | |
187 | unsafe { | |
188 | self.buffer.set_len(filled); | |
189 | } | |
190 | return Ok(true); | |
191 | } | |
192 | // reloop | |
193 | } | |
194 | Err(e) => { | |
195 | unsafe { | |
196 | self.buffer.set_len(filled); | |
197 | } | |
198 | return Err(e.into()); | |
199 | } | |
200 | } | |
201 | } | |
202 | } | |
203 | ||
204 | fn read_packet_do(&mut self) -> Result<bool> { | |
205 | if !self.fill_buffer(mem::size_of::<Packet>())? { | |
206 | return Ok(false); | |
207 | } | |
208 | ||
209 | self.current_packet = self.read_unaligned::<Packet>(0)?.from_le(); | |
210 | ||
211 | self.current_packet_type = match PacketType::try_from(self.current_packet.pkttype) { | |
212 | Some(t) => t, | |
213 | None => bail!("unexpected packet type"), | |
214 | }; | |
215 | ||
216 | let length = self.current_packet.length; | |
217 | if (length as usize) < mem::size_of::<Packet>() { | |
218 | bail!("received packet of bad length ({})", length); | |
219 | } | |
220 | ||
221 | Ok(true) | |
222 | } | |
223 | ||
224 | fn read_packet(&mut self) -> Result<bool> { | |
225 | match self.read_packet_do() { | |
226 | Ok(b) => Ok(b), | |
227 | Err(e) => { | |
228 | if let Some(ioe) = e.downcast_ref::<std::io::Error>() { | |
229 | if ioe.kind() == io::ErrorKind::WouldBlock { | |
230 | return Ok(false); | |
231 | } | |
232 | } | |
233 | Err(e) | |
234 | } | |
235 | } | |
236 | } | |
237 | ||
238 | fn read_unaligned<T: Endian>(&self, offset: usize) -> Result<T> { | |
239 | if offset + mem::size_of::<T>() > self.buffer.len() { | |
240 | bail!("buffer underrun"); | |
241 | } | |
242 | Ok(unsafe { ptr::read_unaligned(&self.buffer[offset] as *const _ as *const T) }.from_le()) | |
243 | } | |
244 | ||
245 | pub fn read_unaligned_data<T: Endian>(&self, offset: usize) -> Result<T> { | |
246 | self.read_unaligned(offset + mem::size_of::<Packet>()) | |
247 | } | |
248 | ||
249 | fn fill_packet(&mut self) -> Result<bool> { | |
250 | self.fill_buffer(self.current_packet.length as usize) | |
251 | } | |
252 | ||
253 | // convenience helpers: | |
254 | ||
255 | pub fn assert_size(&self, size: usize) -> Result<()> { | |
256 | if self.packet_data().len() != size { | |
257 | bail!( | |
258 | "protocol error: invalid packet size (type {})", | |
259 | self.current_packet.pkttype, | |
260 | ); | |
261 | } | |
262 | Ok(()) | |
263 | } | |
264 | ||
265 | pub fn assert_atleast(&self, size: usize) -> Result<()> { | |
266 | if self.packet_data().len() < size { | |
267 | bail!( | |
268 | "protocol error: invalid packet size (type {})", | |
269 | self.current_packet.pkttype, | |
270 | ); | |
271 | } | |
272 | Ok(()) | |
273 | } | |
274 | } |