]> git.proxmox.com Git - pxar.git/blobdiff - src/encoder/mod.rs
don't hold temp buffer mutex across await point
[pxar.git] / src / encoder / mod.rs
index 962965cb9e5d17db28e932a0966c491611aa9951..d14b2d04765bce11378af2de2007205675f64635 100644 (file)
@@ -2,6 +2,8 @@
 //!
 //! This is the implementation used by both the synchronous and async pxar wrappers.
 
+#![deny(missing_docs)]
+
 use std::io;
 use std::mem::{forget, size_of, size_of_val, take};
 use std::os::unix::ffi::OsStrExt;
@@ -29,6 +31,7 @@ pub use sync::Encoder;
 pub struct LinkOffset(u64);
 
 impl LinkOffset {
+    /// Get the raw byte offset of this link.
     #[inline]
     pub fn raw(self) -> u64 {
         self.0
@@ -41,27 +44,31 @@ impl LinkOffset {
 /// synchronous wrapper and for both `tokio` and `future` `AsyncWrite` types in the asynchronous
 /// wrapper.
 pub trait SeqWrite {
+    /// Attempt to perform a sequential write to the file. On success, the number of written bytes
+    /// is returned as `Poll::Ready(Ok(bytes))`.
+    ///
+    /// If writing is not yet possible, `Poll::Pending` is returned and the current task will be
+    /// notified via the `cx.waker()` when writing becomes possible.
     fn poll_seq_write(
         self: Pin<&mut Self>,
         cx: &mut Context,
         buf: &[u8],
     ) -> Poll<io::Result<usize>>;
 
+    /// Attempt to flush the output, ensuring that all buffered data reaches the destination.
+    ///
+    /// On success, returns `Poll::Ready(Ok(()))`.
+    ///
+    /// If flushing cannot complete immediately, `Poll::Pending` is returned and the current task
+    /// will be notified via `cx.waker()` when progress can be made.
     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>>;
-
-    /// To avoid recursively borrowing each time we nest into a subdirectory we add this helper.
-    /// Otherwise starting a subdirectory will get a trait object pointing to `T`, nesting another
-    /// subdirectory in that would have a trait object pointing to the trait object, and so on.
-    fn as_trait_object(&mut self) -> &mut dyn SeqWrite
-    where
-        Self: Sized,
-    {
-        self as &mut dyn SeqWrite
-    }
 }
 
 /// Allow using trait objects for generics taking a `SeqWrite`.
-impl<'a> SeqWrite for &mut (dyn SeqWrite + 'a) {
+impl<S> SeqWrite for &mut S
+where
+    S: SeqWrite + ?Sized,
+{
     fn poll_seq_write(
         self: Pin<&mut Self>,
         cx: &mut Context,
@@ -76,13 +83,6 @@ impl<'a> SeqWrite for &mut (dyn SeqWrite + 'a) {
     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
         unsafe { self.map_unchecked_mut(|this| &mut **this).poll_flush(cx) }
     }
-
-    fn as_trait_object(&mut self) -> &mut dyn SeqWrite
-    where
-        Self: Sized,
-    {
-        &mut **self
-    }
 }
 
 /// awaitable verison of `poll_seq_write`.
@@ -97,6 +97,11 @@ async fn seq_write<T: SeqWrite + ?Sized>(
     Ok(put)
 }
 
+/// awaitable version of 'poll_flush'.
+async fn flush<T: SeqWrite + ?Sized>(output: &mut T) -> io::Result<()> {
+    poll_fn(|cx| unsafe { Pin::new_unchecked(&mut *output).poll_flush(cx) }).await
+}
+
 /// Write the entire contents of a buffer, handling short writes.
 async fn seq_write_all<T: SeqWrite + ?Sized>(
     output: &mut T,
@@ -122,12 +127,7 @@ where
     let data = data.to_le();
     let buf =
         unsafe { std::slice::from_raw_parts(&data as *const E as *const u8, size_of_val(&data)) };
-    seq_write_all(
-        output,
-        buf,
-        position,
-    )
-    .await
+    seq_write_all(output, buf, position).await
 }
 
 /// Write a pxar entry.
@@ -180,13 +180,7 @@ where
     let data = data.to_le();
     let buf =
         unsafe { std::slice::from_raw_parts(&data as *const E as *const u8, size_of_val(&data)) };
-    seq_write_pxar_entry(
-        output,
-        htype,
-        buf,
-        position,
-    )
-    .await
+    seq_write_pxar_entry(output, htype, buf, position).await
 }
 
 /// Error conditions caused by wrong usage of this crate.
@@ -215,6 +209,7 @@ struct EncoderState {
     /// Offset of this directory's ENTRY.
     entry_offset: u64,
 
+    #[allow(dead_code)]
     /// Offset to this directory's first FILENAME.
     files_offset: u64,
 
@@ -241,12 +236,48 @@ impl EncoderState {
     }
 }
 
+pub(crate) enum EncoderOutput<'a, T> {
+    Owned(T),
+    Borrowed(&'a mut T),
+}
+
+impl<'a, T> EncoderOutput<'a, T> {
+    #[inline]
+    fn to_borrowed_mut<'s>(&'s mut self) -> EncoderOutput<'s, T>
+    where
+        'a: 's,
+    {
+        EncoderOutput::Borrowed(self.as_mut())
+    }
+}
+
+impl<'a, T> std::convert::AsMut<T> for EncoderOutput<'a, T> {
+    fn as_mut(&mut self) -> &mut T {
+        match self {
+            EncoderOutput::Owned(o) => o,
+            EncoderOutput::Borrowed(b) => b,
+        }
+    }
+}
+
+impl<'a, T> std::convert::From<T> for EncoderOutput<'a, T> {
+    fn from(t: T) -> Self {
+        EncoderOutput::Owned(t)
+    }
+}
+
+impl<'a, T> std::convert::From<&'a mut T> for EncoderOutput<'a, T> {
+    fn from(t: &'a mut T) -> Self {
+        EncoderOutput::Borrowed(t)
+    }
+}
+
 /// The encoder state machine implementation for a directory.
 ///
 /// We use `async fn` to implement the encoder state machine so that we can easily plug in both
 /// synchronous or `async` I/O objects in as output.
 pub(crate) struct EncoderImpl<'a, T: SeqWrite + 'a> {
-    output: Option<T>,
+    output: EncoderOutput<'a, T>,
     state: EncoderState,
     parent: Option<&'a mut EncoderState>,
     finished: bool,
@@ -272,16 +303,21 @@ impl<'a, T: SeqWrite + 'a> Drop for EncoderImpl<'a, T> {
 }
 
 impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
