1 use anyhow
::{bail, format_err, Error}
;
4 use core
::task
::Context
;
9 use http
::{Request, Response}
;
10 use hyper
::client
::connect
::{Connected, Connection}
;
11 use hyper
::client
::Client
;
13 use pin_project
::pin_project
;
14 use serde_json
::Value
;
15 use tokio
::io
::{ReadBuf, AsyncRead, AsyncWrite, AsyncWriteExt}
;
16 use tokio
::net
::UnixStream
;
19 use proxmox
::api
::error
::HttpError
;
21 /// Port below 1024 is privileged, this is intentional so only root (on host) can connect
22 pub const DEFAULT_VSOCK_PORT
: u16 = 807;
25 struct VsockConnector
;
28 /// Wrapper around UnixStream so we can implement hyper::client::connect::Connection
29 struct UnixConnection
{
34 impl tower_service
::Service
<Uri
> for VsockConnector
{
35 type Response
= UnixConnection
;
37 type Future
= Pin
<Box
<dyn Future
<Output
= Result
<UnixConnection
, Error
>> + Send
>>;
39 fn poll_ready(&mut self, _cx
: &mut task
::Context
<'_
>) -> Poll
<Result
<(), Self::Error
>> {
43 fn call(&mut self, dst
: Uri
) -> Self::Future
{
44 use nix
::sys
::socket
::*;
45 use std
::os
::unix
::io
::FromRawFd
;
47 // connect can block, so run in blocking task (though in reality it seems to immediately
48 // return with either ENODEV or ETIMEDOUT in case of error)
49 tokio
::task
::spawn_blocking(move || {
50 if dst
.scheme_str().unwrap_or_default() != "vsock" {
51 bail
!("invalid URI (scheme) for vsock connector: {}", dst
);
54 let cid
= match dst
.host() {
55 Some(host
) => host
.parse().map_err(|err
| {
57 "invalid URI (host not a number) for vsock connector: {} ({})",
62 None
=> bail
!("invalid URI (no host) for vsock connector: {}", dst
),
65 let port
= match dst
.port_u16() {
67 None
=> bail
!("invalid URI (bad port) for vsock connector: {}", dst
),
77 let sock_addr
= VsockAddr
::new(cid
, port
as u32);
78 connect(sock_fd
, &SockAddr
::Vsock(sock_addr
))?
;
80 // connect sync, but set nonblock after (tokio requires it)
81 let std_stream
= unsafe { std::os::unix::net::UnixStream::from_raw_fd(sock_fd) }
;
82 std_stream
.set_nonblocking(true)?
;
84 let stream
= tokio
::net
::UnixStream
::from_std(std_stream
)?
;
85 let connection
= UnixConnection { stream }
;
89 // unravel the thread JoinHandle to a usable future
90 .map(|res
| match res
{
92 Err(err
) => Err(format_err
!("thread join error on vsock connect: {}", err
)),
98 impl Connection
for UnixConnection
{
99 fn connected(&self) -> Connected
{
104 impl AsyncRead
for UnixConnection
{
106 self: Pin
<&mut Self>,
107 cx
: &mut Context
<'_
>,
109 ) -> Poll
<Result
<(), std
::io
::Error
>> {
110 let this
= self.project();
111 this
.stream
.poll_read(cx
, buf
)
115 impl AsyncWrite
for UnixConnection
{
117 self: Pin
<&mut Self>,
118 cx
: &mut Context
<'_
>,
120 ) -> Poll
<tokio
::io
::Result
<usize>> {
121 let this
= self.project();
122 this
.stream
.poll_write(cx
, buf
)
125 fn poll_flush(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
126 let this
= self.project();
127 this
.stream
.poll_flush(cx
)
130 fn poll_shutdown(self: Pin
<&mut Self>, cx
: &mut Context
<'_
>) -> Poll
<tokio
::io
::Result
<()>> {
131 let this
= self.project();
132 this
.stream
.poll_shutdown(cx
)
136 /// Slimmed down version of HttpClient for virtio-vsock connections (file restore daemon)
137 pub struct VsockClient
{
138 client
: Client
<VsockConnector
>,
144 pub fn new(cid
: i32, port
: u16) -> Self {
145 let conn
= VsockConnector {}
;
146 let client
= Client
::builder().build
::<_
, Body
>(conn
);
147 Self { client, cid, port }
150 pub async
fn get(&self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
151 let req
= Self::request_builder(self.cid
, self.port
, "GET", path
, data
)?
;
152 self.api_request(req
).await
155 pub async
fn post(&mut self, path
: &str, data
: Option
<Value
>) -> Result
<Value
, Error
> {
156 let req
= Self::request_builder(self.cid
, self.port
, "POST", path
, data
)?
;
157 self.api_request(req
).await
160 pub async
fn download(
164 output
: &mut (dyn AsyncWrite
+ Send
+ Unpin
),
165 ) -> Result
<(), Error
> {
166 let req
= Self::request_builder(self.cid
, self.port
, "GET", path
, data
)?
;
168 let client
= self.client
.clone();
170 let resp
= client
.request(req
)
172 .map_err(|_
| format_err
!("vsock download request timed out"))?
;
173 let status
= resp
.status();
174 if !status
.is_success() {
175 Self::api_response(resp
)
180 .map_err(Error
::from
)
181 .try_fold(output
, move |acc
, chunk
| async
move {
182 acc
.write_all(&chunk
).await?
;
190 async
fn api_response(response
: Response
<Body
>) -> Result
<Value
, Error
> {
191 let status
= response
.status();
192 let data
= hyper
::body
::to_bytes(response
.into_body()).await?
;
194 let text
= String
::from_utf8(data
.to_vec()).unwrap();
195 if status
.is_success() {
199 let value
: Value
= serde_json
::from_str(&text
)?
;
203 Err(Error
::from(HttpError
::new(status
, text
)))
207 async
fn api_request(&self, req
: Request
<Body
>) -> Result
<Value
, Error
> {
210 .map_err(Error
::from
)
211 .and_then(Self::api_response
)
215 pub fn request_builder(
221 ) -> Result
<Request
<Body
>, Error
> {
222 let path
= path
.trim_matches('
/'
);
223 let url
: Uri
= format
!("vsock://{}:{}/{}", cid
, port
, path
).parse()?
;
225 if let Some(data
) = data
{
226 if method
== "POST" {
227 let request
= Request
::builder()
230 .header(hyper
::header
::CONTENT_TYPE
, "application/json")
231 .body(Body
::from(data
.to_string()))?
;
234 let query
= tools
::json_object_to_query(data
)?
;
235 let url
: Uri
= format
!("vsock://{}:{}/{}?{}", cid
, port
, path
, query
).parse()?
;
236 let request
= Request
::builder()
240 hyper
::header
::CONTENT_TYPE
,
241 "application/x-www-form-urlencoded",
243 .body(Body
::empty())?
;
248 let request
= Request
::builder()
252 hyper
::header
::CONTENT_TYPE
,
253 "application/x-www-form-urlencoded",
255 .body(Body
::empty())?
;