]> git.proxmox.com Git - proxmox-backup.git/blob - src/client/vsock_client.rs
typo fixes all over the place
[proxmox-backup.git] / src / client / vsock_client.rs
1 use anyhow::{bail, format_err, Error};
2 use futures::*;
3
4 use core::task::Context;
5 use std::pin::Pin;
6 use std::task::Poll;
7
8 use http::Uri;
9 use http::{Request, Response};
10 use hyper::client::connect::{Connected, Connection};
11 use hyper::client::Client;
12 use hyper::Body;
13 use pin_project::pin_project;
14 use serde_json::Value;
15 use tokio::io::{ReadBuf, AsyncRead, AsyncWrite, AsyncWriteExt};
16 use tokio::net::UnixStream;
17
18 use crate::tools;
19 use proxmox::api::error::HttpError;
20
21 /// Port below 1024 is privileged, this is intentional so only root (on host) can connect
22 pub const DEFAULT_VSOCK_PORT: u16 = 807;
23
24 #[derive(Clone)]
25 struct VsockConnector;
26
27 #[pin_project]
28 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
29 struct UnixConnection {
30 #[pin]
31 stream: UnixStream,
32 }
33
34 impl tower_service::Service<Uri> for VsockConnector {
35 type Response = UnixConnection;
36 type Error = Error;
37 type Future = Pin<Box<dyn Future<Output = Result<UnixConnection, Error>> + Send>>;
38
39 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
40 Poll::Ready(Ok(()))
41 }
42
43 fn call(&mut self, dst: Uri) -> Self::Future {
44 use nix::sys::socket::*;
45 use std::os::unix::io::FromRawFd;
46
47 // connect can block, so run in blocking task (though in reality it seems to immediately
48 // return with either ENODEV or ETIMEDOUT in case of error)
49 tokio::task::spawn_blocking(move || {
50 if dst.scheme_str().unwrap_or_default() != "vsock" {
51 bail!("invalid URI (scheme) for vsock connector: {}", dst);
52 }
53
54 let cid = match dst.host() {
55 Some(host) => host.parse().map_err(|err| {
56 format_err!(
57 "invalid URI (host not a number) for vsock connector: {} ({})",
58 dst,
59 err
60 )
61 })?,
62 None => bail!("invalid URI (no host) for vsock connector: {}", dst),
63 };
64
65 let port = match dst.port_u16() {
66 Some(port) => port,
67 None => bail!("invalid URI (bad port) for vsock connector: {}", dst),
68 };
69
70 let sock_fd = socket(
71 AddressFamily::Vsock,
72 SockType::Stream,
73 SockFlag::empty(),
74 None,
75 )?;
76
77 let sock_addr = VsockAddr::new(cid, port as u32);
78 connect(sock_fd, &SockAddr::Vsock(sock_addr))?;
79
80 // connect sync, but set nonblock after (tokio requires it)
81 let std_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) };
82 std_stream.set_nonblocking(true)?;
83
84 let stream = tokio::net::UnixStream::from_std(std_stream)?;
85 let connection = UnixConnection { stream };
86
87 Ok(connection)
88 })
89 // unravel the thread JoinHandle to a usable future
90 .map(|res| match res {
91 Ok(res) => res,
92 Err(err) => Err(format_err!("thread join error on vsock connect: {}", err)),
93 })
94 .boxed()
95 }
96 }
97
98 impl Connection for UnixConnection {
99 fn connected(&self) -> Connected {
100 Connected::new()
101 }
102 }
103
104 impl AsyncRead for UnixConnection {
105 fn poll_read(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf,
109 ) -> Poll<Result<(), std::io::Error>> {
110 let this = self.project();
111 this.stream.poll_read(cx, buf)
112 }
113 }
114
115 impl AsyncWrite for UnixConnection {
116 fn poll_write(
117 self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 buf: &[u8],
120 ) -> Poll<tokio::io::Result<usize>> {
121 let this = self.project();
122 this.stream.poll_write(cx, buf)
123 }
124
125 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
126 let this = self.project();
127 this.stream.poll_flush(cx)
128 }
129
130 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
131 let this = self.project();
132 this.stream.poll_shutdown(cx)
133 }
134 }
135
136 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
137 pub struct VsockClient {
138 client: Client<VsockConnector>,
139 cid: i32,
140 port: u16,
141 }
142
143 impl VsockClient {
144 pub fn new(cid: i32, port: u16) -> Self {
145 let conn = VsockConnector {};
146 let client = Client::builder().build::<_, Body>(conn);
147 Self { client, cid, port }
148 }
149
150 pub async fn get(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
151 let req = Self::request_builder(self.cid, self.port, "GET", path, data)?;
152 self.api_request(req).await
153 }
154
155 pub async fn post(&mut self, path: &str, data: Option<Value>) -> Result<Value, Error> {
156 let req = Self::request_builder(self.cid, self.port, "POST", path, data)?;
157 self.api_request(req).await
158 }
159
160 pub async fn download(
161 &mut self,
162 path: &str,
163 data: Option<Value>,
164 output: &mut (dyn AsyncWrite + Send + Unpin),
165 ) -> Result<(), Error> {
166 let req = Self::request_builder(self.cid, self.port, "GET", path, data)?;
167
168 let client = self.client.clone();
169
170 let resp = client.request(req)
171 .await
172 .map_err(|_| format_err!("vsock download request timed out"))?;
173 let status = resp.status();
174 if !status.is_success() {
175 Self::api_response(resp)
176 .await
177 .map(|_| ())?
178 } else {
179 resp.into_body()
180 .map_err(Error::from)
181 .try_fold(output, move |acc, chunk| async move {
182 acc.write_all(&chunk).await?;
183 Ok::<_, Error>(acc)
184 })
185 .await?;
186 }
187 Ok(())
188 }
189
190 async fn api_response(response: Response<Body>) -> Result<Value, Error> {
191 let status = response.status();
192 let data = hyper::body::to_bytes(response.into_body()).await?;
193
194 let text = String::from_utf8(data.to_vec()).unwrap();
195 if status.is_success() {
196 if text.is_empty() {
197 Ok(Value::Null)
198 } else {
199 let value: Value = serde_json::from_str(&text)?;
200 Ok(value)
201 }
202 } else {
203 Err(Error::from(HttpError::new(status, text)))
204 }
205 }
206
207 async fn api_request(&self, req: Request<Body>) -> Result<Value, Error> {
208 self.client
209 .request(req)
210 .map_err(Error::from)
211 .and_then(Self::api_response)
212 .await
213 }
214
215 pub fn request_builder(
216 cid: i32,
217 port: u16,
218 method: &str,
219 path: &str,
220 data: Option<Value>,
221 ) -> Result<Request<Body>, Error> {
222 let path = path.trim_matches('/');
223 let url: Uri = format!("vsock://{}:{}/{}", cid, port, path).parse()?;
224
225 if let Some(data) = data {
226 if method == "POST" {
227 let request = Request::builder()
228 .method(method)
229 .uri(url)
230 .header(hyper::header::CONTENT_TYPE, "application/json")
231 .body(Body::from(data.to_string()))?;
232 return Ok(request);
233 } else {
234 let query = tools::json_object_to_query(data)?;
235 let url: Uri = format!("vsock://{}:{}/{}?{}", cid, port, path, query).parse()?;
236 let request = Request::builder()
237 .method(method)
238 .uri(url)
239 .header(
240 hyper::header::CONTENT_TYPE,
241 "application/x-www-form-urlencoded",
242 )
243 .body(Body::empty())?;
244 return Ok(request);
245 }
246 }
247
248 let request = Request::builder()
249 .method(method)
250 .uri(url)
251 .header(
252 hyper::header::CONTENT_TYPE,
253 "application/x-www-form-urlencoded",
254 )
255 .body(Body::empty())?;
256
257 Ok(request)
258 }
259 }