-    pub async fn new(output: T, metadata: &Metadata) -> io::Result<EncoderImpl<'a, T>> {
+    pub async fn new(
+        output: EncoderOutput<'a, T>,
+        metadata: &Metadata,
+    ) -> io::Result<EncoderImpl<'a, T>> {
         if !metadata.is_dir() {
             io_bail!("directory metadata must contain the directory mode flag");
         }
         let mut this = Self {
-            output: Some(output),
+            output,
             state: EncoderState::default(),
             parent: None,
             finished: false,
-            file_copy_buffer: Arc::new(Mutex::new(crate::util::vec_new(1024 * 1024))),
+            file_copy_buffer: Arc::new(Mutex::new(unsafe {
+                crate::util::vec_new_uninitialized(1024 * 1024)
+            })),
         };
 
         this.encode_metadata(metadata).await?;
@@ -303,7 +339,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         metadata: &Metadata,
         file_name: &Path,
         file_size: u64,
-    ) -> io::Result<FileImpl<'b>>
+    ) -> io::Result<FileImpl<'b, T>>
     where
         'a: 'b,
     {
@@ -316,7 +352,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         metadata: &Metadata,
         file_name: &[u8],
         file_size: u64,
-    ) -> io::Result<FileImpl<'b>>
+    ) -> io::Result<FileImpl<'b, T>>
     where
         'a: 'b,
     {
@@ -328,19 +364,14 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         let header = format::Header::with_content_size(format::PXAR_PAYLOAD, file_size);
         header.check_header_size()?;
 
-        seq_write_struct(
-            self.output.as_mut().unwrap(),
-            header,
-            &mut self.state.write_position,
-        )
-        .await?;
+        seq_write_struct(self.output.as_mut(), header, &mut self.state.write_position).await?;
 
         let payload_data_offset = self.position();
 
         let meta_size = payload_data_offset - file_offset;
 
         Ok(FileImpl {
-            output: self.output.as_mut().unwrap(),
+            output: self.output.as_mut(),
             goodbye_item: GoodbyeItem {
                 hash: format::hash_filename(file_name),
                 offset: file_offset,
@@ -351,6 +382,31 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         })
     }
 
+    fn take_file_copy_buffer(&self) -> Vec<u8> {
+        let buf: Vec<_> = take(
+            &mut self
+                .file_copy_buffer
+                .lock()
+                .expect("failed to lock temporary buffer mutex"),
+        );
+        if buf.len() < 1024 * 1024 {
+            drop(buf);
+            unsafe { crate::util::vec_new_uninitialized(1024 * 1024) }
+        } else {
+            buf
+        }
+    }
+
+    fn put_file_copy_buffer(&self, buf: Vec<u8>) {
+        let mut lock = self
+            .file_copy_buffer
+            .lock()
+            .expect("failed to lock temporary buffer mutex");
+        if lock.len() < buf.len() {
+            *lock = buf;
+        }
+    }
+
     /// Return a file offset usable with `add_hardlink`.
     pub async fn add_file(
         &mut self,
@@ -359,9 +415,8 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         file_size: u64,
         content: &mut dyn SeqRead,
     ) -> io::Result<LinkOffset> {
-        let buf = Arc::clone(&self.file_copy_buffer);
+        let mut buf = self.take_file_copy_buffer();
         let mut file = self.create_file(metadata, file_name, file_size).await?;
-        let mut buf = buf.lock().expect("failed to lock temporary buffer mutex");
         loop {
             let got = decoder::seq_read(&mut *content, &mut buf[..]).await?;
             if got == 0 {
@@ -370,7 +425,10 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
                 file.write_all(&buf[..got]).await?;
             }
         }
-        Ok(file.file_offset())
+        let offset = file.file_offset();
+        drop(file);
+        self.put_file_copy_buffer(buf);
+        Ok(offset)
     }
 
     /// Return a file offset usable with `add_hardlink`.
@@ -482,7 +540,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         self.start_file_do(metadata, file_name).await?;
         if let Some((htype, entry_data)) = entry_htype_data {
             seq_write_pxar_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 htype,
                 entry_data,
                 &mut self.state.write_position,
@@ -506,14 +564,11 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         self.state.write_position
     }
 
-    pub async fn create_directory<'b>(
-        &'b mut self,
+    pub async fn create_directory(
+        &mut self,
         file_name: &Path,
         metadata: &Metadata,
-    ) -> io::Result<EncoderImpl<'b, &'b mut dyn SeqWrite>>
-    where
-        'a: 'b,
-    {
+    ) -> io::Result<EncoderImpl<'_, T>> {
         self.check()?;
 
         if !metadata.is_dir() {
@@ -527,15 +582,18 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         self.encode_filename(file_name).await?;
 
         let entry_offset = self.position();
-        self.encode_metadata(&metadata).await?;
+        self.encode_metadata(metadata).await?;
 
         let files_offset = self.position();
 
         // the child will write to OUR state now:
         let write_position = self.position();
 
+        let file_copy_buffer = Arc::clone(&self.file_copy_buffer);
+
         Ok(EncoderImpl {
-            output: self.output.as_mut().map(SeqWrite::as_trait_object),
+            // always forward as Borrowed(), to avoid stacking references on nested calls
+            output: self.output.to_borrowed_mut(),
             state: EncoderState {
                 entry_offset,
                 files_offset,
@@ -546,7 +604,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
             },
             parent: Some(&mut self.state),
             finished: false,
-            file_copy_buffer: Arc::clone(&self.file_copy_buffer),
+            file_copy_buffer,
         })
     }
 
@@ -557,14 +615,14 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
     ) -> io::Result<()> {
         self.encode_filename(file_name).await?;
         if let Some(metadata) = metadata {
-            self.encode_metadata(&metadata).await?;
+            self.encode_metadata(metadata).await?;
         }
         Ok(())
     }
 
     async fn encode_metadata(&mut self, metadata: &Metadata) -> io::Result<()> {
         seq_write_pxar_struct_entry(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_ENTRY,
             metadata.stat.clone(),
             &mut self.state.write_position,
@@ -590,7 +648,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
     async fn write_xattr(&mut self, xattr: &format::XAttr) -> io::Result<()> {
         seq_write_pxar_entry(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_XATTR,
             &xattr.data,
             &mut self.state.write_position,
@@ -601,7 +659,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
     async fn write_acls(&mut self, acl: &crate::Acl) -> io::Result<()> {
         for acl in &acl.users {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_USER,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -611,7 +669,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
         for acl in &acl.groups {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_GROUP,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -621,7 +679,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
         if let Some(acl) = &acl.group_obj {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_GROUP_OBJ,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -631,7 +689,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
         if let Some(acl) = &acl.default {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_DEFAULT,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -641,7 +699,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
         for acl in &acl.default_users {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_DEFAULT_USER,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -651,7 +709,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
         for acl in &acl.default_groups {
             seq_write_pxar_struct_entry(
-                self.output.as_mut().unwrap(),
+                self.output.as_mut(),
                 format::PXAR_ACL_DEFAULT_GROUP,
                 acl.clone(),
                 &mut self.state.write_position,
@@ -664,7 +722,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 
     async fn write_file_capabilities(&mut self, fcaps: &format::FCaps) -> io::Result<()> {
         seq_write_pxar_entry(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_FCAPS,
             &fcaps.data,
             &mut self.state.write_position,
@@ -677,7 +735,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         quota_project_id: &format::QuotaProjectId,
     ) -> io::Result<()> {
         seq_write_pxar_struct_entry(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_QUOTA_PROJID,
             *quota_project_id,
             &mut self.state.write_position,
@@ -688,7 +746,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
     async fn encode_filename(&mut self, file_name: &[u8]) -> io::Result<()> {
         crate::util::validate_filename(file_name)?;
         seq_write_pxar_entry_zero(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_FILENAME,
             file_name,
             &mut self.state.write_position,
@@ -696,16 +754,20 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         .await
     }
 
-    pub async fn finish(mut self) -> io::Result<T> {
+    pub async fn finish(mut self) -> io::Result<()> {
         let tail_bytes = self.finish_goodbye_table().await?;
         seq_write_pxar_entry(
-            self.output.as_mut().unwrap(),
+            self.output.as_mut(),
             format::PXAR_GOODBYE,
             &tail_bytes,
             &mut self.state.write_position,
         )
         .await?;
 
+        if let EncoderOutput::Owned(output) = &mut self.output {
+            flush(output).await?;
+        }
+
         // done up here because of the self-borrow and to propagate
         let end_offset = self.position();
 
@@ -724,11 +786,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
             });
         }
         self.finished = true;
-        Ok(self.output.take().unwrap())
-    }
-
-    pub fn into_writer(mut self) -> T {
-        self.output.take().unwrap()
+        Ok(())
     }
 
     async fn finish_goodbye_table(&mut self) -> io::Result<Vec<u8>> {
@@ -743,6 +801,7 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
         tail.sort_unstable_by(|a, b| a.hash.cmp(&b.hash));
 
         let mut bst = Vec::with_capacity(tail.len() + 1);
+        #[allow(clippy::uninit_vec)]
         unsafe {
             bst.set_len(tail.len());
         }
@@ -775,8 +834,8 @@ impl<'a, T: SeqWrite + 'a> EncoderImpl<'a, T> {
 }
 
 /// Writer for a file object in a directory.
-pub struct FileImpl<'a> {
-    output: &'a mut dyn SeqWrite,
+pub(crate) struct FileImpl<'a, S: SeqWrite> {
+    output: &'a mut S,
 
     /// This file's `GoodbyeItem`. FIXME: We currently don't touch this, can we just push it
     /// directly instead of on Drop of FileImpl?
@@ -791,7 +850,7 @@ pub struct FileImpl<'a> {
     parent: &'a mut EncoderState,
 }
 
-impl<'a> Drop for FileImpl<'a> {
+impl<'a, S: SeqWrite> Drop for FileImpl<'a, S> {
     fn drop(&mut self) {
         if self.remaining_size != 0 {
             self.parent.add_error(EncodeError::IncompleteFile);
@@ -801,7 +860,7 @@ impl<'a> Drop for FileImpl<'a> {
     }
 }
 
-impl<'a> FileImpl<'a> {
+impl<'a, S: SeqWrite> FileImpl<'a, S> {
     /// Get the file offset to be able to reference it with `add_hardlink`.
     pub fn file_offset(&self) -> LinkOffset {
         LinkOffset(self.goodbye_item.offset)
@@ -838,10 +897,7 @@ impl<'a> FileImpl<'a> {
     /// Poll flush interface to more easily connect to tokio/futures.
     #[cfg(feature = "tokio-io")]
     pub fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
-        unsafe {
-            self.map_unchecked_mut(|this| &mut this.output)
-                .poll_flush(cx)
-        }
+        unsafe { self.map_unchecked_mut(|this| this.output).poll_flush(cx) }
     }
 
     /// Poll close/shutdown interface to more easily connect to tokio/futures.
@@ -850,10 +906,7 @@ impl<'a> FileImpl<'a> {
     /// provided by our encoder.
     #[cfg(feature = "tokio-io")]
     pub fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
-        unsafe {
-            self.map_unchecked_mut(|this| &mut this.output)
-                .poll_flush(cx)
-        }
+        unsafe { self.map_unchecked_mut(|this| this.output).poll_flush(cx) }
     }
 
     /// Write file data for the current file entry in a pxar archive.
@@ -875,14 +928,14 @@ impl<'a> FileImpl<'a> {
     /// Completely write file data for the current file entry in a pxar archive.
     pub async fn write_all(&mut self, data: &[u8]) -> io::Result<()> {
         self.check_remaining(data.len())?;
-        seq_write_all(&mut self.output, data, &mut self.parent.write_position).await?;
+        seq_write_all(self.output, data, &mut self.parent.write_position).await?;
         self.remaining_size -= data.len() as u64;
         Ok(())
     }
 }
 
 #[cfg(feature = "tokio-io")]
-impl<'a> tokio::io::AsyncWrite for FileImpl<'a> {
+impl<'a, S: SeqWrite> tokio::io::AsyncWrite for FileImpl<'a, S> {
     fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
         FileImpl::poll_write(self, cx, buf)
     }