2 use std
::task
::{Context, Poll}
;
4 use anyhow
::{bail, format_err, Error}
;
7 use http
::{Request, Response}
;
8 use hyper
::client
::connect
::{Connected, Connection}
;
9 use hyper
::client
::Client
;
11 use pin_project_lite
::pin_project
;
12 use serde_json
::Value
;
13 use tokio
::io
::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}
;
14 use tokio
::net
::UnixStream
;
16 use proxmox_http
::uri
::json_object_to_query
;
17 use proxmox_router
::HttpError
;
19 pub const DEFAULT_VSOCK_PORT
: u16 = 807;
22 struct VsockConnector
;
25 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
26 struct UnixConnection
{
32 impl tower_service
::Service
<Uri
> for VsockConnector
{
33 type Response
= UnixConnection
;
35 type Future
= Pin
<Box
<dyn Future
<Output
= Result
<UnixConnection
, Error
>> + Send
>>;
37 fn poll_ready(&mut self, _cx
: &mut task
::Context
<'_
>) -> Poll
<Result
<(), Self::Error
>> {
41 fn call(&mut self, dst
: Uri
) -> Self::Future
{
42 use nix
::sys
::socket
::*;
43 use std
::os
::unix
::io
::FromRawFd
;
45 // connect can block, so run in blocking task (though in reality it seems to immediately
46 // return with either ENODEV or ETIMEDOUT in case of error)
47 tokio
::task
::spawn_blocking(move || {
48 if dst
.scheme_str().unwrap_or_default() != "vsock" {
49 bail
!("invalid URI (scheme) for vsock connector: {}", dst
);
52 let cid
= match dst
.host() {
53 Some(host
) => host
.parse().map_err(|err
| {
55 "invalid URI (host not a number) for vsock connector: {} ({})",
60 None
=> bail
!("invalid URI (no host) for vsock connector: {}", dst
),
63 let port
= match dst
.port_u16() {
65 None
=> bail
!("invalid URI (bad port) for vsock connector: {}", dst
),
75 let sock_addr
= VsockAddr
::new(cid
, port
as u32);
76 connect(sock_fd
, &sock_addr
)?
;
78 // connect sync, but set nonblock after (tokio requires it)
79 let std_stream
= unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) }
;
80 std_stream
.set_nonblocking(true)?
;
82 let stream
= tokio
::net
::UnixStream
::from_std(std_stream
)?
;
83 let connection
= UnixConnection { stream }
;
87 // unravel the thread JoinHandle to a usable future
88 .map(|res
| match res
{
90 Err(err
) => Err(format_err
!("thread join error on vsock connect: {}", err
)),
96 impl Connection
for UnixConnection
{
97 fn connected(&self) -> Connected
{
102 impl AsyncRead
for UnixConnection
{
104 self: Pin
<&mut Self>,
105 cx
: &mut Context
<'_
>,
107 ) -> Poll
<Result
<(), std
::io
::Error
>> {
108 let this
= self.project();
109 this
.stream
.poll_read(cx
, buf
)
113 impl AsyncWrite
for UnixConnection
{
115 self: Pin
<&mut Self>,
116 cx
: &mut Context
<'_
>,
118 ) -> Poll
<tokio
::io
::Result
<usize>> {
119 let this
= self.project();
120 this
.stream
.poll_write(cx
, buf
)
123 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
124 let this
= self.project();
125 this
.stream
.poll_flush(cx
)
128 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
129 let this
= self.project();
130 this
.stream
.poll_shutdown(cx
)
134 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
135 pub struct VsockClient
{
136 client
: Client
<VsockConnector
>,
139 auth
: Option
<String
>,
143 pub fn new(cid
: i32, port
: u16, auth
: Option
<String
>) -> Self {
144 let conn
= VsockConnector {}
;
145 let client
= Client
::builder().build
::<_
, Body
>(conn
);
154 pub async
fn get(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
155 let req
= self.request_builder("GET", path
, data
)?
;
156 self.api_request(req
).await
159 pub async
fn post(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
160 let req
= self.request_builder("POST", path
, data
)?
;
161 self.api_request(req
).await
164 pub async
fn download(
168 output
: &mut (dyn AsyncWrite
+ Send
+ Unpin
),
169 ) -> Result
<(), Error
> {
170 let req
= self.request_builder("GET", path
, data
)?
;
172 let client
= self.client
.clone();
177 .map_err(|_
| format_err
!("vsock download request timed out"))?
;
178 let status
= resp
.status();
179 if !status
.is_success() {
180 Self::api_response(resp
).await
.map(|_
| ())?
183 .map_err(Error
::from
)
184 .try_fold(output
, move |acc
, chunk
| async
move {
185 acc
.write_all(&chunk
).await?
;
193 async
fn api_response(response
: Response
<Body
>) -> Result
<Value
, Error
> {
194 let status
= response
.status();
195 let data
= hyper
::body
::to_bytes(response
.into_body()).await?
;
197 let text
= String
::from_utf8(data
.to_vec()).unwrap();
198 if status
.is_success() {
202 let value
: Value
= serde_json
::from_str(&text
)?
;
206 Err(Error
::from(HttpError
::new(status
, text
)))
210 async
fn api_request(&self, req
: Request
<Body
>) -> Result
<Value
, Error
> {
213 .map_err(Error
::from
)
214 .and_then(Self::api_response
)
223 ) -> Result
<Request
<Body
>, Error
> {
224 let path
= path
.trim_matches('
/'
);
225 let url
: Uri
= format
!("vsock://{}:{}/{}", self.cid
, self.port
, path
).parse()?
;
227 let make_builder
= |content_type
: &str, url
: &Uri
| {
228 let mut builder
= Request
::builder()
231 .header(hyper
::header
::CONTENT_TYPE
, content_type
);
232 if let Some(auth
) = &self.auth
{
233 builder
= builder
.header(hyper
::header
::AUTHORIZATION
, auth
);
238 if let Some(data
) = data
{
239 if method
== "POST" {
240 let builder
= make_builder("application/json", &url
);
241 let request
= builder
.body(Body
::from(data
.to_string()))?
;
244 let query
= json_object_to_query(data
)?
;
246 format
!("vsock://{}:{}/{}?{}", self.cid
, self.port
, path
, query
).parse()?
;
247 let builder
= make_builder("application/x-www-form-urlencoded", &url
);
248 let request
= builder
.body(Body
::empty())?
;
253 let builder
= make_builder("application/x-www-form-urlencoded", &url
);
254 let request
= builder
.body(Body
::empty())?
;