]> git.proxmox.com Git - proxmox-backup.git/blob - src/client/vsock_client.rs
vsock_client: remove wrong comment
[proxmox-backup.git] / src / client / vsock_client.rs
1 use anyhow::{bail, format_err, Error};
2 use futures::*;
3
4 use core::task::Context;
5 use std::pin::Pin;
6 use std::task::Poll;
7
8 use http::Uri;
9 use http::{Request, Response};
10 use hyper::client::connect::{Connected, Connection};
11 use hyper::client::Client;
12 use hyper::Body;
13 use pin_project::pin_project;
14 use serde_json::Value;
15 use tokio::io::{ReadBuf, AsyncRead, AsyncWrite, AsyncWriteExt};
16 use tokio::net::UnixStream;
17
18 use crate::tools;
19 use proxmox::api::error::HttpError;
20
21 pub const DEFAULT_VSOCK_PORT: u16 = 807;
22
23 #[derive(Clone)]
24 struct VsockConnector;
25
26 #[pin_project]
27 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
28 struct UnixConnection {
29 #[pin]
30 stream: UnixStream,
31 }
32
33 impl tower_service::Service<Uri> for VsockConnector {
34 type Response = UnixConnection;
35 type Error = Error;
36 type Future = Pin<Box<dyn Future<Output = Result<UnixConnection, Error>> + Send>>;
37
38 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
39 Poll::Ready(Ok(()))
40 }
41
42 fn call(&mut self, dst: Uri) -> Self::Future {
43 use nix::sys::socket::*;
44 use std::os::unix::io::FromRawFd;
45
46 // connect can block, so run in blocking task (though in reality it seems to immediately
47 // return with either ENODEV or ETIMEDOUT in case of error)
48 tokio::task::spawn_blocking(move || {
49 if dst.scheme_str().unwrap_or_default() != "vsock" {
50 bail!("invalid URI (scheme) for vsock connector: {}", dst);
51 }
52
53 let cid = match dst.host() {
54 Some(host) => host.parse().map_err(|err| {
55 format_err!(
56 "invalid URI (host not a number) for vsock connector: {} ({})",
57 dst,
58 err
59 )
60 })?,
61 None => bail!("invalid URI (no host) for vsock connector: {}", dst),
62 };
63
64 let port = match dst.port_u16() {
65 Some(port) => port,
66 None => bail!("invalid URI (bad port) for vsock connector: {}", dst),
67 };
68
69 let sock_fd = socket(
70 AddressFamily::Vsock,
71 SockType::Stream,
72 SockFlag::empty(),
73 None,
74 )?;
75
76 let sock_addr = VsockAddr::new(cid, port as u32);
77 connect(sock_fd, &SockAddr::Vsock(sock_addr))?;
78
79 // connect sync, but set nonblock after (tokio requires it)
80 let std_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) };
81 std_stream.set_nonblocking(true)?;
82
83 let stream = tokio::net::UnixStream::from_std(std_stream)?;
84 let connection = UnixConnection { stream };
85
86 Ok(connection)
87 })
88 // unravel the thread JoinHandle to a usable future
89 .map(|res| match res {
90 Ok(res) => res,
91 Err(err) => Err(format_err!("thread join error on vsock connect: {}", err)),
92 })
93 .boxed()
94 }
95 }
96
97 impl Connection for UnixConnection {
98 fn connected(&self) -> Connected {
99 Connected::new()
100 }
101 }
102
103 impl AsyncRead for UnixConnection {
104 fn poll_read(
105 self: Pin<&mut Self>,
106 cx: &mut Context<'_>,
107 buf: &mut ReadBuf,
108 ) -> Poll<Result<(), std::io::Error>> {
109 let this = self.project();
110 this.stream.poll_read(cx, buf)
111 }
112 }
113
114 impl AsyncWrite for UnixConnection {
115 fn poll_write(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 buf: &[u8],
119 ) -> Poll<tokio::io::Result<usize>> {
120 let this = self.project();
121 this.stream.poll_write(cx, buf)
122 }
123
124 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
125 let this = self.project();
126 this.stream.poll_flush(cx)
127 }
128
129 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
130 let this = self.project();
131 this.stream.poll_shutdown(cx)
132 }
133 }
134
135 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
136 pub struct VsockClient {
137 client: Client<VsockConnector>,
138 cid: i32,
139 port: u16,
140 }
141
142 impl VsockClient {
143 pub fn new(cid: i32, port: u16) -> Self {
144 let conn = VsockConnector {};
145 let client = Client::builder().build::<_, Body>(conn);
146 Self { client, cid, port }
147 }
148
149 pub async fn get(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
150 let req = Self::request_builder(self.cid, self.port, "GET", path, data)?;
151 self.api_request(req).await
152 }
153
154 pub async fn post(&mut self, path: &str, data: Option<Value>) -> Result<Value, Error> {
155 let req = Self::request_builder(self.cid, self.port, "POST", path, data)?;
156 self.api_request(req).await
157 }
158
159 pub async fn download(
160 &mut self,
161 path: &str,
162 data: Option<Value>,
163 output: &mut (dyn AsyncWrite + Send + Unpin),
164 ) -> Result<(), Error> {
165 let req = Self::request_builder(self.cid, self.port, "GET", path, data)?;
166
167 let client = self.client.clone();
168
169 let resp = client.request(req)
170 .await
171 .map_err(|_| format_err!("vsock download request timed out"))?;
172 let status = resp.status();
173 if !status.is_success() {
174 Self::api_response(resp)
175 .await
176 .map(|_| ())?
177 } else {
178 resp.into_body()
179 .map_err(Error::from)
180 .try_fold(output, move |acc, chunk| async move {
181 acc.write_all(&chunk).await?;
182 Ok::<_, Error>(acc)
183 })
184 .await?;
185 }
186 Ok(())
187 }
188
189 async fn api_response(response: Response<Body>) -> Result<Value, Error> {
190 let status = response.status();
191 let data = hyper::body::to_bytes(response.into_body()).await?;
192
193 let text = String::from_utf8(data.to_vec()).unwrap();
194 if status.is_success() {
195 if text.is_empty() {
196 Ok(Value::Null)
197 } else {
198 let value: Value = serde_json::from_str(&text)?;
199 Ok(value)
200 }
201 } else {
202 Err(Error::from(HttpError::new(status, text)))
203 }
204 }
205
206 async fn api_request(&self, req: Request<Body>) -> Result<Value, Error> {
207 self.client
208 .request(req)
209 .map_err(Error::from)
210 .and_then(Self::api_response)
211 .await
212 }
213
214 pub fn request_builder(
215 cid: i32,
216 port: u16,
217 method: &str,
218 path: &str,
219 data: Option<Value>,
220 ) -> Result<Request<Body>, Error> {
221 let path = path.trim_matches('/');
222 let url: Uri = format!("vsock://{}:{}/{}", cid, port, path).parse()?;
223
224 if let Some(data) = data {
225 if method == "POST" {
226 let request = Request::builder()
227 .method(method)
228 .uri(url)
229 .header(hyper::header::CONTENT_TYPE, "application/json")
230 .body(Body::from(data.to_string()))?;
231 return Ok(request);
232 } else {
233 let query = tools::json_object_to_query(data)?;
234 let url: Uri = format!("vsock://{}:{}/{}?{}", cid, port, path, query).parse()?;
235 let request = Request::builder()
236 .method(method)
237 .uri(url)
238 .header(
239 hyper::header::CONTENT_TYPE,
240 "application/x-www-form-urlencoded",
241 )
242 .body(Body::empty())?;
243 return Ok(request);
244 }
245 }
246
247 let request = Request::builder()
248 .method(method)
249 .uri(url)
250 .header(
251 hyper::header::CONTENT_TYPE,
252 "application/x-www-form-urlencoded",
253 )
254 .body(Body::empty())?;
255
256 Ok(request)
257 }
258 }