]>
Commit | Line | Data |
---|---|---|
add651ee FG |
1 | use crate::frame::{self, Frame, Kind, Reason}; |
2 | use crate::frame::{ | |
3 | DEFAULT_MAX_FRAME_SIZE, DEFAULT_SETTINGS_HEADER_TABLE_SIZE, MAX_MAX_FRAME_SIZE, | |
4 | }; | |
5 | use crate::proto::Error; | |
6 | ||
7 | use crate::hpack; | |
8 | ||
9 | use futures_core::Stream; | |
10 | ||
11 | use bytes::BytesMut; | |
12 | ||
13 | use std::io; | |
14 | ||
15 | use std::pin::Pin; | |
16 | use std::task::{Context, Poll}; | |
17 | use tokio::io::AsyncRead; | |
18 | use tokio_util::codec::FramedRead as InnerFramedRead; | |
19 | use tokio_util::codec::{LengthDelimitedCodec, LengthDelimitedCodecError}; | |
20 | ||
21 | // 16 MB "sane default" taken from golang http2 | |
22 | const DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE: usize = 16 << 20; | |
23 | ||
24 | #[derive(Debug)] | |
25 | pub struct FramedRead<T> { | |
26 | inner: InnerFramedRead<T, LengthDelimitedCodec>, | |
27 | ||
28 | // hpack decoder state | |
29 | hpack: hpack::Decoder, | |
30 | ||
31 | max_header_list_size: usize, | |
32 | ||
33 | partial: Option<Partial>, | |
34 | } | |
35 | ||
36 | /// Partially loaded headers frame | |
37 | #[derive(Debug)] | |
38 | struct Partial { | |
39 | /// Empty frame | |
40 | frame: Continuable, | |
41 | ||
42 | /// Partial header payload | |
43 | buf: BytesMut, | |
44 | } | |
45 | ||
46 | #[derive(Debug)] | |
47 | enum Continuable { | |
48 | Headers(frame::Headers), | |
49 | PushPromise(frame::PushPromise), | |
50 | } | |
51 | ||
52 | impl<T> FramedRead<T> { | |
53 | pub fn new(inner: InnerFramedRead<T, LengthDelimitedCodec>) -> FramedRead<T> { | |
54 | FramedRead { | |
55 | inner, | |
56 | hpack: hpack::Decoder::new(DEFAULT_SETTINGS_HEADER_TABLE_SIZE), | |
57 | max_header_list_size: DEFAULT_SETTINGS_MAX_HEADER_LIST_SIZE, | |
58 | partial: None, | |
59 | } | |
60 | } | |
61 | ||
62 | pub fn get_ref(&self) -> &T { | |
63 | self.inner.get_ref() | |
64 | } | |
65 | ||
66 | pub fn get_mut(&mut self) -> &mut T { | |
67 | self.inner.get_mut() | |
68 | } | |
69 | ||
70 | /// Returns the current max frame size setting | |
71 | #[cfg(feature = "unstable")] | |
72 | #[inline] | |
73 | pub fn max_frame_size(&self) -> usize { | |
74 | self.inner.decoder().max_frame_length() | |
75 | } | |
76 | ||
77 | /// Updates the max frame size setting. | |
78 | /// | |
79 | /// Must be within 16,384 and 16,777,215. | |
80 | #[inline] | |
81 | pub fn set_max_frame_size(&mut self, val: usize) { | |
82 | assert!(DEFAULT_MAX_FRAME_SIZE as usize <= val && val <= MAX_MAX_FRAME_SIZE as usize); | |
83 | self.inner.decoder_mut().set_max_frame_length(val) | |
84 | } | |
85 | ||
86 | /// Update the max header list size setting. | |
87 | #[inline] | |
88 | pub fn set_max_header_list_size(&mut self, val: usize) { | |
89 | self.max_header_list_size = val; | |
90 | } | |
4b012472 FG |
91 | |
92 | /// Update the header table size setting. | |
93 | #[inline] | |
94 | pub fn set_header_table_size(&mut self, val: usize) { | |
95 | self.hpack.queue_size_update(val); | |
96 | } | |
add651ee FG |
97 | } |
98 | ||
99 | /// Decodes a frame. | |
100 | /// | |
101 | /// This method is intentionally de-generified and outlined because it is very large. | |
102 | fn decode_frame( | |
103 | hpack: &mut hpack::Decoder, | |
104 | max_header_list_size: usize, | |
105 | partial_inout: &mut Option<Partial>, | |
106 | mut bytes: BytesMut, | |
107 | ) -> Result<Option<Frame>, Error> { | |
108 | let span = tracing::trace_span!("FramedRead::decode_frame", offset = bytes.len()); | |
109 | let _e = span.enter(); | |
110 | ||
111 | tracing::trace!("decoding frame from {}B", bytes.len()); | |
112 | ||
113 | // Parse the head | |
114 | let head = frame::Head::parse(&bytes); | |
115 | ||
116 | if partial_inout.is_some() && head.kind() != Kind::Continuation { | |
117 | proto_err!(conn: "expected CONTINUATION, got {:?}", head.kind()); | |
118 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
119 | } | |
120 | ||
121 | let kind = head.kind(); | |
122 | ||
123 | tracing::trace!(frame.kind = ?kind); | |
124 | ||
125 | macro_rules! header_block { | |
126 | ($frame:ident, $head:ident, $bytes:ident) => ({ | |
127 | // Drop the frame header | |
128 | // TODO: Change to drain: carllerche/bytes#130 | |
129 | let _ = $bytes.split_to(frame::HEADER_LEN); | |
130 | ||
131 | // Parse the header frame w/o parsing the payload | |
132 | let (mut frame, mut payload) = match frame::$frame::load($head, $bytes) { | |
133 | Ok(res) => res, | |
134 | Err(frame::Error::InvalidDependencyId) => { | |
135 | proto_err!(stream: "invalid HEADERS dependency ID"); | |
136 | // A stream cannot depend on itself. An endpoint MUST | |
137 | // treat this as a stream error (Section 5.4.2) of type | |
138 | // `PROTOCOL_ERROR`. | |
139 | return Err(Error::library_reset($head.stream_id(), Reason::PROTOCOL_ERROR)); | |
140 | }, | |
141 | Err(e) => { | |
142 | proto_err!(conn: "failed to load frame; err={:?}", e); | |
143 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
144 | } | |
145 | }; | |
146 | ||
147 | let is_end_headers = frame.is_end_headers(); | |
148 | ||
149 | // Load the HPACK encoded headers | |
150 | match frame.load_hpack(&mut payload, max_header_list_size, hpack) { | |
151 | Ok(_) => {}, | |
152 | Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {}, | |
153 | Err(frame::Error::MalformedMessage) => { | |
154 | let id = $head.stream_id(); | |
155 | proto_err!(stream: "malformed header block; stream={:?}", id); | |
156 | return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); | |
157 | }, | |
158 | Err(e) => { | |
159 | proto_err!(conn: "failed HPACK decoding; err={:?}", e); | |
160 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
161 | } | |
162 | } | |
163 | ||
164 | if is_end_headers { | |
165 | frame.into() | |
166 | } else { | |
167 | tracing::trace!("loaded partial header block"); | |
168 | // Defer returning the frame | |
169 | *partial_inout = Some(Partial { | |
170 | frame: Continuable::$frame(frame), | |
171 | buf: payload, | |
172 | }); | |
173 | ||
174 | return Ok(None); | |
175 | } | |
176 | }); | |
177 | } | |
178 | ||
179 | let frame = match kind { | |
180 | Kind::Settings => { | |
181 | let res = frame::Settings::load(head, &bytes[frame::HEADER_LEN..]); | |
182 | ||
183 | res.map_err(|e| { | |
184 | proto_err!(conn: "failed to load SETTINGS frame; err={:?}", e); | |
185 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
186 | })? | |
187 | .into() | |
188 | } | |
189 | Kind::Ping => { | |
190 | let res = frame::Ping::load(head, &bytes[frame::HEADER_LEN..]); | |
191 | ||
192 | res.map_err(|e| { | |
193 | proto_err!(conn: "failed to load PING frame; err={:?}", e); | |
194 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
195 | })? | |
196 | .into() | |
197 | } | |
198 | Kind::WindowUpdate => { | |
199 | let res = frame::WindowUpdate::load(head, &bytes[frame::HEADER_LEN..]); | |
200 | ||
201 | res.map_err(|e| { | |
202 | proto_err!(conn: "failed to load WINDOW_UPDATE frame; err={:?}", e); | |
203 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
204 | })? | |
205 | .into() | |
206 | } | |
207 | Kind::Data => { | |
208 | let _ = bytes.split_to(frame::HEADER_LEN); | |
209 | let res = frame::Data::load(head, bytes.freeze()); | |
210 | ||
211 | // TODO: Should this always be connection level? Probably not... | |
212 | res.map_err(|e| { | |
213 | proto_err!(conn: "failed to load DATA frame; err={:?}", e); | |
214 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
215 | })? | |
216 | .into() | |
217 | } | |
218 | Kind::Headers => header_block!(Headers, head, bytes), | |
219 | Kind::Reset => { | |
220 | let res = frame::Reset::load(head, &bytes[frame::HEADER_LEN..]); | |
221 | res.map_err(|e| { | |
222 | proto_err!(conn: "failed to load RESET frame; err={:?}", e); | |
223 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
224 | })? | |
225 | .into() | |
226 | } | |
227 | Kind::GoAway => { | |
228 | let res = frame::GoAway::load(&bytes[frame::HEADER_LEN..]); | |
229 | res.map_err(|e| { | |
230 | proto_err!(conn: "failed to load GO_AWAY frame; err={:?}", e); | |
231 | Error::library_go_away(Reason::PROTOCOL_ERROR) | |
232 | })? | |
233 | .into() | |
234 | } | |
235 | Kind::PushPromise => header_block!(PushPromise, head, bytes), | |
236 | Kind::Priority => { | |
237 | if head.stream_id() == 0 { | |
238 | // Invalid stream identifier | |
239 | proto_err!(conn: "invalid stream ID 0"); | |
240 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
241 | } | |
242 | ||
243 | match frame::Priority::load(head, &bytes[frame::HEADER_LEN..]) { | |
244 | Ok(frame) => frame.into(), | |
245 | Err(frame::Error::InvalidDependencyId) => { | |
246 | // A stream cannot depend on itself. An endpoint MUST | |
247 | // treat this as a stream error (Section 5.4.2) of type | |
248 | // `PROTOCOL_ERROR`. | |
249 | let id = head.stream_id(); | |
250 | proto_err!(stream: "PRIORITY invalid dependency ID; stream={:?}", id); | |
251 | return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); | |
252 | } | |
253 | Err(e) => { | |
254 | proto_err!(conn: "failed to load PRIORITY frame; err={:?};", e); | |
255 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
256 | } | |
257 | } | |
258 | } | |
259 | Kind::Continuation => { | |
260 | let is_end_headers = (head.flag() & 0x4) == 0x4; | |
261 | ||
262 | let mut partial = match partial_inout.take() { | |
263 | Some(partial) => partial, | |
264 | None => { | |
265 | proto_err!(conn: "received unexpected CONTINUATION frame"); | |
266 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
267 | } | |
268 | }; | |
269 | ||
270 | // The stream identifiers must match | |
271 | if partial.frame.stream_id() != head.stream_id() { | |
272 | proto_err!(conn: "CONTINUATION frame stream ID does not match previous frame stream ID"); | |
273 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
274 | } | |
275 | ||
276 | // Extend the buf | |
277 | if partial.buf.is_empty() { | |
278 | partial.buf = bytes.split_off(frame::HEADER_LEN); | |
279 | } else { | |
280 | if partial.frame.is_over_size() { | |
281 | // If there was left over bytes previously, they may be | |
282 | // needed to continue decoding, even though we will | |
283 | // be ignoring this frame. This is done to keep the HPACK | |
284 | // decoder state up-to-date. | |
285 | // | |
286 | // Still, we need to be careful, because if a malicious | |
287 | // attacker were to try to send a gigantic string, such | |
288 | // that it fits over multiple header blocks, we could | |
289 | // grow memory uncontrollably again, and that'd be a shame. | |
290 | // | |
291 | // Instead, we use a simple heuristic to determine if | |
292 | // we should continue to ignore decoding, or to tell | |
293 | // the attacker to go away. | |
294 | if partial.buf.len() + bytes.len() > max_header_list_size { | |
295 | proto_err!(conn: "CONTINUATION frame header block size over ignorable limit"); | |
296 | return Err(Error::library_go_away(Reason::COMPRESSION_ERROR)); | |
297 | } | |
298 | } | |
299 | partial.buf.extend_from_slice(&bytes[frame::HEADER_LEN..]); | |
300 | } | |
301 | ||
302 | match partial | |
303 | .frame | |
304 | .load_hpack(&mut partial.buf, max_header_list_size, hpack) | |
305 | { | |
306 | Ok(_) => {} | |
307 | Err(frame::Error::Hpack(hpack::DecoderError::NeedMore(_))) if !is_end_headers => {} | |
308 | Err(frame::Error::MalformedMessage) => { | |
309 | let id = head.stream_id(); | |
310 | proto_err!(stream: "malformed CONTINUATION frame; stream={:?}", id); | |
311 | return Err(Error::library_reset(id, Reason::PROTOCOL_ERROR)); | |
312 | } | |
313 | Err(e) => { | |
314 | proto_err!(conn: "failed HPACK decoding; err={:?}", e); | |
315 | return Err(Error::library_go_away(Reason::PROTOCOL_ERROR)); | |
316 | } | |
317 | } | |
318 | ||
319 | if is_end_headers { | |
320 | partial.frame.into() | |
321 | } else { | |
322 | *partial_inout = Some(partial); | |
323 | return Ok(None); | |
324 | } | |
325 | } | |
326 | Kind::Unknown => { | |
327 | // Unknown frames are ignored | |
328 | return Ok(None); | |
329 | } | |
330 | }; | |
331 | ||
332 | Ok(Some(frame)) | |
333 | } | |
334 | ||
335 | impl<T> Stream for FramedRead<T> | |
336 | where | |
337 | T: AsyncRead + Unpin, | |
338 | { | |
339 | type Item = Result<Frame, Error>; | |
340 | ||
341 | fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { | |
342 | let span = tracing::trace_span!("FramedRead::poll_next"); | |
343 | let _e = span.enter(); | |
344 | loop { | |
345 | tracing::trace!("poll"); | |
346 | let bytes = match ready!(Pin::new(&mut self.inner).poll_next(cx)) { | |
347 | Some(Ok(bytes)) => bytes, | |
348 | Some(Err(e)) => return Poll::Ready(Some(Err(map_err(e)))), | |
349 | None => return Poll::Ready(None), | |
350 | }; | |
351 | ||
352 | tracing::trace!(read.bytes = bytes.len()); | |
353 | let Self { | |
354 | ref mut hpack, | |
355 | max_header_list_size, | |
356 | ref mut partial, | |
357 | .. | |
358 | } = *self; | |
359 | if let Some(frame) = decode_frame(hpack, max_header_list_size, partial, bytes)? { | |
360 | tracing::debug!(?frame, "received"); | |
361 | return Poll::Ready(Some(Ok(frame))); | |
362 | } | |
363 | } | |
364 | } | |
365 | } | |
366 | ||
367 | fn map_err(err: io::Error) -> Error { | |
368 | if let io::ErrorKind::InvalidData = err.kind() { | |
369 | if let Some(custom) = err.get_ref() { | |
370 | if custom.is::<LengthDelimitedCodecError>() { | |
371 | return Error::library_go_away(Reason::FRAME_SIZE_ERROR); | |
372 | } | |
373 | } | |
374 | } | |
375 | err.into() | |
376 | } | |
377 | ||
378 | // ===== impl Continuable ===== | |
379 | ||
380 | impl Continuable { | |
381 | fn stream_id(&self) -> frame::StreamId { | |
382 | match *self { | |
383 | Continuable::Headers(ref h) => h.stream_id(), | |
384 | Continuable::PushPromise(ref p) => p.stream_id(), | |
385 | } | |
386 | } | |
387 | ||
388 | fn is_over_size(&self) -> bool { | |
389 | match *self { | |
390 | Continuable::Headers(ref h) => h.is_over_size(), | |
391 | Continuable::PushPromise(ref p) => p.is_over_size(), | |
392 | } | |
393 | } | |
394 | ||
395 | fn load_hpack( | |
396 | &mut self, | |
397 | src: &mut BytesMut, | |
398 | max_header_list_size: usize, | |
399 | decoder: &mut hpack::Decoder, | |
400 | ) -> Result<(), frame::Error> { | |
401 | match *self { | |
402 | Continuable::Headers(ref mut h) => h.load_hpack(src, max_header_list_size, decoder), | |
403 | Continuable::PushPromise(ref mut p) => p.load_hpack(src, max_header_list_size, decoder), | |
404 | } | |
405 | } | |
406 | } | |
407 | ||
408 | impl<T> From<Continuable> for Frame<T> { | |
409 | fn from(cont: Continuable) -> Self { | |
410 | match cont { | |
411 | Continuable::Headers(mut headers) => { | |
412 | headers.set_end_headers(); | |
413 | headers.into() | |
414 | } | |
415 | Continuable::PushPromise(mut push) => { | |
416 | push.set_end_headers(); | |
417 | push.into() | |
418 | } | |
419 | } | |
420 | } | |
421 | } |