2 use std
::task
::{Context, Poll}
;
4 use anyhow
::{bail, format_err, Error}
;
7 use http
::{Request, Response}
;
8 use hyper
::client
::connect
::{Connected, Connection}
;
9 use hyper
::client
::Client
;
11 use pin_project_lite
::pin_project
;
12 use serde_json
::Value
;
13 use tokio
::io
::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadBuf}
;
14 use tokio
::net
::UnixStream
;
16 use proxmox_router
::HttpError
;
18 pub const DEFAULT_VSOCK_PORT
: u16 = 807;
21 struct VsockConnector
;
24 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
25 struct UnixConnection
{
31 impl tower_service
::Service
<Uri
> for VsockConnector
{
32 type Response
= UnixConnection
;
34 type Future
= Pin
<Box
<dyn Future
<Output
= Result
<UnixConnection
, Error
>> + Send
>>;
36 fn poll_ready(&mut self, _cx
: &mut task
::Context
<'_
>) -> Poll
<Result
<(), Self::Error
>> {
40 fn call(&mut self, dst
: Uri
) -> Self::Future
{
41 use nix
::sys
::socket
::*;
42 use std
::os
::unix
::io
::FromRawFd
;
44 // connect can block, so run in blocking task (though in reality it seems to immediately
45 // return with either ENODEV or ETIMEDOUT in case of error)
46 tokio
::task
::spawn_blocking(move || {
47 if dst
.scheme_str().unwrap_or_default() != "vsock" {
48 bail
!("invalid URI (scheme) for vsock connector: {}", dst
);
51 let cid
= match dst
.host() {
52 Some(host
) => host
.parse().map_err(|err
| {
54 "invalid URI (host not a number) for vsock connector: {} ({})",
59 None
=> bail
!("invalid URI (no host) for vsock connector: {}", dst
),
62 let port
= match dst
.port_u16() {
64 None
=> bail
!("invalid URI (bad port) for vsock connector: {}", dst
),
74 let sock_addr
= VsockAddr
::new(cid
, port
as u32);
75 connect(sock_fd
, &SockAddr
::Vsock(sock_addr
))?
;
77 // connect sync, but set nonblock after (tokio requires it)
78 let std_stream
= unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) }
;
79 std_stream
.set_nonblocking(true)?
;
81 let stream
= tokio
::net
::UnixStream
::from_std(std_stream
)?
;
82 let connection
= UnixConnection { stream }
;
86 // unravel the thread JoinHandle to a usable future
87 .map(|res
| match res
{
89 Err(err
) => Err(format_err
!("thread join error on vsock connect: {}", err
)),
95 impl Connection
for UnixConnection
{
96 fn connected(&self) -> Connected
{
101 impl AsyncRead
for UnixConnection
{
103 self: Pin
<&mut Self>,
104 cx
: &mut Context
<'_
>,
106 ) -> Poll
<Result
<(), std
::io
::Error
>> {
107 let this
= self.project();
108 this
.stream
.poll_read(cx
, buf
)
112 impl AsyncWrite
for UnixConnection
{
114 self: Pin
<&mut Self>,
115 cx
: &mut Context
<'_
>,
117 ) -> Poll
<tokio
::io
::Result
<usize>> {
118 let this
= self.project();
119 this
.stream
.poll_write(cx
, buf
)
122 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
123 let this
= self.project();
124 this
.stream
.poll_flush(cx
)
127 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
128 let this
= self.project();
129 this
.stream
.poll_shutdown(cx
)
133 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
134 pub struct VsockClient
{
135 client
: Client
<VsockConnector
>,
138 auth
: Option
<String
>,
142 pub fn new(cid
: i32, port
: u16, auth
: Option
<String
>) -> Self {
143 let conn
= VsockConnector {}
;
144 let client
= Client
::builder().build
::<_
, Body
>(conn
);
153 pub async
fn get(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
154 let req
= self.request_builder("GET", path
, data
)?
;
155 self.api_request(req
).await
158 pub async
fn post(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
159 let req
= self.request_builder("POST", path
, data
)?
;
160 self.api_request(req
).await
163 pub async
fn download(
167 output
: &mut (dyn AsyncWrite
+ Send
+ Unpin
),
168 ) -> Result
<(), Error
> {
169 let req
= self.request_builder("GET", path
, data
)?
;
171 let client
= self.client
.clone();
176 .map_err(|_
| format_err
!("vsock download request timed out"))?
;
177 let status
= resp
.status();
178 if !status
.is_success() {
179 Self::api_response(resp
).await
.map(|_
| ())?
182 .map_err(Error
::from
)
183 .try_fold(output
, move |acc
, chunk
| async
move {
184 acc
.write_all(&chunk
).await?
;
192 async
fn api_response(response
: Response
<Body
>) -> Result
<Value
, Error
> {
193 let status
= response
.status();
194 let data
= hyper
::body
::to_bytes(response
.into_body()).await?
;
196 let text
= String
::from_utf8(data
.to_vec()).unwrap();
197 if status
.is_success() {
201 let value
: Value
= serde_json
::from_str(&text
)?
;
205 Err(Error
::from(HttpError
::new(status
, text
)))
209 async
fn api_request(&self, req
: Request
<Body
>) -> Result
<Value
, Error
> {
212 .map_err(Error
::from
)
213 .and_then(Self::api_response
)
222 ) -> Result
<Request
<Body
>, Error
> {
223 let path
= path
.trim_matches('
/'
);
224 let url
: Uri
= format
!("vsock://{}:{}/{}", self.cid
, self.port
, path
).parse()?
;
226 let make_builder
= |content_type
: &str, url
: &Uri
| {
227 let mut builder
= Request
::builder()
230 .header(hyper
::header
::CONTENT_TYPE
, content_type
);
231 if let Some(auth
) = &self.auth
{
232 builder
= builder
.header(hyper
::header
::AUTHORIZATION
, auth
);
237 if let Some(data
) = data
{
238 if method
== "POST" {
239 let builder
= make_builder("application/json", &url
);
240 let request
= builder
.body(Body
::from(data
.to_string()))?
;
243 let query
= pbs_tools
::json
::json_object_to_query(data
)?
;
245 format
!("vsock://{}:{}/{}?{}", self.cid
, self.port
, path
, query
).parse()?
;
246 let builder
= make_builder("application/x-www-form-urlencoded", &url
);
247 let request
= builder
.body(Body
::empty())?
;
252 let builder
= make_builder("application/x-www-form-urlencoded", &url
);
253 let request
= builder
.body(Body
::empty())?
;