]> git.lizzy.rs Git - connect-rs.git/commitdiff
don't block in poll_x fns, fixes conn closing issues
authorSachandhan Ganesh <sachan.ganesh@gmail.com>
Wed, 20 Jan 2021 04:50:21 +0000 (20:50 -0800)
committerSachandhan Ganesh <sachan.ganesh@gmail.com>
Wed, 20 Jan 2021 04:50:21 +0000 (20:50 -0800)
examples/tcp-client/src/schema/hello_world.rs
examples/tcp-echo-server/src/schema/hello_world.rs
examples/tls-client/src/schema/hello_world.rs
examples/tls-echo-server/src/schema/hello_world.rs
src/lib.rs
src/reader.rs
src/schema/message.rs
src/schema/mod.rs
src/tcp/client.rs
src/tls/client.rs
src/writer.rs

index 5af493570751990f5ef8b8923b6d61fa1aaf3877..1af771cf17386e220e5611e48047288af8c68b0d 100644 (file)
@@ -1,4 +1,4 @@
-// This file is generated by rust-protobuf 2.19.0. Do not edit
+// This file is generated by rust-protobuf 2.20.0. Do not edit
 // @generated
 
 // https://github.com/rust-lang/rust-clippy/issues/702
@@ -21,7 +21,7 @@
 
 /// Generated files are compatible only with the same version
 /// of protobuf runtime.
-// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0;
 
 #[derive(PartialEq,Clone,Default)]
 pub struct HelloWorld {
index 5af493570751990f5ef8b8923b6d61fa1aaf3877..1af771cf17386e220e5611e48047288af8c68b0d 100644 (file)
@@ -1,4 +1,4 @@
-// This file is generated by rust-protobuf 2.19.0. Do not edit
+// This file is generated by rust-protobuf 2.20.0. Do not edit
 // @generated
 
 // https://github.com/rust-lang/rust-clippy/issues/702
@@ -21,7 +21,7 @@
 
 /// Generated files are compatible only with the same version
 /// of protobuf runtime.
-// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0;
 
 #[derive(PartialEq,Clone,Default)]
 pub struct HelloWorld {
index 5af493570751990f5ef8b8923b6d61fa1aaf3877..1af771cf17386e220e5611e48047288af8c68b0d 100644 (file)
@@ -1,4 +1,4 @@
-// This file is generated by rust-protobuf 2.19.0. Do not edit
+// This file is generated by rust-protobuf 2.20.0. Do not edit
 // @generated
 
 // https://github.com/rust-lang/rust-clippy/issues/702
@@ -21,7 +21,7 @@
 
 /// Generated files are compatible only with the same version
 /// of protobuf runtime.
-// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0;
 
 #[derive(PartialEq,Clone,Default)]
 pub struct HelloWorld {
index 5af493570751990f5ef8b8923b6d61fa1aaf3877..1af771cf17386e220e5611e48047288af8c68b0d 100644 (file)
@@ -1,4 +1,4 @@
-// This file is generated by rust-protobuf 2.19.0. Do not edit
+// This file is generated by rust-protobuf 2.20.0. Do not edit
 // @generated
 
 // https://github.com/rust-lang/rust-clippy/issues/702
@@ -21,7 +21,7 @@
 
 /// Generated files are compatible only with the same version
 /// of protobuf runtime.
-// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0;
 
 #[derive(PartialEq,Clone,Default)]
 pub struct HelloWorld {
index 9239db5379336efb356c753c0a3c2bbf03384337..c85c2cd8d99a5b50f05a257ec447454b7bd77f92 100644 (file)
@@ -6,7 +6,7 @@ mod writer;
 
 pub use crate::reader::ConnectionReader;
 pub use crate::writer::ConnectionWriter;
-use async_std::net::SocketAddr;
+use async_std::{net::SocketAddr, pin::Pin};
 use futures::{AsyncRead, AsyncWrite};
 pub use futures::{SinkExt, StreamExt};
 
@@ -22,8 +22,8 @@ impl Connection {
     pub(crate) fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
-        write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
+        read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
+        write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
     ) -> Self {
         Self {
             local_addr,
index 0452f6c59daf65e182f57c0445b24b10e43900e7..cfdc75404fd54962b61354a685b80b8f475d0c83 100644 (file)
@@ -3,7 +3,7 @@ use async_std::net::SocketAddr;
 use async_std::pin::Pin;
 use bytes::{Buf, BytesMut};
 use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, Stream};
+use futures::{AsyncRead, Stream};
 use log::*;
 use protobuf::Message;
 use std::convert::TryInto;
@@ -17,7 +17,7 @@ const BUFFER_SIZE: usize = 8192;
 pub struct ConnectionReader {
     local_addr:   SocketAddr,
     peer_addr:    SocketAddr,
-    read_stream:  Box<dyn AsyncRead + Send + Sync + Unpin>,
+    read_stream:  Pin<Box<dyn AsyncRead + Send + Sync>>,
     pending_read: Option<BytesMut>,
     closed:       bool,
 }
@@ -26,7 +26,7 @@ impl ConnectionReader {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+        read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
     ) -> Self {
         Self {
             local_addr,
@@ -48,24 +48,35 @@ impl ConnectionReader {
     pub fn is_closed(&self) -> bool {
         self.closed
     }
+
+    pub(crate) fn close_stream(&mut self) {
+        trace!("Closing the stream for connection with {}", self.peer_addr);
+        self.pending_read.take();
+        self.closed = true;
+    }
 }
 
 impl Stream for ConnectionReader {
     type Item = Any;
 
-    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         let mut buffer = BytesMut::new();
         buffer.resize(BUFFER_SIZE, 0);
 
         trace!("Starting new read loop for {}", self.local_addr);
         loop {
             trace!("Reading from the stream");
-            match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
-                Ok(mut bytes_read) => {
+            let stream = self.read_stream.as_mut();
+
+            match stream.poll_read(cx, &mut buffer) {
+                Poll::Pending => return Poll::Pending,
+
+                Poll::Ready(Ok(mut bytes_read)) => {
                     if bytes_read > 0 {
                         trace!("Read {} bytes from the network stream", bytes_read)
-                    } else if bytes_read == 0 && self.pending_read.is_none() {
-                        return Poll::Pending;
+                    } else if self.pending_read.is_none() {
+                        self.close_stream();
+                        return Poll::Ready(None)
                     }
 
                     if let Some(mut pending_buf) = self.pending_read.take() {
@@ -125,11 +136,9 @@ impl Stream for ConnectionReader {
                 }
 
                 // Close the stream
-                Err(_err) => {
-                    trace!("Closing the stream");
-                    self.pending_read.take();
-                    self.closed = true;
-                    return Poll::Ready(None);
+                Poll::Ready(Err(_e)) => {
+                    self.close_stream();
+                    Poll::Ready(None)
                 }
             }
         }
index c41fbb06e070fc36afc56ab4f2ae9383d39227aa..265075fc8d9a3fac37c50f7a07480eabbcb00c31 100644 (file)
@@ -1,4 +1,4 @@
-// This file is generated by rust-protobuf 2.19.0. Do not edit
+// This file is generated by rust-protobuf 2.20.0. Do not edit
 // @generated
 
 // https://github.com/rust-lang/rust-clippy/issues/702
@@ -21,7 +21,7 @@
 
 /// Generated files are compatible only with the same version
 /// of protobuf runtime.
-// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
+// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0;
 
 #[derive(PartialEq,Clone,Default)]
 pub struct ConnectionMessage {
index 49c0b7c849a0a354f47eadff247fc02ba0fd7816..2774e6463d3a0929ffea8cba49eede1ec3f88c3e 100644 (file)
@@ -5,7 +5,7 @@ use protobuf::well_known_types::Any;
 use protobuf::Message;
 
 impl ConnectionMessage {
-    pub(crate) fn from_msg<T: Message>(msg: T) -> Self {
+    pub(crate) fn from_msg<M: Message>(msg: M) -> Self {
         let mut sm = Self::new();
         let payload = Any::pack(&msg).expect("Protobuf Message could not be packed into Any type");
 
index 140c00d4fff8c63058323a20c825c5e4d07f165e..e9cd42bfc395b511675c4dabaf092ed11822f2e1 100644 (file)
@@ -28,8 +28,8 @@ impl From<TcpStream> for Connection {
         Self::new(
             local_addr,
             peer_addr,
-            Box::new(stream),
-            Box::new(write_stream),
+            Box::pin(stream),
+            Box::pin(write_stream),
         )
     }
 }
index 225c0db6536b1897851eb73105160094db59e933..afb95656eea924ab375ef168df5fde7357e72050 100644 (file)
@@ -58,8 +58,8 @@ impl From<TlsConnectionMetadata> for Connection {
                 Self::new(
                     local_addr,
                     peer_addr,
-                    Box::new(read_stream),
-                    Box::new(write_stream),
+                    Box::pin(read_stream),
+                    Box::pin(write_stream),
                 )
             }
 
@@ -73,8 +73,8 @@ impl From<TlsConnectionMetadata> for Connection {
                 Self::new(
                     local_addr,
                     peer_addr,
-                    Box::new(read_stream),
-                    Box::new(write_stream),
+                    Box::pin(read_stream),
+                    Box::pin(write_stream),
                 )
             }
         }
index f6442a614e693bd4666cf17e10cee775faf333ca..f26ab4b12eba3719b98947c25fc80dc96fa5f24d 100644 (file)
@@ -2,8 +2,9 @@ use crate::schema::ConnectionMessage;
 use async_channel::RecvError;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
+use futures::{AsyncWrite, Sink};
+use futures::io::IoSlice;
 use futures::task::{Context, Poll};
-use futures::{AsyncWrite, AsyncWriteExt, Sink};
 use log::*;
 use protobuf::Message;
 
@@ -11,24 +12,24 @@ pub use futures::SinkExt;
 pub use futures::StreamExt;
 
 pub struct ConnectionWriter {
-    local_addr:    SocketAddr,
-    peer_addr:     SocketAddr,
-    write_stream:  Box<dyn AsyncWrite + Send + Sync + Unpin>,
-    pending_write: Option<ConnectionMessage>,
-    closed:        bool,
+    local_addr:     SocketAddr,
+    peer_addr:      SocketAddr,
+    write_stream:   Pin<Box<dyn AsyncWrite + Send + Sync>>,
+    pending_writes: Vec<Vec<u8>>,
+    closed:         bool,
 }
 
 impl ConnectionWriter {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
+        write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
     ) -> Self {
         Self {
             local_addr,
             peer_addr,
             write_stream,
-            pending_write: None,
+            pending_writes: Vec::new(),
             closed: false,
         }
     }
@@ -44,47 +45,15 @@ impl ConnectionWriter {
     pub fn is_closed(&self) -> bool {
         self.closed
     }
-
-    fn send_to_conn(&mut self) -> Poll<Result<(), RecvError>> {
-        if let Some(pending_msg) = self.pending_write.take() {
-            trace!("Send pending message");
-            if let Ok(buffer) = pending_msg.write_to_bytes() {
-                let msg_size = buffer.len();
-                trace!("{} bytes to be sent over network connection", msg_size);
-
-                return if let Ok(_) =
-                    futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
-                {
-                    if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) {
-                        trace!("Sent message of {} bytes", msg_size);
-                        Poll::Ready(Ok(()))
-                    } else {
-                        trace!("Encountered error while flushing queued bytes to network stream");
-                        Poll::Ready(Err(RecvError))
-                    }
-                } else {
-                    error!("Encountered error when writing to network stream");
-                    Poll::Ready(Err(RecvError))
-                };
-            } else {
-                error!("Encountered error when serializing message to bytes");
-                return Poll::Ready(Err(RecvError));
-            }
-        } else {
-            trace!("No message to send over connection");
-        }
-
-        Poll::Ready(Ok(()))
-    }
 }
 
 impl<M: Message> Sink<M> for ConnectionWriter {
     type Error = RecvError;
 
     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
-        if self.pending_write.is_some() {
-            trace!("Connection not ready to send message yet, waiting for prior message");
-            Poll::Pending
+        if self.is_closed() {
+            trace!("Connection is closed, cannot send message");
+            Poll::Ready(Err(RecvError))
         } else {
             trace!("Connection ready to send message");
             Poll::Ready(Ok(()))
@@ -93,30 +62,126 @@ impl<M: Message> Sink<M> for ConnectionWriter {
 
     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
         trace!("Preparing message to be sent next");
-        let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
-        self.pending_write.replace(stitch_msg);
+        let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
 
-        Ok(())
+        if let Ok(buffer) = msg.write_to_bytes() {
+            let msg_size = buffer.len();
+            trace!("Serialized pending message into {} bytes", msg_size);
+
+            self.pending_writes.push(buffer);
+
+            Ok(())
+        } else {
+            error!("Encountered error when serializing message to bytes");
+            Err(RecvError)
+        }
     }
 
     fn poll_flush(
         mut self: Pin<&mut Self>,
-        _cx: &mut Context<'_>,
+        cx: &mut Context<'_>,
     ) -> Poll<Result<(), Self::Error>> {
-        self.send_to_conn()
+        if self.pending_writes.len() > 0 {
+            let stream = self.write_stream.as_mut();
+
+            match stream.poll_flush(cx) {
+                Poll::Pending => Poll::Pending,
+
+                Poll::Ready(Ok(_)) => {
+                    trace!("Sending pending bytes");
+
+                    let pending = self.pending_writes.split_off(0);
+                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+                        IoSlice::new(p)
+                    }).collect();
+
+                    let stream = self.write_stream.as_mut();
+                    match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+                        Poll::Pending => Poll::Pending,
+
+                        Poll::Ready(Ok(bytes_written)) => {
+                            trace!("Wrote {} bytes to network stream", bytes_written);
+                            Poll::Ready(Ok(()))
+                        },
+
+                        Poll::Ready(Err(_e)) => {
+                            error!("Encountered error when writing to network stream");
+                            Poll::Ready(Err(RecvError))
+                        },
+                    }
+                },
+
+                Poll::Ready(Err(_e)) => {
+                    error!("Encountered error when flushing network stream");
+                    Poll::Ready(Err(RecvError))
+                }
+            }
+        } else {
+            Poll::Ready(Ok(()))
+        }
     }
 
     fn poll_close(
         mut self: Pin<&mut Self>,
-        _cx: &mut Context<'_>,
+        cx: &mut Context<'_>,
     ) -> Poll<Result<(), Self::Error>> {
-        let _ = self.send_to_conn();
-
         self.closed = true;
-        if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
-            Poll::Ready(Ok(()))
+
+        let flush = if self.pending_writes.len() > 0 {
+            let stream = self.write_stream.as_mut();
+
+            match stream.poll_flush(cx) {
+                Poll::Pending => Poll::Pending,
+
+                Poll::Ready(Ok(_)) => {
+                    trace!("Sending pending bytes");
+
+                    let pending = self.pending_writes.split_off(0);
+                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+                        IoSlice::new(p)
+                    }).collect();
+
+                    let stream = self.write_stream.as_mut();
+                    match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+                        Poll::Pending => Poll::Pending,
+
+                        Poll::Ready(Ok(bytes_written)) => {
+                            trace!("Wrote {} bytes to network stream", bytes_written);
+                            Poll::Ready(Ok(()))
+                        },
+
+                        Poll::Ready(Err(_e)) => {
+                            error!("Encountered error when writing to network stream");
+                            Poll::Ready(Err(RecvError))
+                        },
+                    }
+                },
+
+                Poll::Ready(Err(_e)) => {
+                    error!("Encountered error when flushing network stream");
+                    Poll::Ready(Err(RecvError))
+                }
+            }
         } else {
-            Poll::Ready(Err(RecvError))
+            Poll::Ready(Ok(()))
+        };
+
+        match flush {
+            Poll::Pending => Poll::Pending,
+
+            Poll::Ready(Ok(_)) => {
+                let stream = self.write_stream.as_mut();
+
+                match stream.poll_close(cx) {
+                    Poll::Pending => Poll::Pending,
+
+                    Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
+
+                    Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
+                }
+            },
+
+            err => err,
         }
     }
 }