]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/writer.rs
better handling of stream/sink closing
[connect-rs.git] / src / writer.rs
index 5b5b17a99102a1539f9349d07643354aa5fa4b4e..f6442a614e693bd4666cf17e10cee775faf333ca 100644 (file)
@@ -11,10 +11,11 @@ pub use futures::SinkExt;
 pub use futures::StreamExt;
 
 pub struct ConnectionWriter {
-    local_addr: SocketAddr,
-    peer_addr: SocketAddr,
-    write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
+    local_addr:    SocketAddr,
+    peer_addr:     SocketAddr,
+    write_stream:  Box<dyn AsyncWrite + Send + Sync + Unpin>,
     pending_write: Option<ConnectionMessage>,
+    closed:        bool,
 }
 
 impl ConnectionWriter {
@@ -28,6 +29,7 @@ impl ConnectionWriter {
             peer_addr,
             write_stream,
             pending_write: None,
+            closed: false,
         }
     }
 
@@ -38,41 +40,18 @@ impl ConnectionWriter {
     pub fn peer_addr(&self) -> SocketAddr {
         self.peer_addr.clone()
     }
-}
-
-impl<T: Message> Sink<T> for ConnectionWriter {
-    type Error = RecvError;
 
-    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        if self.pending_write.is_some() {
-            trace!("Connection not ready to send message yet, waiting for prior message");
-            Poll::Pending
-        } else {
-            trace!("Connection ready to send message");
-            Poll::Ready(Ok(()))
-        }
+    pub fn is_closed(&self) -> bool {
+        self.closed
     }
 
-    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
-        trace!("Preparing message to be sent next");
-        let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
-        self.pending_write.replace(stitch_msg);
-
-        Ok(())
-    }
-
-    fn poll_flush(
-        mut self: Pin<&mut Self>,
-        _cx: &mut Context<'_>,
-    ) -> Poll<Result<(), Self::Error>> {
+    fn send_to_conn(&mut self) -> Poll<Result<(), RecvError>> {
         if let Some(pending_msg) = self.pending_write.take() {
             trace!("Send pending message");
             if let Ok(buffer) = pending_msg.write_to_bytes() {
                 let msg_size = buffer.len();
                 trace!("{} bytes to be sent over network connection", msg_size);
 
-                trace!("{:?}", buffer.as_slice());
-
                 return if let Ok(_) =
                     futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
                 {
@@ -97,11 +76,43 @@ impl<T: Message> Sink<T> for ConnectionWriter {
 
         Poll::Ready(Ok(()))
     }
+}
+
+impl<M: Message> Sink<M> for ConnectionWriter {
+    type Error = RecvError;
+
+    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
+        if self.pending_write.is_some() {
+            trace!("Connection not ready to send message yet, waiting for prior message");
+            Poll::Pending
+        } else {
+            trace!("Connection ready to send message");
+            Poll::Ready(Ok(()))
+        }
+    }
+
+    fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
+        trace!("Preparing message to be sent next");
+        let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
+        self.pending_write.replace(stitch_msg);
+
+        Ok(())
+    }
+
+    fn poll_flush(
+        mut self: Pin<&mut Self>,
+        _cx: &mut Context<'_>,
+    ) -> Poll<Result<(), Self::Error>> {
+        self.send_to_conn()
+    }
 
     fn poll_close(
         mut self: Pin<&mut Self>,
         _cx: &mut Context<'_>,
     ) -> Poll<Result<(), Self::Error>> {
+        let _ = self.send_to_conn();
+
+        self.closed = true;
         if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
             Poll::Ready(Ok(()))
         } else {