1 use anyhow
::{bail, format_err, Error}
;
4 use core
::task
::Context
;
9 use http
::{Request, Response}
;
10 use hyper
::client
::connect
::{Connected, Connection}
;
11 use hyper
::client
::Client
;
13 use pin_project
::pin_project
;
14 use serde_json
::Value
;
15 use tokio
::io
::{ReadBuf, AsyncRead, AsyncWrite, AsyncWriteExt}
;
16 use tokio
::net
::UnixStream
;
19 use proxmox
::api
::error
::HttpError
;
21 pub const DEFAULT_VSOCK_PORT
: u16 = 807;
24 struct VsockConnector
;
27 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
28 struct UnixConnection
{
33 impl tower_service
::Service
<Uri
> for VsockConnector
{
34 type Response
= UnixConnection
;
36 type Future
= Pin
<Box
<dyn Future
<Output
= Result
<UnixConnection
, Error
>> + Send
>>;
38 fn poll_ready(&mut self, _cx
: &mut task
::Context
<'_
>) -> Poll
<Result
<(), Self::Error
>> {
42 fn call(&mut self, dst
: Uri
) -> Self::Future
{
43 use nix
::sys
::socket
::*;
44 use std
::os
::unix
::io
::FromRawFd
;
46 // connect can block, so run in blocking task (though in reality it seems to immediately
47 // return with either ENODEV or ETIMEDOUT in case of error)
48 tokio
::task
::spawn_blocking(move || {
49 if dst
.scheme_str().unwrap_or_default() != "vsock" {
50 bail
!("invalid URI (scheme) for vsock connector: {}", dst
);
53 let cid
= match dst
.host() {
54 Some(host
) => host
.parse().map_err(|err
| {
56 "invalid URI (host not a number) for vsock connector: {} ({})",
61 None
=> bail
!("invalid URI (no host) for vsock connector: {}", dst
),
64 let port
= match dst
.port_u16() {
66 None
=> bail
!("invalid URI (bad port) for vsock connector: {}", dst
),
76 let sock_addr
= VsockAddr
::new(cid
, port
as u32);
77 connect(sock_fd
, &SockAddr
::Vsock(sock_addr
))?
;
79 // connect sync, but set nonblock after (tokio requires it)
80 let std_stream
= unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) }
;
81 std_stream
.set_nonblocking(true)?
;
83 let stream
= tokio
::net
::UnixStream
::from_std(std_stream
)?
;
84 let connection
= UnixConnection { stream }
;
88 // unravel the thread JoinHandle to a usable future
89 .map(|res
| match res
{
91 Err(err
) => Err(format_err
!("thread join error on vsock connect: {}", err
)),
97 impl Connection
for UnixConnection
{
98 fn connected(&self) -> Connected
{
103 impl AsyncRead
for UnixConnection
{
105 self: Pin
<&mut Self>,
106 cx
: &mut Context
<'_
>,
108 ) -> Poll
<Result
<(), std
::io
::Error
>> {
109 let this
= self.project();
110 this
.stream
.poll_read(cx
, buf
)
114 impl AsyncWrite
for UnixConnection
{
116 self: Pin
<&mut Self>,
117 cx
: &mut Context
<'_
>,
119 ) -> Poll
<tokio
::io
::Result
<usize>> {
120 let this
= self.project();
121 this
.stream
.poll_write(cx
, buf
)
124 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
125 let this
= self.project();
126 this
.stream
.poll_flush(cx
)
129 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
130 let this
= self.project();
131 this
.stream
.poll_shutdown(cx
)
135 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
136 pub struct VsockClient
{
137 client
: Client
<VsockConnector
>,
143 pub fn new(cid
: i32, port
: u16) -> Self {
144 let conn
= VsockConnector {}
;
145 let client
= Client
::builder().build
::<_
, Body
>(conn
);
146 Self { client, cid, port }
149 pub async
fn get(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
150 let req
= Self::request_builder(self.cid
, self.port
, "GET", path
, data
)?
;
151 self.api_request(req
).await
154 pub async
fn post(&mut self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
155 let req
= Self::request_builder(self.cid
, self.port
, "POST", path
, data
)?
;
156 self.api_request(req
).await
159 pub async
fn download(
163 output
: &mut (dyn AsyncWrite
+ Send
+ Unpin
),
164 ) -> Result
<(), Error
> {
165 let req
= Self::request_builder(self.cid
, self.port
, "GET", path
, data
)?
;
167 let client
= self.client
.clone();
169 let resp
= client
.request(req
)
171 .map_err(|_
| format_err
!("vsock download request timed out"))?
;
172 let status
= resp
.status();
173 if !status
.is_success() {
174 Self::api_response(resp
)
179 .map_err(Error
::from
)
180 .try_fold(output
, move |acc
, chunk
| async
move {
181 acc
.write_all(&chunk
).await?
;
189 async
fn api_response(response
: Response
<Body
>) -> Result
<Value
, Error
> {
190 let status
= response
.status();
191 let data
= hyper
::body
::to_bytes(response
.into_body()).await?
;
193 let text
= String
::from_utf8(data
.to_vec()).unwrap();
194 if status
.is_success() {
198 let value
: Value
= serde_json
::from_str(&text
)?
;
202 Err(Error
::from(HttpError
::new(status
, text
)))
206 async
fn api_request(&self, req
: Request
<Body
>) -> Result
<Value
, Error
> {
209 .map_err(Error
::from
)
210 .and_then(Self::api_response
)
214 pub fn request_builder(
220 ) -> Result
<Request
<Body
>, Error
> {
221 let path
= path
.trim_matches('
/'
);
222 let url
: Uri
= format
!("vsock://{}:{}/{}", cid
, port
, path
).parse()?
;
224 if let Some(data
) = data
{
225 if method
== "POST" {
226 let request
= Request
::builder()
229 .header(hyper
::header
::CONTENT_TYPE
, "application/json")
230 .body(Body
::from(data
.to_string()))?
;
233 let query
= tools
::json_object_to_query(data
)?
;
234 let url
: Uri
= format
!("vsock://{}:{}/{}?{}", cid
, port
, path
, query
).parse()?
;
235 let request
= Request
::builder()
239 hyper
::header
::CONTENT_TYPE
,
240 "application/x-www-form-urlencoded",
242 .body(Body
::empty())?
;
247 let request
= Request
::builder()
251 hyper
::header
::CONTENT_TYPE
,
252 "application/x-www-form-urlencoded",
254 .body(Body
::empty())?
;