]> git.proxmox.com Git - proxmox-backup.git/blob - pbs-client/src/vsock_client.rs
update to first proxmox crate split
[proxmox-backup.git] / pbs-client / src / vsock_client.rs
1 use std::pin::Pin;
2 use std::task::{Context, Poll};
3
4 use anyhow::{bail, format_err, Error};
5 use futures::*;
6 use http::Uri;
7 use http::{Request, Response};
8 use hyper::client::connect::{Connected, Connection};
9 use hyper::client::Client;
10 use hyper::Body;
11 use pin_project_lite::pin_project;
12 use serde_json::Value;
13 use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
14 use tokio::net::UnixStream;
15
16 use proxmox_router::HttpError;
17
18 pub const DEFAULT_VSOCK_PORT: u16 = 807;
19
20 #[derive(Clone)]
21 struct VsockConnector;
22
23 pin_project! {
24 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
25 struct UnixConnection {
26 #[pin]
27 stream: UnixStream,
28 }
29 }
30
31 impl tower_service::Service<Uri> for VsockConnector {
32 type Response = UnixConnection;
33 type Error = Error;
34 type Future = Pin<Box<dyn Future<Output = Result<UnixConnection, Error>> + Send>>;
35
36 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
37 Poll::Ready(Ok(()))
38 }
39
40 fn call(&mut self, dst: Uri) -> Self::Future {
41 use nix::sys::socket::*;
42 use std::os::unix::io::FromRawFd;
43
44 // connect can block, so run in blocking task (though in reality it seems to immediately
45 // return with either ENODEV or ETIMEDOUT in case of error)
46 tokio::task::spawn_blocking(move || {
47 if dst.scheme_str().unwrap_or_default() != "vsock" {
48 bail!("invalid URI (scheme) for vsock connector: {}", dst);
49 }
50
51 let cid = match dst.host() {
52 Some(host) => host.parse().map_err(|err| {
53 format_err!(
54 "invalid URI (host not a number) for vsock connector: {} ({})",
55 dst,
56 err
57 )
58 })?,
59 None => bail!("invalid URI (no host) for vsock connector: {}", dst),
60 };
61
62 let port = match dst.port_u16() {
63 Some(port) => port,
64 None => bail!("invalid URI (bad port) for vsock connector: {}", dst),
65 };
66
67 let sock_fd = socket(
68 AddressFamily::Vsock,
69 SockType::Stream,
70 SockFlag::empty(),
71 None,
72 )?;
73
74 let sock_addr = VsockAddr::new(cid, port as u32);
75 connect(sock_fd, &SockAddr::Vsock(sock_addr))?;
76
77 // connect sync, but set nonblock after (tokio requires it)
78 let std_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) };
79 std_stream.set_nonblocking(true)?;
80
81 let stream = tokio::net::UnixStream::from_std(std_stream)?;
82 let connection = UnixConnection { stream };
83
84 Ok(connection)
85 })
86 // unravel the thread JoinHandle to a usable future
87 .map(|res| match res {
88 Ok(res) => res,
89 Err(err) => Err(format_err!("thread join error on vsock connect: {}", err)),
90 })
91 .boxed()
92 }
93 }
94
95 impl Connection for UnixConnection {
96 fn connected(&self) -> Connected {
97 Connected::new()
98 }
99 }
100
101 impl AsyncRead for UnixConnection {
102 fn poll_read(
103 self: Pin<&mut Self>,
104 cx: &mut Context<'_>,
105 buf: &mut ReadBuf,
106 ) -> Poll<Result<(), std::io::Error>> {
107 let this = self.project();
108 this.stream.poll_read(cx, buf)
109 }
110 }
111
112 impl AsyncWrite for UnixConnection {
113 fn poll_write(
114 self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 buf: &[u8],
117 ) -> Poll<tokio::io::Result<usize>> {
118 let this = self.project();
119 this.stream.poll_write(cx, buf)
120 }
121
122 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
123 let this = self.project();
124 this.stream.poll_flush(cx)
125 }
126
127 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
128 let this = self.project();
129 this.stream.poll_shutdown(cx)
130 }
131 }
132
133 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
134 pub struct VsockClient {
135 client: Client<VsockConnector>,
136 cid: i32,
137 port: u16,
138 auth: Option<String>,
139 }
140
141 impl VsockClient {
142 pub fn new(cid: i32, port: u16, auth: Option<String>) -> Self {
143 let conn = VsockConnector {};
144 let client = Client::builder().build::<_, Body>(conn);
145 Self {
146 client,
147 cid,
148 port,
149 auth,
150 }
151 }
152
153 pub async fn get(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
154 let req = self.request_builder("GET", path, data)?;
155 self.api_request(req).await
156 }
157
158 pub async fn post(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
159 let req = self.request_builder("POST", path, data)?;
160 self.api_request(req).await
161 }
162
163 pub async fn download(
164 &self,
165 path: &str,
166 data: Option<Value>,
167 output: &mut (dyn AsyncWrite + Send + Unpin),
168 ) -> Result<(), Error> {
169 let req = self.request_builder("GET", path, data)?;
170
171 let client = self.client.clone();
172
173 let resp = client
174 .request(req)
175 .await
176 .map_err(|_| format_err!("vsock download request timed out"))?;
177 let status = resp.status();
178 if !status.is_success() {
179 Self::api_response(resp).await.map(|_| ())?
180 } else {
181 resp.into_body()
182 .map_err(Error::from)
183 .try_fold(output, move |acc, chunk| async move {
184 acc.write_all(&chunk).await?;
185 Ok::<_, Error>(acc)
186 })
187 .await?;
188 }
189 Ok(())
190 }
191
192 async fn api_response(response: Response<Body>) -> Result<Value, Error> {
193 let status = response.status();
194 let data = hyper::body::to_bytes(response.into_body()).await?;
195
196 let text = String::from_utf8(data.to_vec()).unwrap();
197 if status.is_success() {
198 if text.is_empty() {
199 Ok(Value::Null)
200 } else {
201 let value: Value = serde_json::from_str(&text)?;
202 Ok(value)
203 }
204 } else {
205 Err(Error::from(HttpError::new(status, text)))
206 }
207 }
208
209 async fn api_request(&self, req: Request<Body>) -> Result<Value, Error> {
210 self.client
211 .request(req)
212 .map_err(Error::from)
213 .and_then(Self::api_response)
214 .await
215 }
216
217 fn request_builder(
218 &self,
219 method: &str,
220 path: &str,
221 data: Option<Value>,
222 ) -> Result<Request<Body>, Error> {
223 let path = path.trim_matches('/');
224 let url: Uri = format!("vsock://{}:{}/{}", self.cid, self.port, path).parse()?;
225
226 let make_builder = |content_type: &str, url: &Uri| {
227 let mut builder = Request::builder()
228 .method(method)
229 .uri(url)
230 .header(hyper::header::CONTENT_TYPE, content_type);
231 if let Some(auth) = &self.auth {
232 builder = builder.header(hyper::header::AUTHORIZATION, auth);
233 }
234 builder
235 };
236
237 if let Some(data) = data {
238 if method == "POST" {
239 let builder = make_builder("application/json", &url);
240 let request = builder.body(Body::from(data.to_string()))?;
241 return Ok(request);
242 } else {
243 let query = pbs_tools::json::json_object_to_query(data)?;
244 let url: Uri =
245 format!("vsock://{}:{}/{}?{}", self.cid, self.port, path, query).parse()?;
246 let builder = make_builder("application/x-www-form-urlencoded", &url);
247 let request = builder.body(Body::empty())?;
248 return Ok(request);
249 }
250 }
251
252 let builder = make_builder("application/x-www-form-urlencoded", &url);
253 let request = builder.body(Body::empty())?;
254
255 Ok(request)
256 }
257 }