]>
Commit | Line | Data |
---|---|---|
add651ee FG |
1 | use crate::codec::UserError; |
2 | use crate::codec::UserError::*; | |
3 | use crate::frame::{self, Frame, FrameSize}; | |
4 | use crate::hpack; | |
5 | ||
6 | use bytes::{Buf, BufMut, BytesMut}; | |
7 | use std::pin::Pin; | |
8 | use std::task::{Context, Poll}; | |
9 | use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; | |
4b012472 | 10 | use tokio_util::io::poll_write_buf; |
add651ee | 11 | |
4b012472 | 12 | use std::io::{self, Cursor}; |
add651ee FG |
13 | |
14 | // A macro to get around a method needing to borrow &mut self | |
15 | macro_rules! limited_write_buf { | |
16 | ($self:expr) => {{ | |
17 | let limit = $self.max_frame_size() + frame::HEADER_LEN; | |
18 | $self.buf.get_mut().limit(limit) | |
19 | }}; | |
20 | } | |
21 | ||
22 | #[derive(Debug)] | |
23 | pub struct FramedWrite<T, B> { | |
24 | /// Upstream `AsyncWrite` | |
25 | inner: T, | |
26 | ||
27 | encoder: Encoder<B>, | |
28 | } | |
29 | ||
30 | #[derive(Debug)] | |
31 | struct Encoder<B> { | |
32 | /// HPACK encoder | |
33 | hpack: hpack::Encoder, | |
34 | ||
35 | /// Write buffer | |
36 | /// | |
37 | /// TODO: Should this be a ring buffer? | |
38 | buf: Cursor<BytesMut>, | |
39 | ||
40 | /// Next frame to encode | |
41 | next: Option<Next<B>>, | |
42 | ||
43 | /// Last data frame | |
44 | last_data_frame: Option<frame::Data<B>>, | |
45 | ||
46 | /// Max frame size, this is specified by the peer | |
47 | max_frame_size: FrameSize, | |
48 | ||
4b012472 FG |
49 | /// Chain payloads bigger than this. |
50 | chain_threshold: usize, | |
51 | ||
52 | /// Min buffer required to attempt to write a frame | |
53 | min_buffer_capacity: usize, | |
add651ee FG |
54 | } |
55 | ||
56 | #[derive(Debug)] | |
57 | enum Next<B> { | |
58 | Data(frame::Data<B>), | |
59 | Continuation(frame::Continuation), | |
60 | } | |
61 | ||
62 | /// Initialize the connection with this amount of write buffer. | |
63 | /// | |
64 | /// The minimum MAX_FRAME_SIZE is 16kb, so always be able to send a HEADERS | |
65 | /// frame that big. | |
66 | const DEFAULT_BUFFER_CAPACITY: usize = 16 * 1_024; | |
67 | ||
4b012472 FG |
68 | /// Chain payloads bigger than this when vectored I/O is enabled. The remote |
69 | /// will never advertise a max frame size less than this (well, the spec says | |
70 | /// the max frame size can't be less than 16kb, so not even close). | |
add651ee FG |
71 | const CHAIN_THRESHOLD: usize = 256; |
72 | ||
4b012472 FG |
73 | /// Chain payloads bigger than this when vectored I/O is **not** enabled. |
74 | /// A larger value in this scenario will reduce the number of small and | |
75 | /// fragmented data being sent, and hereby improve the throughput. | |
76 | const CHAIN_THRESHOLD_WITHOUT_VECTORED_IO: usize = 1024; | |
77 | ||
add651ee FG |
78 | // TODO: Make generic |
79 | impl<T, B> FramedWrite<T, B> | |
80 | where | |
81 | T: AsyncWrite + Unpin, | |
82 | B: Buf, | |
83 | { | |
84 | pub fn new(inner: T) -> FramedWrite<T, B> { | |
4b012472 FG |
85 | let chain_threshold = if inner.is_write_vectored() { |
86 | CHAIN_THRESHOLD | |
87 | } else { | |
88 | CHAIN_THRESHOLD_WITHOUT_VECTORED_IO | |
89 | }; | |
add651ee FG |
90 | FramedWrite { |
91 | inner, | |
92 | encoder: Encoder { | |
93 | hpack: hpack::Encoder::default(), | |
94 | buf: Cursor::new(BytesMut::with_capacity(DEFAULT_BUFFER_CAPACITY)), | |
95 | next: None, | |
96 | last_data_frame: None, | |
97 | max_frame_size: frame::DEFAULT_MAX_FRAME_SIZE, | |
4b012472 FG |
98 | chain_threshold, |
99 | min_buffer_capacity: chain_threshold + frame::HEADER_LEN, | |
add651ee FG |
100 | }, |
101 | } | |
102 | } | |
103 | ||
104 | /// Returns `Ready` when `send` is able to accept a frame | |
105 | /// | |
106 | /// Calling this function may result in the current contents of the buffer | |
107 | /// to be flushed to `T`. | |
108 | pub fn poll_ready(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { | |
109 | if !self.encoder.has_capacity() { | |
110 | // Try flushing | |
111 | ready!(self.flush(cx))?; | |
112 | ||
113 | if !self.encoder.has_capacity() { | |
114 | return Poll::Pending; | |
115 | } | |
116 | } | |
117 | ||
118 | Poll::Ready(Ok(())) | |
119 | } | |
120 | ||
121 | /// Buffer a frame. | |
122 | /// | |
123 | /// `poll_ready` must be called first to ensure that a frame may be | |
124 | /// accepted. | |
125 | pub fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { | |
126 | self.encoder.buffer(item) | |
127 | } | |
128 | ||
129 | /// Flush buffered data to the wire | |
130 | pub fn flush(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { | |
131 | let span = tracing::trace_span!("FramedWrite::flush"); | |
132 | let _e = span.enter(); | |
133 | ||
134 | loop { | |
135 | while !self.encoder.is_empty() { | |
136 | match self.encoder.next { | |
137 | Some(Next::Data(ref mut frame)) => { | |
138 | tracing::trace!(queued_data_frame = true); | |
139 | let mut buf = (&mut self.encoder.buf).chain(frame.payload_mut()); | |
4b012472 | 140 | ready!(poll_write_buf(Pin::new(&mut self.inner), cx, &mut buf))? |
add651ee FG |
141 | } |
142 | _ => { | |
143 | tracing::trace!(queued_data_frame = false); | |
4b012472 FG |
144 | ready!(poll_write_buf( |
145 | Pin::new(&mut self.inner), | |
add651ee | 146 | cx, |
4b012472 | 147 | &mut self.encoder.buf |
add651ee FG |
148 | ))? |
149 | } | |
4b012472 | 150 | }; |
add651ee FG |
151 | } |
152 | ||
153 | match self.encoder.unset_frame() { | |
154 | ControlFlow::Continue => (), | |
155 | ControlFlow::Break => break, | |
156 | } | |
157 | } | |
158 | ||
159 | tracing::trace!("flushing buffer"); | |
160 | // Flush the upstream | |
161 | ready!(Pin::new(&mut self.inner).poll_flush(cx))?; | |
162 | ||
163 | Poll::Ready(Ok(())) | |
164 | } | |
165 | ||
166 | /// Close the codec | |
167 | pub fn shutdown(&mut self, cx: &mut Context) -> Poll<io::Result<()>> { | |
168 | ready!(self.flush(cx))?; | |
169 | Pin::new(&mut self.inner).poll_shutdown(cx) | |
170 | } | |
171 | } | |
172 | ||
add651ee FG |
173 | #[must_use] |
174 | enum ControlFlow { | |
175 | Continue, | |
176 | Break, | |
177 | } | |
178 | ||
179 | impl<B> Encoder<B> | |
180 | where | |
181 | B: Buf, | |
182 | { | |
183 | fn unset_frame(&mut self) -> ControlFlow { | |
184 | // Clear internal buffer | |
185 | self.buf.set_position(0); | |
186 | self.buf.get_mut().clear(); | |
187 | ||
188 | // The data frame has been written, so unset it | |
189 | match self.next.take() { | |
190 | Some(Next::Data(frame)) => { | |
191 | self.last_data_frame = Some(frame); | |
192 | debug_assert!(self.is_empty()); | |
193 | ControlFlow::Break | |
194 | } | |
195 | Some(Next::Continuation(frame)) => { | |
196 | // Buffer the continuation frame, then try to write again | |
197 | let mut buf = limited_write_buf!(self); | |
198 | if let Some(continuation) = frame.encode(&mut buf) { | |
199 | self.next = Some(Next::Continuation(continuation)); | |
200 | } | |
201 | ControlFlow::Continue | |
202 | } | |
203 | None => ControlFlow::Break, | |
204 | } | |
205 | } | |
206 | ||
207 | fn buffer(&mut self, item: Frame<B>) -> Result<(), UserError> { | |
208 | // Ensure that we have enough capacity to accept the write. | |
209 | assert!(self.has_capacity()); | |
210 | let span = tracing::trace_span!("FramedWrite::buffer", frame = ?item); | |
211 | let _e = span.enter(); | |
212 | ||
213 | tracing::debug!(frame = ?item, "send"); | |
214 | ||
215 | match item { | |
216 | Frame::Data(mut v) => { | |
217 | // Ensure that the payload is not greater than the max frame. | |
218 | let len = v.payload().remaining(); | |
219 | ||
220 | if len > self.max_frame_size() { | |
221 | return Err(PayloadTooBig); | |
222 | } | |
223 | ||
4b012472 | 224 | if len >= self.chain_threshold { |
add651ee FG |
225 | let head = v.head(); |
226 | ||
227 | // Encode the frame head to the buffer | |
228 | head.encode(len, self.buf.get_mut()); | |
229 | ||
4b012472 FG |
230 | if self.buf.get_ref().remaining() < self.chain_threshold { |
231 | let extra_bytes = self.chain_threshold - self.buf.remaining(); | |
232 | self.buf.get_mut().put(v.payload_mut().take(extra_bytes)); | |
233 | } | |
234 | ||
add651ee FG |
235 | // Save the data frame |
236 | self.next = Some(Next::Data(v)); | |
237 | } else { | |
238 | v.encode_chunk(self.buf.get_mut()); | |
239 | ||
240 | // The chunk has been fully encoded, so there is no need to | |
241 | // keep it around | |
242 | assert_eq!(v.payload().remaining(), 0, "chunk not fully encoded"); | |
243 | ||
244 | // Save off the last frame... | |
245 | self.last_data_frame = Some(v); | |
246 | } | |
247 | } | |
248 | Frame::Headers(v) => { | |
249 | let mut buf = limited_write_buf!(self); | |
250 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { | |
251 | self.next = Some(Next::Continuation(continuation)); | |
252 | } | |
253 | } | |
254 | Frame::PushPromise(v) => { | |
255 | let mut buf = limited_write_buf!(self); | |
256 | if let Some(continuation) = v.encode(&mut self.hpack, &mut buf) { | |
257 | self.next = Some(Next::Continuation(continuation)); | |
258 | } | |
259 | } | |
260 | Frame::Settings(v) => { | |
261 | v.encode(self.buf.get_mut()); | |
262 | tracing::trace!(rem = self.buf.remaining(), "encoded settings"); | |
263 | } | |
264 | Frame::GoAway(v) => { | |
265 | v.encode(self.buf.get_mut()); | |
266 | tracing::trace!(rem = self.buf.remaining(), "encoded go_away"); | |
267 | } | |
268 | Frame::Ping(v) => { | |
269 | v.encode(self.buf.get_mut()); | |
270 | tracing::trace!(rem = self.buf.remaining(), "encoded ping"); | |
271 | } | |
272 | Frame::WindowUpdate(v) => { | |
273 | v.encode(self.buf.get_mut()); | |
274 | tracing::trace!(rem = self.buf.remaining(), "encoded window_update"); | |
275 | } | |
276 | ||
277 | Frame::Priority(_) => { | |
278 | /* | |
279 | v.encode(self.buf.get_mut()); | |
280 | tracing::trace!("encoded priority; rem={:?}", self.buf.remaining()); | |
281 | */ | |
282 | unimplemented!(); | |
283 | } | |
284 | Frame::Reset(v) => { | |
285 | v.encode(self.buf.get_mut()); | |
286 | tracing::trace!(rem = self.buf.remaining(), "encoded reset"); | |
287 | } | |
288 | } | |
289 | ||
290 | Ok(()) | |
291 | } | |
292 | ||
293 | fn has_capacity(&self) -> bool { | |
4b012472 FG |
294 | self.next.is_none() |
295 | && (self.buf.get_ref().capacity() - self.buf.get_ref().len() | |
296 | >= self.min_buffer_capacity) | |
add651ee FG |
297 | } |
298 | ||
299 | fn is_empty(&self) -> bool { | |
300 | match self.next { | |
301 | Some(Next::Data(ref frame)) => !frame.payload().has_remaining(), | |
302 | _ => !self.buf.has_remaining(), | |
303 | } | |
304 | } | |
305 | } | |
306 | ||
307 | impl<B> Encoder<B> { | |
308 | fn max_frame_size(&self) -> usize { | |
309 | self.max_frame_size as usize | |
310 | } | |
311 | } | |
312 | ||
313 | impl<T, B> FramedWrite<T, B> { | |
314 | /// Returns the max frame size that can be sent | |
315 | pub fn max_frame_size(&self) -> usize { | |
316 | self.encoder.max_frame_size() | |
317 | } | |
318 | ||
319 | /// Set the peer's max frame size. | |
320 | pub fn set_max_frame_size(&mut self, val: usize) { | |
321 | assert!(val <= frame::MAX_MAX_FRAME_SIZE as usize); | |
322 | self.encoder.max_frame_size = val as FrameSize; | |
323 | } | |
324 | ||
325 | /// Set the peer's header table size. | |
326 | pub fn set_header_table_size(&mut self, val: usize) { | |
327 | self.encoder.hpack.update_max_size(val); | |
328 | } | |
329 | ||
330 | /// Retrieve the last data frame that has been sent | |
331 | pub fn take_last_data_frame(&mut self) -> Option<frame::Data<B>> { | |
332 | self.encoder.last_data_frame.take() | |
333 | } | |
334 | ||
335 | pub fn get_mut(&mut self) -> &mut T { | |
336 | &mut self.inner | |
337 | } | |
338 | } | |
339 | ||
340 | impl<T: AsyncRead + Unpin, B> AsyncRead for FramedWrite<T, B> { | |
341 | fn poll_read( | |
342 | mut self: Pin<&mut Self>, | |
343 | cx: &mut Context<'_>, | |
344 | buf: &mut ReadBuf, | |
345 | ) -> Poll<io::Result<()>> { | |
346 | Pin::new(&mut self.inner).poll_read(cx, buf) | |
347 | } | |
348 | } | |
349 | ||
350 | // We never project the Pin to `B`. | |
351 | impl<T: Unpin, B> Unpin for FramedWrite<T, B> {} | |
352 | ||
353 | #[cfg(feature = "unstable")] | |
354 | mod unstable { | |
355 | use super::*; | |
356 | ||
357 | impl<T, B> FramedWrite<T, B> { | |
358 | pub fn get_ref(&self) -> &T { | |
359 | &self.inner | |
360 | } | |
361 | } | |
362 | } |