]> git.proxmox.com Git - rustc.git/blob - vendor/hyper/src/client/connect/http.rs
New upstream version 1.73.0+dfsg1
[rustc.git] / vendor / hyper / src / client / connect / http.rs
1 use std::error::Error as StdError;
2 use std::fmt;
3 use std::future::Future;
4 use std::io;
5 use std::marker::PhantomData;
6 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
7 use std::pin::Pin;
8 use std::sync::Arc;
9 use std::task::{self, Poll};
10 use std::time::Duration;
11
12 use futures_util::future::Either;
13 use http::uri::{Scheme, Uri};
14 use pin_project_lite::pin_project;
15 use tokio::net::{TcpSocket, TcpStream};
16 use tokio::time::Sleep;
17 use tracing::{debug, trace, warn};
18
19 use super::dns::{self, resolve, GaiResolver, Resolve};
20 use super::{Connected, Connection};
21 //#[cfg(feature = "runtime")] use super::dns::TokioThreadpoolGaiResolver;
22
23 /// A connector for the `http` scheme.
24 ///
25 /// Performs DNS resolution in a thread pool, and then connects over TCP.
26 ///
27 /// # Note
28 ///
29 /// Sets the [`HttpInfo`](HttpInfo) value on responses, which includes
30 /// transport information such as the remote socket address used.
31 #[cfg_attr(docsrs, doc(cfg(feature = "tcp")))]
32 #[derive(Clone)]
33 pub struct HttpConnector<R = GaiResolver> {
34 config: Arc<Config>,
35 resolver: R,
36 }
37
38 /// Extra information about the transport when an HttpConnector is used.
39 ///
40 /// # Example
41 ///
42 /// ```
43 /// # async fn doc() -> hyper::Result<()> {
44 /// use hyper::Uri;
45 /// use hyper::client::{Client, connect::HttpInfo};
46 ///
47 /// let client = Client::new();
48 /// let uri = Uri::from_static("http://example.com");
49 ///
50 /// let res = client.get(uri).await?;
51 /// res
52 /// .extensions()
53 /// .get::<HttpInfo>()
54 /// .map(|info| {
55 /// println!("remote addr = {}", info.remote_addr());
56 /// });
57 /// # Ok(())
58 /// # }
59 /// ```
60 ///
61 /// # Note
62 ///
63 /// If a different connector is used besides [`HttpConnector`](HttpConnector),
64 /// this value will not exist in the extensions. Consult that specific
65 /// connector to see what "extra" information it might provide to responses.
66 #[derive(Clone, Debug)]
67 pub struct HttpInfo {
68 remote_addr: SocketAddr,
69 local_addr: SocketAddr,
70 }
71
72 #[derive(Clone)]
73 struct Config {
74 connect_timeout: Option<Duration>,
75 enforce_http: bool,
76 happy_eyeballs_timeout: Option<Duration>,
77 keep_alive_timeout: Option<Duration>,
78 local_address_ipv4: Option<Ipv4Addr>,
79 local_address_ipv6: Option<Ipv6Addr>,
80 nodelay: bool,
81 reuse_address: bool,
82 send_buffer_size: Option<usize>,
83 recv_buffer_size: Option<usize>,
84 }
85
86 // ===== impl HttpConnector =====
87
88 impl HttpConnector {
89 /// Construct a new HttpConnector.
90 pub fn new() -> HttpConnector {
91 HttpConnector::new_with_resolver(GaiResolver::new())
92 }
93 }
94
95 /*
96 #[cfg(feature = "runtime")]
97 impl HttpConnector<TokioThreadpoolGaiResolver> {
98 /// Construct a new HttpConnector using the `TokioThreadpoolGaiResolver`.
99 ///
100 /// This resolver **requires** the threadpool runtime to be used.
101 pub fn new_with_tokio_threadpool_resolver() -> Self {
102 HttpConnector::new_with_resolver(TokioThreadpoolGaiResolver::new())
103 }
104 }
105 */
106
107 impl<R> HttpConnector<R> {
108 /// Construct a new HttpConnector.
109 ///
110 /// Takes a [`Resolver`](crate::client::connect::dns#resolvers-are-services) to handle DNS lookups.
111 pub fn new_with_resolver(resolver: R) -> HttpConnector<R> {
112 HttpConnector {
113 config: Arc::new(Config {
114 connect_timeout: None,
115 enforce_http: true,
116 happy_eyeballs_timeout: Some(Duration::from_millis(300)),
117 keep_alive_timeout: None,
118 local_address_ipv4: None,
119 local_address_ipv6: None,
120 nodelay: false,
121 reuse_address: false,
122 send_buffer_size: None,
123 recv_buffer_size: None,
124 }),
125 resolver,
126 }
127 }
128
129 /// Option to enforce all `Uri`s have the `http` scheme.
130 ///
131 /// Enabled by default.
132 #[inline]
133 pub fn enforce_http(&mut self, is_enforced: bool) {
134 self.config_mut().enforce_http = is_enforced;
135 }
136
137 /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration.
138 ///
139 /// If `None`, the option will not be set.
140 ///
141 /// Default is `None`.
142 #[inline]
143 pub fn set_keepalive(&mut self, dur: Option<Duration>) {
144 self.config_mut().keep_alive_timeout = dur;
145 }
146
147 /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`.
148 ///
149 /// Default is `false`.
150 #[inline]
151 pub fn set_nodelay(&mut self, nodelay: bool) {
152 self.config_mut().nodelay = nodelay;
153 }
154
155 /// Sets the value of the SO_SNDBUF option on the socket.
156 #[inline]
157 pub fn set_send_buffer_size(&mut self, size: Option<usize>) {
158 self.config_mut().send_buffer_size = size;
159 }
160
161 /// Sets the value of the SO_RCVBUF option on the socket.
162 #[inline]
163 pub fn set_recv_buffer_size(&mut self, size: Option<usize>) {
164 self.config_mut().recv_buffer_size = size;
165 }
166
167 /// Set that all sockets are bound to the configured address before connection.
168 ///
169 /// If `None`, the sockets will not be bound.
170 ///
171 /// Default is `None`.
172 #[inline]
173 pub fn set_local_address(&mut self, addr: Option<IpAddr>) {
174 let (v4, v6) = match addr {
175 Some(IpAddr::V4(a)) => (Some(a), None),
176 Some(IpAddr::V6(a)) => (None, Some(a)),
177 _ => (None, None),
178 };
179
180 let cfg = self.config_mut();
181
182 cfg.local_address_ipv4 = v4;
183 cfg.local_address_ipv6 = v6;
184 }
185
186 /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's
187 /// preferences) before connection.
188 #[inline]
189 pub fn set_local_addresses(&mut self, addr_ipv4: Ipv4Addr, addr_ipv6: Ipv6Addr) {
190 let cfg = self.config_mut();
191
192 cfg.local_address_ipv4 = Some(addr_ipv4);
193 cfg.local_address_ipv6 = Some(addr_ipv6);
194 }
195
196 /// Set the connect timeout.
197 ///
198 /// If a domain resolves to multiple IP addresses, the timeout will be
199 /// evenly divided across them.
200 ///
201 /// Default is `None`.
202 #[inline]
203 pub fn set_connect_timeout(&mut self, dur: Option<Duration>) {
204 self.config_mut().connect_timeout = dur;
205 }
206
207 /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm.
208 ///
209 /// If hostname resolves to both IPv4 and IPv6 addresses and connection
210 /// cannot be established using preferred address family before timeout
211 /// elapses, then connector will in parallel attempt connection using other
212 /// address family.
213 ///
214 /// If `None`, parallel connection attempts are disabled.
215 ///
216 /// Default is 300 milliseconds.
217 ///
218 /// [RFC 6555]: https://tools.ietf.org/html/rfc6555
219 #[inline]
220 pub fn set_happy_eyeballs_timeout(&mut self, dur: Option<Duration>) {
221 self.config_mut().happy_eyeballs_timeout = dur;
222 }
223
224 /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`.
225 ///
226 /// Default is `false`.
227 #[inline]
228 pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self {
229 self.config_mut().reuse_address = reuse_address;
230 self
231 }
232
233 // private
234
235 fn config_mut(&mut self) -> &mut Config {
236 // If the are HttpConnector clones, this will clone the inner
237 // config. So mutating the config won't ever affect previous
238 // clones.
239 Arc::make_mut(&mut self.config)
240 }
241 }
242
243 static INVALID_NOT_HTTP: &str = "invalid URL, scheme is not http";
244 static INVALID_MISSING_SCHEME: &str = "invalid URL, scheme is missing";
245 static INVALID_MISSING_HOST: &str = "invalid URL, host is missing";
246
247 // R: Debug required for now to allow adding it to debug output later...
248 impl<R: fmt::Debug> fmt::Debug for HttpConnector<R> {
249 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
250 f.debug_struct("HttpConnector").finish()
251 }
252 }
253
254 impl<R> tower_service::Service<Uri> for HttpConnector<R>
255 where
256 R: Resolve + Clone + Send + Sync + 'static,
257 R::Future: Send,
258 {
259 type Response = TcpStream;
260 type Error = ConnectError;
261 type Future = HttpConnecting<R>;
262
263 fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
264 ready!(self.resolver.poll_ready(cx)).map_err(ConnectError::dns)?;
265 Poll::Ready(Ok(()))
266 }
267
268 fn call(&mut self, dst: Uri) -> Self::Future {
269 let mut self_ = self.clone();
270 HttpConnecting {
271 fut: Box::pin(async move { self_.call_async(dst).await }),
272 _marker: PhantomData,
273 }
274 }
275 }
276
277 fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> {
278 trace!(
279 "Http::connect; scheme={:?}, host={:?}, port={:?}",
280 dst.scheme(),
281 dst.host(),
282 dst.port(),
283 );
284
285 if config.enforce_http {
286 if dst.scheme() != Some(&Scheme::HTTP) {
287 return Err(ConnectError {
288 msg: INVALID_NOT_HTTP.into(),
289 cause: None,
290 });
291 }
292 } else if dst.scheme().is_none() {
293 return Err(ConnectError {
294 msg: INVALID_MISSING_SCHEME.into(),
295 cause: None,
296 });
297 }
298
299 let host = match dst.host() {
300 Some(s) => s,
301 None => {
302 return Err(ConnectError {
303 msg: INVALID_MISSING_HOST.into(),
304 cause: None,
305 })
306 }
307 };
308 let port = match dst.port() {
309 Some(port) => port.as_u16(),
310 None => {
311 if dst.scheme() == Some(&Scheme::HTTPS) {
312 443
313 } else {
314 80
315 }
316 }
317 };
318
319 Ok((host, port))
320 }
321
322 impl<R> HttpConnector<R>
323 where
324 R: Resolve,
325 {
326 async fn call_async(&mut self, dst: Uri) -> Result<TcpStream, ConnectError> {
327 let config = &self.config;
328
329 let (host, port) = get_host_port(config, &dst)?;
330 let host = host.trim_start_matches('[').trim_end_matches(']');
331
332 // If the host is already an IP addr (v4 or v6),
333 // skip resolving the dns and start connecting right away.
334 let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) {
335 addrs
336 } else {
337 let addrs = resolve(&mut self.resolver, dns::Name::new(host.into()))
338 .await
339 .map_err(ConnectError::dns)?;
340 let addrs = addrs
341 .map(|mut addr| {
342 addr.set_port(port);
343 addr
344 })
345 .collect();
346 dns::SocketAddrs::new(addrs)
347 };
348
349 let c = ConnectingTcp::new(addrs, config);
350
351 let sock = c.connect().await?;
352
353 if let Err(e) = sock.set_nodelay(config.nodelay) {
354 warn!("tcp set_nodelay error: {}", e);
355 }
356
357 Ok(sock)
358 }
359 }
360
361 impl Connection for TcpStream {
362 fn connected(&self) -> Connected {
363 let connected = Connected::new();
364 if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) {
365 connected.extra(HttpInfo { remote_addr, local_addr })
366 } else {
367 connected
368 }
369 }
370 }
371
372 impl HttpInfo {
373 /// Get the remote address of the transport used.
374 pub fn remote_addr(&self) -> SocketAddr {
375 self.remote_addr
376 }
377
378 /// Get the local address of the transport used.
379 pub fn local_addr(&self) -> SocketAddr {
380 self.local_addr
381 }
382 }
383
384 pin_project! {
385 // Not publicly exported (so missing_docs doesn't trigger).
386 //
387 // We return this `Future` instead of the `Pin<Box<dyn Future>>` directly
388 // so that users don't rely on it fitting in a `Pin<Box<dyn Future>>` slot
389 // (and thus we can change the type in the future).
390 #[must_use = "futures do nothing unless polled"]
391 #[allow(missing_debug_implementations)]
392 pub struct HttpConnecting<R> {
393 #[pin]
394 fut: BoxConnecting,
395 _marker: PhantomData<R>,
396 }
397 }
398
399 type ConnectResult = Result<TcpStream, ConnectError>;
400 type BoxConnecting = Pin<Box<dyn Future<Output = ConnectResult> + Send>>;
401
402 impl<R: Resolve> Future for HttpConnecting<R> {
403 type Output = ConnectResult;
404
405 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
406 self.project().fut.poll(cx)
407 }
408 }
409
410 // Not publicly exported (so missing_docs doesn't trigger).
411 pub struct ConnectError {
412 msg: Box<str>,
413 cause: Option<Box<dyn StdError + Send + Sync>>,
414 }
415
416 impl ConnectError {
417 fn new<S, E>(msg: S, cause: E) -> ConnectError
418 where
419 S: Into<Box<str>>,
420 E: Into<Box<dyn StdError + Send + Sync>>,
421 {
422 ConnectError {
423 msg: msg.into(),
424 cause: Some(cause.into()),
425 }
426 }
427
428 fn dns<E>(cause: E) -> ConnectError
429 where
430 E: Into<Box<dyn StdError + Send + Sync>>,
431 {
432 ConnectError::new("dns error", cause)
433 }
434
435 fn m<S, E>(msg: S) -> impl FnOnce(E) -> ConnectError
436 where
437 S: Into<Box<str>>,
438 E: Into<Box<dyn StdError + Send + Sync>>,
439 {
440 move |cause| ConnectError::new(msg, cause)
441 }
442 }
443
444 impl fmt::Debug for ConnectError {
445 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
446 if let Some(ref cause) = self.cause {
447 f.debug_tuple("ConnectError")
448 .field(&self.msg)
449 .field(cause)
450 .finish()
451 } else {
452 self.msg.fmt(f)
453 }
454 }
455 }
456
457 impl fmt::Display for ConnectError {
458 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
459 f.write_str(&self.msg)?;
460
461 if let Some(ref cause) = self.cause {
462 write!(f, ": {}", cause)?;
463 }
464
465 Ok(())
466 }
467 }
468
469 impl StdError for ConnectError {
470 fn source(&self) -> Option<&(dyn StdError + 'static)> {
471 self.cause.as_ref().map(|e| &**e as _)
472 }
473 }
474
475 struct ConnectingTcp<'a> {
476 preferred: ConnectingTcpRemote,
477 fallback: Option<ConnectingTcpFallback>,
478 config: &'a Config,
479 }
480
481 impl<'a> ConnectingTcp<'a> {
482 fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self {
483 if let Some(fallback_timeout) = config.happy_eyeballs_timeout {
484 let (preferred_addrs, fallback_addrs) = remote_addrs
485 .split_by_preference(config.local_address_ipv4, config.local_address_ipv6);
486 if fallback_addrs.is_empty() {
487 return ConnectingTcp {
488 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
489 fallback: None,
490 config,
491 };
492 }
493
494 ConnectingTcp {
495 preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout),
496 fallback: Some(ConnectingTcpFallback {
497 delay: tokio::time::sleep(fallback_timeout),
498 remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout),
499 }),
500 config,
501 }
502 } else {
503 ConnectingTcp {
504 preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout),
505 fallback: None,
506 config,
507 }
508 }
509 }
510 }
511
512 struct ConnectingTcpFallback {
513 delay: Sleep,
514 remote: ConnectingTcpRemote,
515 }
516
517 struct ConnectingTcpRemote {
518 addrs: dns::SocketAddrs,
519 connect_timeout: Option<Duration>,
520 }
521
522 impl ConnectingTcpRemote {
523 fn new(addrs: dns::SocketAddrs, connect_timeout: Option<Duration>) -> Self {
524 let connect_timeout = connect_timeout.map(|t| t / (addrs.len() as u32));
525
526 Self {
527 addrs,
528 connect_timeout,
529 }
530 }
531 }
532
533 impl ConnectingTcpRemote {
534 async fn connect(&mut self, config: &Config) -> Result<TcpStream, ConnectError> {
535 let mut err = None;
536 for addr in &mut self.addrs {
537 debug!("connecting to {}", addr);
538 match connect(&addr, config, self.connect_timeout)?.await {
539 Ok(tcp) => {
540 debug!("connected to {}", addr);
541 return Ok(tcp);
542 }
543 Err(e) => {
544 trace!("connect error for {}: {:?}", addr, e);
545 err = Some(e);
546 }
547 }
548 }
549
550 match err {
551 Some(e) => Err(e),
552 None => Err(ConnectError::new(
553 "tcp connect error",
554 std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"),
555 )),
556 }
557 }
558 }
559
560 fn bind_local_address(
561 socket: &socket2::Socket,
562 dst_addr: &SocketAddr,
563 local_addr_ipv4: &Option<Ipv4Addr>,
564 local_addr_ipv6: &Option<Ipv6Addr>,
565 ) -> io::Result<()> {
566 match (*dst_addr, local_addr_ipv4, local_addr_ipv6) {
567 (SocketAddr::V4(_), Some(addr), _) => {
568 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
569 }
570 (SocketAddr::V6(_), _, Some(addr)) => {
571 socket.bind(&SocketAddr::new(addr.clone().into(), 0).into())?;
572 }
573 _ => {
574 if cfg!(windows) {
575 // Windows requires a socket be bound before calling connect
576 let any: SocketAddr = match *dst_addr {
577 SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(),
578 SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(),
579 };
580 socket.bind(&any.into())?;
581 }
582 }
583 }
584
585 Ok(())
586 }
587
588 fn connect(
589 addr: &SocketAddr,
590 config: &Config,
591 connect_timeout: Option<Duration>,
592 ) -> Result<impl Future<Output = Result<TcpStream, ConnectError>>, ConnectError> {
593 // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the
594 // keepalive timeout, it would be nice to use that instead of socket2,
595 // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance...
596 use socket2::{Domain, Protocol, Socket, TcpKeepalive, Type};
597 use std::convert::TryInto;
598
599 let domain = Domain::for_address(*addr);
600 let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
601 .map_err(ConnectError::m("tcp open error"))?;
602
603 // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is
604 // responsible for ensuring O_NONBLOCK is set.
605 socket
606 .set_nonblocking(true)
607 .map_err(ConnectError::m("tcp set_nonblocking error"))?;
608
609 if let Some(dur) = config.keep_alive_timeout {
610 let conf = TcpKeepalive::new().with_time(dur);
611 if let Err(e) = socket.set_tcp_keepalive(&conf) {
612 warn!("tcp set_keepalive error: {}", e);
613 }
614 }
615
616 bind_local_address(
617 &socket,
618 addr,
619 &config.local_address_ipv4,
620 &config.local_address_ipv6,
621 )
622 .map_err(ConnectError::m("tcp bind local error"))?;
623
624 #[cfg(unix)]
625 let socket = unsafe {
626 // Safety: `from_raw_fd` is only safe to call if ownership of the raw
627 // file descriptor is transferred. Since we call `into_raw_fd` on the
628 // socket2 socket, it gives up ownership of the fd and will not close
629 // it, so this is safe.
630 use std::os::unix::io::{FromRawFd, IntoRawFd};
631 TcpSocket::from_raw_fd(socket.into_raw_fd())
632 };
633 #[cfg(windows)]
634 let socket = unsafe {
635 // Safety: `from_raw_socket` is only safe to call if ownership of the raw
636 // Windows SOCKET is transferred. Since we call `into_raw_socket` on the
637 // socket2 socket, it gives up ownership of the SOCKET and will not close
638 // it, so this is safe.
639 use std::os::windows::io::{FromRawSocket, IntoRawSocket};
640 TcpSocket::from_raw_socket(socket.into_raw_socket())
641 };
642
643 if config.reuse_address {
644 if let Err(e) = socket.set_reuseaddr(true) {
645 warn!("tcp set_reuse_address error: {}", e);
646 }
647 }
648
649 if let Some(size) = config.send_buffer_size {
650 if let Err(e) = socket.set_send_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
651 warn!("tcp set_buffer_size error: {}", e);
652 }
653 }
654
655 if let Some(size) = config.recv_buffer_size {
656 if let Err(e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(std::u32::MAX)) {
657 warn!("tcp set_recv_buffer_size error: {}", e);
658 }
659 }
660
661 let connect = socket.connect(*addr);
662 Ok(async move {
663 match connect_timeout {
664 Some(dur) => match tokio::time::timeout(dur, connect).await {
665 Ok(Ok(s)) => Ok(s),
666 Ok(Err(e)) => Err(e),
667 Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)),
668 },
669 None => connect.await,
670 }
671 .map_err(ConnectError::m("tcp connect error"))
672 })
673 }
674
675 impl ConnectingTcp<'_> {
676 async fn connect(mut self) -> Result<TcpStream, ConnectError> {
677 match self.fallback {
678 None => self.preferred.connect(self.config).await,
679 Some(mut fallback) => {
680 let preferred_fut = self.preferred.connect(self.config);
681 futures_util::pin_mut!(preferred_fut);
682
683 let fallback_fut = fallback.remote.connect(self.config);
684 futures_util::pin_mut!(fallback_fut);
685
686 let fallback_delay = fallback.delay;
687 futures_util::pin_mut!(fallback_delay);
688
689 let (result, future) =
690 match futures_util::future::select(preferred_fut, fallback_delay).await {
691 Either::Left((result, _fallback_delay)) => {
692 (result, Either::Right(fallback_fut))
693 }
694 Either::Right(((), preferred_fut)) => {
695 // Delay is done, start polling both the preferred and the fallback
696 futures_util::future::select(preferred_fut, fallback_fut)
697 .await
698 .factor_first()
699 }
700 };
701
702 if result.is_err() {
703 // Fallback to the remaining future (could be preferred or fallback)
704 // if we get an error
705 future.await
706 } else {
707 result
708 }
709 }
710 }
711 }
712 }
713
714 #[cfg(test)]
715 mod tests {
716 use std::io;
717
718 use ::http::Uri;
719
720 use super::super::sealed::{Connect, ConnectSvc};
721 use super::{Config, ConnectError, HttpConnector};
722
723 async fn connect<C>(
724 connector: C,
725 dst: Uri,
726 ) -> Result<<C::_Svc as ConnectSvc>::Connection, <C::_Svc as ConnectSvc>::Error>
727 where
728 C: Connect,
729 {
730 connector.connect(super::super::sealed::Internal, dst).await
731 }
732
733 #[tokio::test]
734 async fn test_errors_enforce_http() {
735 let dst = "https://example.domain/foo/bar?baz".parse().unwrap();
736 let connector = HttpConnector::new();
737
738 let err = connect(connector, dst).await.unwrap_err();
739 assert_eq!(&*err.msg, super::INVALID_NOT_HTTP);
740 }
741
742 #[cfg(any(target_os = "linux", target_os = "macos"))]
743 fn get_local_ips() -> (Option<std::net::Ipv4Addr>, Option<std::net::Ipv6Addr>) {
744 use std::net::{IpAddr, TcpListener};
745
746 let mut ip_v4 = None;
747 let mut ip_v6 = None;
748
749 let ips = pnet_datalink::interfaces()
750 .into_iter()
751 .flat_map(|i| i.ips.into_iter().map(|n| n.ip()));
752
753 for ip in ips {
754 match ip {
755 IpAddr::V4(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v4 = Some(ip),
756 IpAddr::V6(ip) if TcpListener::bind((ip, 0)).is_ok() => ip_v6 = Some(ip),
757 _ => (),
758 }
759
760 if ip_v4.is_some() && ip_v6.is_some() {
761 break;
762 }
763 }
764
765 (ip_v4, ip_v6)
766 }
767
768 #[tokio::test]
769 async fn test_errors_missing_scheme() {
770 let dst = "example.domain".parse().unwrap();
771 let mut connector = HttpConnector::new();
772 connector.enforce_http(false);
773
774 let err = connect(connector, dst).await.unwrap_err();
775 assert_eq!(&*err.msg, super::INVALID_MISSING_SCHEME);
776 }
777
778 // NOTE: pnet crate that we use in this test doesn't compile on Windows
779 #[cfg(any(target_os = "linux", target_os = "macos"))]
780 #[tokio::test]
781 async fn local_address() {
782 use std::net::{IpAddr, TcpListener};
783 let _ = pretty_env_logger::try_init();
784
785 let (bind_ip_v4, bind_ip_v6) = get_local_ips();
786 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
787 let port = server4.local_addr().unwrap().port();
788 let server6 = TcpListener::bind(&format!("[::1]:{}", port)).unwrap();
789
790 let assert_client_ip = |dst: String, server: TcpListener, expected_ip: IpAddr| async move {
791 let mut connector = HttpConnector::new();
792
793 match (bind_ip_v4, bind_ip_v6) {
794 (Some(v4), Some(v6)) => connector.set_local_addresses(v4, v6),
795 (Some(v4), None) => connector.set_local_address(Some(v4.into())),
796 (None, Some(v6)) => connector.set_local_address(Some(v6.into())),
797 _ => unreachable!(),
798 }
799
800 connect(connector, dst.parse().unwrap()).await.unwrap();
801
802 let (_, client_addr) = server.accept().unwrap();
803
804 assert_eq!(client_addr.ip(), expected_ip);
805 };
806
807 if let Some(ip) = bind_ip_v4 {
808 assert_client_ip(format!("http://127.0.0.1:{}", port), server4, ip.into()).await;
809 }
810
811 if let Some(ip) = bind_ip_v6 {
812 assert_client_ip(format!("http://[::1]:{}", port), server6, ip.into()).await;
813 }
814 }
815
816 #[test]
817 #[cfg_attr(not(feature = "__internal_happy_eyeballs_tests"), ignore)]
818 fn client_happy_eyeballs() {
819 use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, TcpListener};
820 use std::time::{Duration, Instant};
821
822 use super::dns;
823 use super::ConnectingTcp;
824
825 let _ = pretty_env_logger::try_init();
826 let server4 = TcpListener::bind("127.0.0.1:0").unwrap();
827 let addr = server4.local_addr().unwrap();
828 let _server6 = TcpListener::bind(&format!("[::1]:{}", addr.port())).unwrap();
829 let rt = tokio::runtime::Builder::new_current_thread()
830 .enable_all()
831 .build()
832 .unwrap();
833
834 let local_timeout = Duration::default();
835 let unreachable_v4_timeout = measure_connect(unreachable_ipv4_addr()).1;
836 let unreachable_v6_timeout = measure_connect(unreachable_ipv6_addr()).1;
837 let fallback_timeout = std::cmp::max(unreachable_v4_timeout, unreachable_v6_timeout)
838 + Duration::from_millis(250);
839
840 let scenarios = &[
841 // Fast primary, without fallback.
842 (&[local_ipv4_addr()][..], 4, local_timeout, false),
843 (&[local_ipv6_addr()][..], 6, local_timeout, false),
844 // Fast primary, with (unused) fallback.
845 (
846 &[local_ipv4_addr(), local_ipv6_addr()][..],
847 4,
848 local_timeout,
849 false,
850 ),
851 (
852 &[local_ipv6_addr(), local_ipv4_addr()][..],
853 6,
854 local_timeout,
855 false,
856 ),
857 // Unreachable + fast primary, without fallback.
858 (
859 &[unreachable_ipv4_addr(), local_ipv4_addr()][..],
860 4,
861 unreachable_v4_timeout,
862 false,
863 ),
864 (
865 &[unreachable_ipv6_addr(), local_ipv6_addr()][..],
866 6,
867 unreachable_v6_timeout,
868 false,
869 ),
870 // Unreachable + fast primary, with (unused) fallback.
871 (
872 &[
873 unreachable_ipv4_addr(),
874 local_ipv4_addr(),
875 local_ipv6_addr(),
876 ][..],
877 4,
878 unreachable_v4_timeout,
879 false,
880 ),
881 (
882 &[
883 unreachable_ipv6_addr(),
884 local_ipv6_addr(),
885 local_ipv4_addr(),
886 ][..],
887 6,
888 unreachable_v6_timeout,
889 true,
890 ),
891 // Slow primary, with (used) fallback.
892 (
893 &[slow_ipv4_addr(), local_ipv4_addr(), local_ipv6_addr()][..],
894 6,
895 fallback_timeout,
896 false,
897 ),
898 (
899 &[slow_ipv6_addr(), local_ipv6_addr(), local_ipv4_addr()][..],
900 4,
901 fallback_timeout,
902 true,
903 ),
904 // Slow primary, with (used) unreachable + fast fallback.
905 (
906 &[slow_ipv4_addr(), unreachable_ipv6_addr(), local_ipv6_addr()][..],
907 6,
908 fallback_timeout + unreachable_v6_timeout,
909 false,
910 ),
911 (
912 &[slow_ipv6_addr(), unreachable_ipv4_addr(), local_ipv4_addr()][..],
913 4,
914 fallback_timeout + unreachable_v4_timeout,
915 true,
916 ),
917 ];
918
919 // Scenarios for IPv6 -> IPv4 fallback require that host can access IPv6 network.
920 // Otherwise, connection to "slow" IPv6 address will error-out immediately.
921 let ipv6_accessible = measure_connect(slow_ipv6_addr()).0;
922
923 for &(hosts, family, timeout, needs_ipv6_access) in scenarios {
924 if needs_ipv6_access && !ipv6_accessible {
925 continue;
926 }
927
928 let (start, stream) = rt
929 .block_on(async move {
930 let addrs = hosts
931 .iter()
932 .map(|host| (host.clone(), addr.port()).into())
933 .collect();
934 let cfg = Config {
935 local_address_ipv4: None,
936 local_address_ipv6: None,
937 connect_timeout: None,
938 keep_alive_timeout: None,
939 happy_eyeballs_timeout: Some(fallback_timeout),
940 nodelay: false,
941 reuse_address: false,
942 enforce_http: false,
943 send_buffer_size: None,
944 recv_buffer_size: None,
945 };
946 let connecting_tcp = ConnectingTcp::new(dns::SocketAddrs::new(addrs), &cfg);
947 let start = Instant::now();
948 Ok::<_, ConnectError>((start, ConnectingTcp::connect(connecting_tcp).await?))
949 })
950 .unwrap();
951 let res = if stream.peer_addr().unwrap().is_ipv4() {
952 4
953 } else {
954 6
955 };
956 let duration = start.elapsed();
957
958 // Allow actual duration to be +/- 150ms off.
959 let min_duration = if timeout >= Duration::from_millis(150) {
960 timeout - Duration::from_millis(150)
961 } else {
962 Duration::default()
963 };
964 let max_duration = timeout + Duration::from_millis(150);
965
966 assert_eq!(res, family);
967 assert!(duration >= min_duration);
968 assert!(duration <= max_duration);
969 }
970
971 fn local_ipv4_addr() -> IpAddr {
972 Ipv4Addr::new(127, 0, 0, 1).into()
973 }
974
975 fn local_ipv6_addr() -> IpAddr {
976 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1).into()
977 }
978
979 fn unreachable_ipv4_addr() -> IpAddr {
980 Ipv4Addr::new(127, 0, 0, 2).into()
981 }
982
983 fn unreachable_ipv6_addr() -> IpAddr {
984 Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 2).into()
985 }
986
987 fn slow_ipv4_addr() -> IpAddr {
988 // RFC 6890 reserved IPv4 address.
989 Ipv4Addr::new(198, 18, 0, 25).into()
990 }
991
992 fn slow_ipv6_addr() -> IpAddr {
993 // RFC 6890 reserved IPv6 address.
994 Ipv6Addr::new(2001, 2, 0, 0, 0, 0, 0, 254).into()
995 }
996
997 fn measure_connect(addr: IpAddr) -> (bool, Duration) {
998 let start = Instant::now();
999 let result =
1000 std::net::TcpStream::connect_timeout(&(addr, 80).into(), Duration::from_secs(1));
1001
1002 let reachable = result.is_ok() || result.unwrap_err().kind() == io::ErrorKind::TimedOut;
1003 let duration = start.elapsed();
1004 (reachable, duration)
1005 }
1006 }
1007 }