]> git.proxmox.com Git - proxmox-backup.git/blame - src/client/vsock_client.rs
move more helpers to pbs-tools
[proxmox-backup.git] / src / client / vsock_client.rs
CommitLineData
89d25b19
SR
1use anyhow::{bail, format_err, Error};
2use futures::*;
3
4use core::task::Context;
5use std::pin::Pin;
6use std::task::Poll;
7
8use http::Uri;
9use http::{Request, Response};
10use hyper::client::connect::{Connected, Connection};
11use hyper::client::Client;
12use hyper::Body;
13use pin_project::pin_project;
14use serde_json::Value;
971bc6f9 15use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf};
89d25b19
SR
16use tokio::net::UnixStream;
17
89d25b19
SR
18use proxmox::api::error::HttpError;
19
89d25b19
SR
20pub const DEFAULT_VSOCK_PORT: u16 = 807;
21
22#[derive(Clone)]
23struct VsockConnector;
24
25#[pin_project]
26/// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
27struct UnixConnection {
28 #[pin]
29 stream: UnixStream,
30}
31
32impl tower_service::Service<Uri> for VsockConnector {
33 type Response = UnixConnection;
34 type Error = Error;
35 type Future = Pin<Box<dyn Future<Output = Result<UnixConnection, Error>> + Send>>;
36
37 fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
38 Poll::Ready(Ok(()))
39 }
40
41 fn call(&mut self, dst: Uri) -> Self::Future {
42 use nix::sys::socket::*;
43 use std::os::unix::io::FromRawFd;
44
45 // connect can block, so run in blocking task (though in reality it seems to immediately
46 // return with either ENODEV or ETIMEDOUT in case of error)
47 tokio::task::spawn_blocking(move || {
48 if dst.scheme_str().unwrap_or_default() != "vsock" {
49 bail!("invalid URI (scheme) for vsock connector: {}", dst);
50 }
51
52 let cid = match dst.host() {
53 Some(host) => host.parse().map_err(|err| {
54 format_err!(
55 "invalid URI (host not a number) for vsock connector: {} ({})",
56 dst,
57 err
58 )
59 })?,
60 None => bail!("invalid URI (no host) for vsock connector: {}", dst),
61 };
62
63 let port = match dst.port_u16() {
64 Some(port) => port,
65 None => bail!("invalid URI (bad port) for vsock connector: {}", dst),
66 };
67
68 let sock_fd = socket(
69 AddressFamily::Vsock,
70 SockType::Stream,
71 SockFlag::empty(),
72 None,
73 )?;
74
75 let sock_addr = VsockAddr::new(cid, port as u32);
76 connect(sock_fd, &SockAddr::Vsock(sock_addr))?;
77
78 // connect sync, but set nonblock after (tokio requires it)
79 let std_stream = unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) };
80 std_stream.set_nonblocking(true)?;
81
82 let stream = tokio::net::UnixStream::from_std(std_stream)?;
83 let connection = UnixConnection { stream };
84
85 Ok(connection)
86 })
d1d74c43 87 // unravel the thread JoinHandle to a usable future
89d25b19
SR
88 .map(|res| match res {
89 Ok(res) => res,
90 Err(err) => Err(format_err!("thread join error on vsock connect: {}", err)),
91 })
92 .boxed()
93 }
94}
95
96impl Connection for UnixConnection {
97 fn connected(&self) -> Connected {
98 Connected::new()
99 }
100}
101
102impl AsyncRead for UnixConnection {
103 fn poll_read(
104 self: Pin<&mut Self>,
105 cx: &mut Context<'_>,
106 buf: &mut ReadBuf,
107 ) -> Poll<Result<(), std::io::Error>> {
108 let this = self.project();
109 this.stream.poll_read(cx, buf)
110 }
111}
112
113impl AsyncWrite for UnixConnection {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 buf: &[u8],
118 ) -> Poll<tokio::io::Result<usize>> {
119 let this = self.project();
120 this.stream.poll_write(cx, buf)
121 }
122
123 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
124 let this = self.project();
125 this.stream.poll_flush(cx)
126 }
127
128 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<tokio::io::Result<()>> {
129 let this = self.project();
130 this.stream.poll_shutdown(cx)
131 }
132}
133
134/// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
135pub struct VsockClient {
136 client: Client<VsockConnector>,
137 cid: i32,
138 port: u16,
48763935 139 auth: Option<String>,
89d25b19
SR
140}
141
142impl VsockClient {
48763935 143 pub fn new(cid: i32, port: u16, auth: Option<String>) -> Self {
89d25b19
SR
144 let conn = VsockConnector {};
145 let client = Client::builder().build::<_, Body>(conn);
48763935
SR
146 Self {
147 client,
148 cid,
149 port,
150 auth,
151 }
89d25b19
SR
152 }
153
154 pub async fn get(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
48763935 155 let req = self.request_builder("GET", path, data)?;
89d25b19
SR
156 self.api_request(req).await
157 }
158
971bc6f9 159 pub async fn post(&self, path: &str, data: Option<Value>) -> Result<Value, Error> {
48763935 160 let req = self.request_builder("POST", path, data)?;
89d25b19
SR
161 self.api_request(req).await
162 }
163
164 pub async fn download(
971bc6f9 165 &self,
89d25b19
SR
166 path: &str,
167 data: Option<Value>,
168 output: &mut (dyn AsyncWrite + Send + Unpin),
169 ) -> Result<(), Error> {
48763935 170 let req = self.request_builder("GET", path, data)?;
89d25b19
SR
171
172 let client = self.client.clone();
173
971bc6f9
SR
174 let resp = client
175 .request(req)
89d25b19
SR
176 .await
177 .map_err(|_| format_err!("vsock download request timed out"))?;
178 let status = resp.status();
179 if !status.is_success() {
971bc6f9 180 Self::api_response(resp).await.map(|_| ())?
89d25b19
SR
181 } else {
182 resp.into_body()
183 .map_err(Error::from)
184 .try_fold(output, move |acc, chunk| async move {
185 acc.write_all(&chunk).await?;
186 Ok::<_, Error>(acc)
187 })
188 .await?;
189 }
190 Ok(())
191 }
192
193 async fn api_response(response: Response<Body>) -> Result<Value, Error> {
194 let status = response.status();
195 let data = hyper::body::to_bytes(response.into_body()).await?;
196
197 let text = String::from_utf8(data.to_vec()).unwrap();
198 if status.is_success() {
199 if text.is_empty() {
200 Ok(Value::Null)
201 } else {
202 let value: Value = serde_json::from_str(&text)?;
203 Ok(value)
204 }
205 } else {
206 Err(Error::from(HttpError::new(status, text)))
207 }
208 }
209
210 async fn api_request(&self, req: Request<Body>) -> Result<Value, Error> {
211 self.client
212 .request(req)
213 .map_err(Error::from)
214 .and_then(Self::api_response)
215 .await
216 }
217
48763935
SR
218 fn request_builder(
219 &self,
89d25b19
SR
220 method: &str,
221 path: &str,
222 data: Option<Value>,
223 ) -> Result<Request<Body>, Error> {
224 let path = path.trim_matches('/');
48763935
SR
225 let url: Uri = format!("vsock://{}:{}/{}", self.cid, self.port, path).parse()?;
226
227 let make_builder = |content_type: &str, url: &Uri| {
228 let mut builder = Request::builder()
229 .method(method)
230 .uri(url)
231 .header(hyper::header::CONTENT_TYPE, content_type);
232 if let Some(auth) = &self.auth {
233 builder = builder.header(hyper::header::AUTHORIZATION, auth);
234 }
235 builder
236 };
89d25b19
SR
237
238 if let Some(data) = data {
239 if method == "POST" {
48763935
SR
240 let builder = make_builder("application/json", &url);
241 let request = builder.body(Body::from(data.to_string()))?;
89d25b19
SR
242 return Ok(request);
243 } else {
9eb78407 244 let query = pbs_tools::json::json_object_to_query(data)?;
48763935
SR
245 let url: Uri =
246 format!("vsock://{}:{}/{}?{}", self.cid, self.port, path, query).parse()?;
247 let builder = make_builder("application/x-www-form-urlencoded", &url);
248 let request = builder.body(Body::empty())?;
89d25b19
SR
249 return Ok(request);
250 }
251 }
252
48763935
SR
253 let builder = make_builder("application/x-www-form-urlencoded", &url);
254 let request = builder.body(Body::empty())?;
89d25b19
SR
255
256 Ok(request)
257 }
258}