]> git.proxmox.com Git - proxmox-backup.git/blame - proxmox-protocol/src/common.rs
more formatting & use statement fixups
[proxmox-backup.git] / proxmox-protocol / src / common.rs
CommitLineData
ac4e349b
WB
1use std::io::{self, Read, Write};
2use std::mem;
3use std::ptr;
4
5use failure::*;
6
7use endian_trait::Endian;
8
9use crate::protocol::*;
10
11type Result<T> = std::result::Result<T, Error>;
12
13pub(crate) struct Connection<S>
14where
15 S: Read + Write,
16{
17 socket: S,
18 pub buffer: Vec<u8>,
19 pub current_packet: Packet,
20 pub current_packet_type: PacketType,
21 pub error: bool,
22 pub eof: bool,
23 upload_queue: Option<(Vec<u8>, usize)>,
24}
25
26impl<S> Connection<S>
27where
28 S: Read + Write,
29{
30 pub fn new(socket: S) -> Self {
31 Self {
32 socket,
33 buffer: Vec::new(),
34 current_packet: unsafe { mem::zeroed() },
35 current_packet_type: PacketType::Error,
36 error: false,
37 eof: false,
38 upload_queue: None,
39 }
40 }
41
42 pub fn write_some(&mut self, buf: &[u8]) -> std::io::Result<usize> {
43 self.socket.write(buf)
44 }
45
46 /// It is safe to clear the error after an `io::ErrorKind::Interrupted`.
47 pub fn clear_err(&mut self) {
48 self.error = false;
49 }
50
51 // None => nothing was queued
52 // Some(true) => queue finished
53 // Some(false) => queue not finished
54 pub fn poll_send(&mut self) -> Result<Option<bool>> {
55 if let Some((ref data, ref mut pos)) = self.upload_queue {
56 loop {
57 match self.socket.write(&data[*pos..]) {
58 Ok(put) => {
59 *pos += put;
60 if *pos == data.len() {
61 self.upload_queue = None;
62 return Ok(Some(true));
63 }
64 // Keep writing
65 continue;
66 }
67 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
68 return Ok(Some(false));
69 }
70 Err(e) => return Err(e.into()),
71 }
72 }
73 } else {
74 Ok(None)
75 }
76 }
77
78 // Returns true when the data was also sent out, false if the queue is now full.
79 // For now we only allow a single dataset to be queued at once.
80 pub fn queue_data(&mut self, buf: Vec<u8>) -> Result<bool> {
81 if self.upload_queue.is_some() {
82 bail!("upload queue clash");
83 }
84
85 self.upload_queue = Some((buf, 0));
86
87 match self.poll_send()? {
88 None => unreachable!(), // We literally just set self.upload_queue to Some(value)
89 Some(v) => Ok(v),
90 }
91 }
92
93 // Returns 'true' if there's data available, 'false' if there isn't (if the
94 // underlying reader returned `WouldBlock` or the `read()` was short).
95 // Other errors are propagated.
96 pub fn poll_read(&mut self) -> Result<bool> {
97 if self.eof {
98 return Ok(false);
99 }
100
101 if self.error {
102 eprintln!("refusing to read from a client in error state");
103 bail!("client is in error state");
104 }
105
106 match self.poll_data_do() {
107 Ok(has_packet) => Ok(has_packet),
108 Err(e) => {
ac4e349b
WB
109 self.error = true;
110 Err(e)
111 }
112 }
113 }
114
115 fn poll_data_do(&mut self) -> Result<bool> {
116 if !self.read_packet()? {
117 return Ok(false);
118 }
119
120 if self.current_packet.length > MAX_PACKET_SIZE {
121 bail!("client tried to send a huge packet");
122 }
123
124 if !self.fill_packet()? {
125 return Ok(false);
126 }
127
128 Ok(true)
129 }
130
131 pub fn packet_length(&self) -> usize {
132 self.current_packet.length as usize
133 }
134
135 pub fn packet_data(&self) -> &[u8] {
136 let beg = mem::size_of::<Packet>();
137 let end = self.packet_length();
138 &self.buffer[beg..end]
139 }
140
e4027693 141 pub fn next(&mut self) -> Result<()> {
ac4e349b
WB
142 let pktlen = self.packet_length();
143 unsafe {
144 if self.buffer.len() != pktlen {
145 std::ptr::copy_nonoverlapping(
146 &self.buffer[pktlen],
147 &mut self.buffer[0],
148 self.buffer.len() - pktlen,
149 );
150 }
151 self.buffer.set_len(self.buffer.len() - pktlen);
152 }
e4027693 153 Ok(())
ac4e349b
WB
154 }
155
156 // NOTE: After calling this you must `self.buffer.set_len()` when done!
157 #[must_use]
158 fn buffer_set_min_size(&mut self, size: usize) -> usize {
159 if self.buffer.capacity() < size {
160 self.buffer.reserve(size - self.buffer.len());
161 }
162 let start = self.buffer.len();
163 unsafe {
164 self.buffer.set_len(size);
165 }
166 start
167 }
168
169 fn fill_buffer(&mut self, size: usize) -> Result<bool> {
170 if self.buffer.len() >= size {
171 return Ok(true);
172 }
173 let mut filled = self.buffer_set_min_size(size);
174 loop {
175 // We don't use read_exact to not block too long or busy-read on nonblocking sockets...
176 match self.socket.read(&mut self.buffer[filled..]) {
177 Ok(got) => {
178 if got == 0 {
179 self.eof = true;
180 unsafe {
181 self.buffer.set_len(filled);
182 }
183 return Ok(false);
184 }
185 filled += got;
186 if filled >= size {
187 unsafe {
188 self.buffer.set_len(filled);
189 }
190 return Ok(true);
191 }
192 // reloop
193 }
194 Err(e) => {
195 unsafe {
196 self.buffer.set_len(filled);
197 }
198 return Err(e.into());
199 }
200 }
201 }
202 }
203
204 fn read_packet_do(&mut self) -> Result<bool> {
205 if !self.fill_buffer(mem::size_of::<Packet>())? {
206 return Ok(false);
207 }
208
209 self.current_packet = self.read_unaligned::<Packet>(0)?.from_le();
210
211 self.current_packet_type = match PacketType::try_from(self.current_packet.pkttype) {
212 Some(t) => t,
213 None => bail!("unexpected packet type"),
214 };
215
216 let length = self.current_packet.length;
217 if (length as usize) < mem::size_of::<Packet>() {
218 bail!("received packet of bad length ({})", length);
219 }
220
221 Ok(true)
222 }
223
224 fn read_packet(&mut self) -> Result<bool> {
225 match self.read_packet_do() {
226 Ok(b) => Ok(b),
227 Err(e) => {
228 if let Some(ioe) = e.downcast_ref::<std::io::Error>() {
229 if ioe.kind() == io::ErrorKind::WouldBlock {
230 return Ok(false);
231 }
232 }
233 Err(e)
234 }
235 }
236 }
237
238 fn read_unaligned<T: Endian>(&self, offset: usize) -> Result<T> {
239 if offset + mem::size_of::<T>() > self.buffer.len() {
240 bail!("buffer underrun");
241 }
242 Ok(unsafe { ptr::read_unaligned(&self.buffer[offset] as *const _ as *const T) }.from_le())
243 }
244
245 pub fn read_unaligned_data<T: Endian>(&self, offset: usize) -> Result<T> {
246 self.read_unaligned(offset + mem::size_of::<Packet>())
247 }
248
249 fn fill_packet(&mut self) -> Result<bool> {
250 self.fill_buffer(self.current_packet.length as usize)
251 }
252
253 // convenience helpers:
254
255 pub fn assert_size(&self, size: usize) -> Result<()> {
256 if self.packet_data().len() != size {
257 bail!(
258 "protocol error: invalid packet size (type {})",
259 self.current_packet.pkttype,
260 );
261 }
262 Ok(())
263 }
264
265 pub fn assert_atleast(&self, size: usize) -> Result<()> {
266 if self.packet_data().len() < size {
267 bail!(
268 "protocol error: invalid packet size (type {})",
269 self.current_packet.pkttype,
270 );
271 }
272 Ok(())
273 }
274}