]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/writer.rs
don't block in poll_x fns, fixes conn closing issues
[connect-rs.git] / src / writer.rs
index f6442a614e693bd4666cf17e10cee775faf333ca..f26ab4b12eba3719b98947c25fc80dc96fa5f24d 100644 (file)
@@ -2,8 +2,9 @@ use crate::schema::ConnectionMessage;
 use async_channel::RecvError;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
+use futures::{AsyncWrite, Sink};
+use futures::io::IoSlice;
 use futures::task::{Context, Poll};
-use futures::{AsyncWrite, AsyncWriteExt, Sink};
 use log::*;
 use protobuf::Message;
 
@@ -11,24 +12,24 @@ pub use futures::SinkExt;
 pub use futures::StreamExt;
 
 pub struct ConnectionWriter {
-    local_addr:    SocketAddr,
-    peer_addr:     SocketAddr,
-    write_stream:  Box<dyn AsyncWrite + Send + Sync + Unpin>,
-    pending_write: Option<ConnectionMessage>,
-    closed:        bool,
+    local_addr:     SocketAddr,
+    peer_addr:      SocketAddr,
+    write_stream:   Pin<Box<dyn AsyncWrite + Send + Sync>>,
+    pending_writes: Vec<Vec<u8>>,
+    closed:         bool,
 }
 
 impl ConnectionWriter {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
+        write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
     ) -> Self {
         Self {
             local_addr,
             peer_addr,
             write_stream,
-            pending_write: None,
+            pending_writes: Vec::new(),
             closed: false,
         }
     }
@@ -44,47 +45,15 @@ impl ConnectionWriter {
     pub fn is_closed(&self) -> bool {
         self.closed
     }
-
-    fn send_to_conn(&mut self) -> Poll<Result<(), RecvError>> {
-        if let Some(pending_msg) = self.pending_write.take() {
-            trace!("Send pending message");
-            if let Ok(buffer) = pending_msg.write_to_bytes() {
-                let msg_size = buffer.len();
-                trace!("{} bytes to be sent over network connection", msg_size);
-
-                return if let Ok(_) =
-                    futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
-                {
-                    if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) {
-                        trace!("Sent message of {} bytes", msg_size);
-                        Poll::Ready(Ok(()))
-                    } else {
-                        trace!("Encountered error while flushing queued bytes to network stream");
-                        Poll::Ready(Err(RecvError))
-                    }
-                } else {
-                    error!("Encountered error when writing to network stream");
-                    Poll::Ready(Err(RecvError))
-                };
-            } else {
-                error!("Encountered error when serializing message to bytes");
-                return Poll::Ready(Err(RecvError));
-            }
-        } else {
-            trace!("No message to send over connection");
-        }
-
-        Poll::Ready(Ok(()))
-    }
 }
 
 impl<M: Message> Sink<M> for ConnectionWriter {
     type Error = RecvError;
 
     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        if self.pending_write.is_some() {
-            trace!("Connection not ready to send message yet, waiting for prior message");
-            Poll::Pending
+        if self.is_closed() {
+            trace!("Connection is closed, cannot send message");
+            Poll::Ready(Err(RecvError))
         } else {
             trace!("Connection ready to send message");
             Poll::Ready(Ok(()))
@@ -93,30 +62,126 @@ impl<M: Message> Sink<M> for ConnectionWriter {
 
     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
         trace!("Preparing message to be sent next");
-        let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
-        self.pending_write.replace(stitch_msg);
+        let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
 
-        Ok(())
+        if let Ok(buffer) = msg.write_to_bytes() {
+            let msg_size = buffer.len();
+            trace!("Serialized pending message into {} bytes", msg_size);
+
+            self.pending_writes.push(buffer);
+
+            Ok(())
+        } else {
+            error!("Encountered error when serializing message to bytes");
+            Err(RecvError)
+        }
     }
 
     fn poll_flush(
         mut self: Pin<&mut Self>,
-        _cx: &mut Context<'_>,
+        cx: &mut Context<'_>,
     ) -> Poll<Result<(), Self::Error>> {
-        self.send_to_conn()
+        if self.pending_writes.len() > 0 {
+            let stream = self.write_stream.as_mut();
+
+            match stream.poll_flush(cx) {
+                Poll::Pending => Poll::Pending,
+
+                Poll::Ready(Ok(_)) => {
+                    trace!("Sending pending bytes");
+
+                    let pending = self.pending_writes.split_off(0);
+                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+                        IoSlice::new(p)
+                    }).collect();
+
+                    let stream = self.write_stream.as_mut();
+                    match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+                        Poll::Pending => Poll::Pending,
+
+                        Poll::Ready(Ok(bytes_written)) => {
+                            trace!("Wrote {} bytes to network stream", bytes_written);
+                            Poll::Ready(Ok(()))
+                        },
+
+                        Poll::Ready(Err(_e)) => {
+                            error!("Encountered error when writing to network stream");
+                            Poll::Ready(Err(RecvError))
+                        },
+                    }
+                },
+
+                Poll::Ready(Err(_e)) => {
+                    error!("Encountered error when flushing network stream");
+                    Poll::Ready(Err(RecvError))
+                }
+            }
+        } else {
+            Poll::Ready(Ok(()))
+        }
     }
 
     fn poll_close(
         mut self: Pin<&mut Self>,
-        _cx: &mut Context<'_>,
+        cx: &mut Context<'_>,
     ) -> Poll<Result<(), Self::Error>> {
-        let _ = self.send_to_conn();
-
         self.closed = true;
-        if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
-            Poll::Ready(Ok(()))
+
+        let flush = if self.pending_writes.len() > 0 {
+            let stream = self.write_stream.as_mut();
+
+            match stream.poll_flush(cx) {
+                Poll::Pending => Poll::Pending,
+
+                Poll::Ready(Ok(_)) => {
+                    trace!("Sending pending bytes");
+
+                    let pending = self.pending_writes.split_off(0);
+                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+                        IoSlice::new(p)
+                    }).collect();
+
+                    let stream = self.write_stream.as_mut();
+                    match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+                        Poll::Pending => Poll::Pending,
+
+                        Poll::Ready(Ok(bytes_written)) => {
+                            trace!("Wrote {} bytes to network stream", bytes_written);
+                            Poll::Ready(Ok(()))
+                        },
+
+                        Poll::Ready(Err(_e)) => {
+                            error!("Encountered error when writing to network stream");
+                            Poll::Ready(Err(RecvError))
+                        },
+                    }
+                },
+
+                Poll::Ready(Err(_e)) => {
+                    error!("Encountered error when flushing network stream");
+                    Poll::Ready(Err(RecvError))
+                }
+            }
         } else {
-            Poll::Ready(Err(RecvError))
+            Poll::Ready(Ok(()))
+        };
+
+        match flush {
+            Poll::Pending => Poll::Pending,
+
+            Poll::Ready(Ok(_)) => {
+                let stream = self.write_stream.as_mut();
+
+                match stream.poll_close(cx) {
+                    Poll::Pending => Poll::Pending,
+
+                    Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
+
+                    Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
+                }
+            },
+
+            err => err,
         }
     }
 }