]>
Commit | Line | Data |
---|---|---|
89d25b19 | 1 | use std::pin::Pin; |
4805edc4 | 2 | use std::task::{Context, Poll}; |
89d25b19 | 3 | |
4805edc4 WB |
4 | use anyhow::{bail, format_err, Error}; |
5 | use futures::*; | |
89d25b19 SR |
6 | use http::Uri; |
7 | use http::{Request, Response}; | |
8 | use hyper::client::connect::{Connected, Connection}; | |
9 | use hyper::client::Client; | |
10 | use hyper::Body; | |
71549afa | 11 | use pin_project_lite::pin_project; |
89d25b19 | 12 | use serde_json::Value; |
971bc6f9 | 13 | use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}; |
89d25b19 SR |
14 | use tokio::net::UnixStream; |
15 | ||
2b9cf927 | 16 | use proxmox_http::uri::json_object_to_query; |
6ef1b649 | 17 | use proxmox_router::HttpError; |
89d25b19 | 18 | |
89d25b19 SR |
19 | pub const DEFAULT_VSOCK_PORT: u16 = 807; |
20 | ||
21 | #[derive(Clone)] | |
22 | struct VsockConnector; | |
23 | ||
71549afa TL |
24 | pin_project! { |
25 | /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection | |
26 | struct UnixConnection { | |
27 | #[pin] | |
28 | stream: UnixStream, | |
29 | } | |
89d25b19 SR |
30 | } |
31 | ||
32 | impl tower_service::Service<Uri> for VsockConnector { | |
33 | type Response = UnixConnection; | |
34 | type Error = Error; | |
35 | type Future = Pin<Box<dyn Future<Output = Result<UnixConnection, Error>> + Send>>; | |
36 | ||
37 | fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> { | |
38 | Poll::Ready(Ok(())) | |
39 | } | |
40 | ||
41 | fn call(&mut self, dst: Uri) -> Self::Future { | |
42 | use nix::sys::socket::*; | |
43 | use std::os::unix::io::FromRawFd; | |
44 | ||
45 | // connect can block, so run in blocking task (though in reality it seems to immediately | |
46 | // return with either ENODEV or ETIMEDOUT in case of error) | |
47 | tokio::task::spawn_blocking(move || { | |
48 | if dst.scheme_str().unwrap_or_default() != "vsock" { | |
49 | bail!("invalid URI (scheme) for vsock connector: {}", dst); | |
50 | } | |
51 | ||
52 | let cid = match dst.host() { | |
53 | Some(host) => host.parse().map_err(|err| { | |
54 | format_err!( | |
55 | "invalid URI (host not a number) for vsock connector: {} ({})", | |
56 | dst, | |
57 | err | |
58 | ) | |
59 | })?, | |
60 | None => bail!("invalid URI (no host) for vsock connector: {}", dst), | |
61 | }; | |
62 | ||
63 | let port = match dst.port_u16() { | |
64 | Some(port) => port, | |
65 | None => bail!("invalid URI (bad port) for vsock connector: {}", dst), | |
66 | }; | |
67 | ||
68 | let sock_fd = socket( | |
69 | AddressFamily::Vsock, | |
70 | SockType::Stream, | |
71 | SockFlag::empty(), | |
72 | None, | |
73 | )?; | |
74 | ||
75 | let sock_addr = VsockAddr::new(cid, port as u32); | |
11ca8343 | 76 | connect(sock_fd, &sock_addr)?; |
89d25b19 SR |
77 | |
78 | // connect sync, but set nonblock after (tokio requires it) | |
79 | let std_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) }; | |
80 | std_stream.set_nonblocking(true)?; | |
81 | ||
82 | let stream = tokio::net::UnixStream::from_std(std_stream)?; | |
83 | let connection = UnixConnection { stream }; | |
84 | ||
85 | Ok(connection) | |
86 | }) | |
d1d74c43 | 87 | // unravel the thread JoinHandle to a usable future |
89d25b19 SR |
88 | .map(|res| match res { |
89 | Ok(res) => res, | |
90 | Err(err) => Err(format_err!("thread join error on vsock connect: {}", err)), | |
91 | }) | |
92 | .boxed() | |
93 | } | |
94 | } | |
95 | ||
96 | impl Connection for UnixConnection { | |
97 | fn connected(&self) -> Connected { | |
98 | Connected::new() | |
99 | } | |
100 | } | |
101 | ||
102 | impl AsyncRead for UnixConnection { | |
103 | fn poll_read( | |
104 | self: Pin<&mut Self>, | |
105 | cx: &mut Context<'_>, | |
106 | buf: &mut ReadBuf, | |
107 | ) -> Poll<Result<(), std::io::Error>> { | |
108 | let this = self.project(); | |
109 | this.stream.poll_read(cx, buf) | |
110 | } | |
111 | } | |
112 | ||
113 | impl AsyncWrite for UnixConnection { | |
114 | fn poll_write( | |
115 | self: Pin<&mut Self>, | |
116 | cx: &mut Context<'_>, | |
117 | buf: &[u8], | |
118 | ) -> Poll<tokio::io::Result<usize>> { | |
119 | let this = self.project(); | |
120 | this.stream.poll_write(cx, buf) | |
121 | } | |
122 | ||
123 | fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { | |
124 | let this = self.project(); | |
125 | this.stream.poll_flush(cx) | |
126 | } | |
127 | ||
128 | fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> { | |
129 | let this = self.project(); | |
130 | this.stream.poll_shutdown(cx) | |
131 | } | |
132 | } | |
133 | ||
134 | /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon) | |
135 | pub struct VsockClient { | |
136 | client: Client<VsockConnector>, | |
137 | cid: i32, | |
138 | port: u16, | |
48763935 | 139 | auth: Option<String>, |
89d25b19 SR |
140 | } |
141 | ||
142 | impl VsockClient { | |
48763935 | 143 | pub fn new(cid: i32, port: u16, auth: Option<String>) -> Self { |
89d25b19 SR |
144 | let conn = VsockConnector {}; |
145 | let client = Client::builder().build::<_, Body>(conn); | |
48763935 SR |
146 | Self { |
147 | client, | |
148 | cid, | |
149 | port, | |
150 | auth, | |
151 | } | |
89d25b19 SR |
152 | } |
153 | ||
154 | pub async fn get(&self, path: &str, data: Option<Value>) -> Result<Value, Error> { | |
48763935 | 155 | let req = self.request_builder("GET", path, data)?; |
89d25b19 SR |
156 | self.api_request(req).await |
157 | } | |
158 | ||
971bc6f9 | 159 | pub async fn post(&self, path: &str, data: Option<Value>) -> Result<Value, Error> { |
48763935 | 160 | let req = self.request_builder("POST", path, data)?; |
89d25b19 SR |
161 | self.api_request(req).await |
162 | } | |
163 | ||
164 | pub async fn download( | |
971bc6f9 | 165 | &self, |
89d25b19 SR |
166 | path: &str, |
167 | data: Option<Value>, | |
168 | output: &mut (dyn AsyncWrite + Send + Unpin), | |
169 | ) -> Result<(), Error> { | |
48763935 | 170 | let req = self.request_builder("GET", path, data)?; |
89d25b19 SR |
171 | |
172 | let client = self.client.clone(); | |
173 | ||
971bc6f9 SR |
174 | let resp = client |
175 | .request(req) | |
89d25b19 SR |
176 | .await |
177 | .map_err(|_| format_err!("vsock download request timed out"))?; | |
178 | let status = resp.status(); | |
179 | if !status.is_success() { | |
971bc6f9 | 180 | Self::api_response(resp).await.map(|_| ())? |
89d25b19 SR |
181 | } else { |
182 | resp.into_body() | |
183 | .map_err(Error::from) | |
184 | .try_fold(output, move |acc, chunk| async move { | |
185 | acc.write_all(&chunk).await?; | |
186 | Ok::<_, Error>(acc) | |
187 | }) | |
188 | .await?; | |
189 | } | |
190 | Ok(()) | |
191 | } | |
192 | ||
193 | async fn api_response(response: Response<Body>) -> Result<Value, Error> { | |
194 | let status = response.status(); | |
195 | let data = hyper::body::to_bytes(response.into_body()).await?; | |
196 | ||
197 | let text = String::from_utf8(data.to_vec()).unwrap(); | |
198 | if status.is_success() { | |
199 | if text.is_empty() { | |
200 | Ok(Value::Null) | |
201 | } else { | |
202 | let value: Value = serde_json::from_str(&text)?; | |
203 | Ok(value) | |
204 | } | |
205 | } else { | |
206 | Err(Error::from(HttpError::new(status, text))) | |
207 | } | |
208 | } | |
209 | ||
210 | async fn api_request(&self, req: Request<Body>) -> Result<Value, Error> { | |
211 | self.client | |
212 | .request(req) | |
213 | .map_err(Error::from) | |
214 | .and_then(Self::api_response) | |
215 | .await | |
216 | } | |
217 | ||
48763935 SR |
218 | fn request_builder( |
219 | &self, | |
89d25b19 SR |
220 | method: &str, |
221 | path: &str, | |
222 | data: Option<Value>, | |
223 | ) -> Result<Request<Body>, Error> { | |
224 | let path = path.trim_matches('/'); | |
48763935 SR |
225 | let url: Uri = format!("vsock://{}:{}/{}", self.cid, self.port, path).parse()?; |
226 | ||
227 | let make_builder = |content_type: &str, url: &Uri| { | |
228 | let mut builder = Request::builder() | |
229 | .method(method) | |
230 | .uri(url) | |
231 | .header(hyper::header::CONTENT_TYPE, content_type); | |
232 | if let Some(auth) = &self.auth { | |
233 | builder = builder.header(hyper::header::AUTHORIZATION, auth); | |
234 | } | |
235 | builder | |
236 | }; | |
89d25b19 SR |
237 | |
238 | if let Some(data) = data { | |
239 | if method == "POST" { | |
48763935 SR |
240 | let builder = make_builder("application/json", &url); |
241 | let request = builder.body(Body::from(data.to_string()))?; | |
89d25b19 SR |
242 | return Ok(request); |
243 | } else { | |
2b9cf927 | 244 | let query = json_object_to_query(data)?; |
48763935 SR |
245 | let url: Uri = |
246 | format!("vsock://{}:{}/{}?{}", self.cid, self.port, path, query).parse()?; | |
247 | let builder = make_builder("application/x-www-form-urlencoded", &url); | |
248 | let request = builder.body(Body::empty())?; | |
89d25b19 SR |
249 | return Ok(request); |
250 | } | |
251 | } | |
252 | ||
48763935 SR |
253 | let builder = make_builder("application/x-www-form-urlencoded", &url); |
254 | let request = builder.body(Body::empty())?; | |
89d25b19 SR |
255 | |
256 | Ok(request) | |
257 | } | |
258 | } |