]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/reader.rs
add return stmt to fix bug
[connect-rs.git] / src / reader.rs
index 568e40a359ceaee18fbbef74c2675fa0d4b60745..6844e393f297148f09b0e5ef0b64c577fc923826 100644 (file)
@@ -1,9 +1,9 @@
-use crate::schema::StitchMessage;
+use crate::schema::ConnectionMessage;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
 use bytes::{Buf, BytesMut};
 use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, Stream};
+use futures::{AsyncRead, Stream};
 use log::*;
 use protobuf::Message;
 use std::convert::TryInto;
@@ -14,24 +14,26 @@ use protobuf::well_known_types::Any;
 
 const BUFFER_SIZE: usize = 8192;
 
-pub struct StitchConnectionReader {
-    local_addr: SocketAddr,
-    peer_addr: SocketAddr,
-    read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+pub struct ConnectionReader {
+    local_addr:   SocketAddr,
+    peer_addr:    SocketAddr,
+    read_stream:  Pin<Box<dyn AsyncRead + Send + Sync>>,
     pending_read: Option<BytesMut>,
+    closed:       bool,
 }
 
-impl StitchConnectionReader {
+impl ConnectionReader {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+        read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
     ) -> Self {
         Self {
             local_addr,
             peer_addr,
             read_stream,
             pending_read: None,
+            closed: false,
         }
     }
 
@@ -42,26 +44,43 @@ impl StitchConnectionReader {
     pub fn peer_addr(&self) -> SocketAddr {
         self.peer_addr.clone()
     }
+
+    pub fn is_closed(&self) -> bool {
+        self.closed
+    }
+
+    pub(crate) fn close_stream(&mut self) {
+        trace!("Closing the stream for connection with {}", self.peer_addr);
+        self.pending_read.take();
+        self.closed = true;
+    }
 }
 
-impl Stream for StitchConnectionReader {
+impl Stream for ConnectionReader {
     type Item = Any;
 
-    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         let mut buffer = BytesMut::new();
         buffer.resize(BUFFER_SIZE, 0);
 
-        debug!("Starting new read loop for {}", self.local_addr);
+        trace!("Starting new read loop for {}", self.local_addr);
         loop {
             trace!("Reading from the stream");
-            match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
-                Ok(mut bytes_read) => {
+            let stream = self.read_stream.as_mut();
+
+            match stream.poll_read(cx, &mut buffer) {
+                Poll::Pending => return Poll::Pending,
+
+                Poll::Ready(Ok(mut bytes_read)) => {
                     if bytes_read > 0 {
-                        debug!("Read {} bytes from the network stream", bytes_read)
+                        trace!("Read {} bytes from the network stream", bytes_read)
+                    } else if self.pending_read.is_none() {
+                        self.close_stream();
+                        return Poll::Ready(None)
                     }
 
                     if let Some(mut pending_buf) = self.pending_read.take() {
-                        debug!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
+                        trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
                         bytes_read += pending_buf.len();
 
                         pending_buf.unsplit(buffer);
@@ -72,18 +91,17 @@ impl Stream for StitchConnectionReader {
                         format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
                     );
                     while bytes_read_u64 > 0 {
-                        debug!(
+                        trace!(
                             "{} bytes from network stream still unprocessed",
                             bytes_read_u64
                         );
 
                         buffer.resize(bytes_read, 0);
-                        debug!("{:?}", buffer.as_ref());
 
-                        match StitchMessage::parse_from_bytes(buffer.as_ref()) {
+                        match ConnectionMessage::parse_from_bytes(buffer.as_ref()) {
                             Ok(mut data) => {
                                 let serialized_size = data.compute_size();
-                                debug!("Deserialized message of size {} bytes", serialized_size);
+                                trace!("Deserialized message of size {} bytes", serialized_size);
 
                                 buffer.advance(serialized_size as usize);
 
@@ -95,9 +113,9 @@ impl Stream for StitchConnectionReader {
                                     .as_str(),
                                 );
                                 bytes_read_u64 -= serialized_size_u64;
-                                debug!("{} bytes still unprocessed", bytes_read_u64);
+                                trace!("{} bytes still unprocessed", bytes_read_u64);
 
-                                debug!("Sending deserialized message downstream");
+                                trace!("Sending deserialized message downstream");
                                 return Poll::Ready(Some(data.take_payload()));
                             }
 
@@ -117,7 +135,11 @@ impl Stream for StitchConnectionReader {
                     buffer.resize(BUFFER_SIZE, 0);
                 }
 
-                Err(_err) => return Poll::Pending,
+                // Close the stream
+                Poll::Ready(Err(_e)) => {
+                    self.close_stream();
+                    return Poll::Ready(None)
+                }
             }
         }
     }