]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/reader.rs
remove unstable features for doc comments
[connect-rs.git] / src / reader.rs
index 568e40a359ceaee18fbbef74c2675fa0d4b60745..1e9c03155e6ef63e6268fd35ddc6f2aed0086048 100644 (file)
-use crate::schema::StitchMessage;
+use crate::protocol::ConnectDatagram;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
-use bytes::{Buf, BytesMut};
+use bytes::BytesMut;
 use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, Stream};
+use futures::{AsyncRead, Stream};
 use log::*;
-use protobuf::Message;
 use std::convert::TryInto;
 
-pub use futures::SinkExt;
-pub use futures::StreamExt;
-use protobuf::well_known_types::Any;
+pub use futures::{SinkExt, StreamExt};
 
+/// A default buffer size to read in bytes and then deserialize as messages.
 const BUFFER_SIZE: usize = 8192;
 
-pub struct StitchConnectionReader {
+/// An interface to read messages from the network connection.
+///
+/// Implements the `Stream` trait to asynchronously read messages from the network connection.
+///
+/// # Example
+///
+/// Basic usage:
+///
+/// ```ignore
+/// while let Some(msg) = reader.next().await {
+///   // handle the received message
+/// }
+/// ```
+///
+/// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
+/// example program or other client example programs for a more thorough showcase.
+///
+
+pub struct ConnectionReader {
     local_addr: SocketAddr,
     peer_addr: SocketAddr,
-    read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+    read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
+    buffer: Option<BytesMut>,
     pending_read: Option<BytesMut>,
+    pending_datagram: Option<usize>,
+    closed: bool,
 }
 
-impl StitchConnectionReader {
+impl ConnectionReader {
+    /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
+    /// socket metadata.
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
-        read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+        read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
     ) -> Self {
+        let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
+        buffer.resize(BUFFER_SIZE, 0);
+
         Self {
             local_addr,
             peer_addr,
             read_stream,
+            buffer: Some(buffer),
             pending_read: None,
+            pending_datagram: None,
+            closed: false,
         }
     }
 
+    /// Get the local IP address and port.
     pub fn local_addr(&self) -> SocketAddr {
         self.local_addr.clone()
     }
 
+    /// Get the peer IP address and port.
     pub fn peer_addr(&self) -> SocketAddr {
         self.peer_addr.clone()
     }
-}
 
-impl Stream for StitchConnectionReader {
-    type Item = Any;
+    /// Check if the `Stream` of messages from the network is closed.
+    pub fn is_closed(&self) -> bool {
+        self.closed
+    }
+
+    pub(crate) fn close_stream(&mut self) {
+        debug!("Closing the stream for connection with {}", self.peer_addr);
+        self.buffer.take();
+        self.pending_datagram.take();
+        self.pending_read.take();
+        self.closed = true;
+    }
+}
 
-    fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        let mut buffer = BytesMut::new();
-        buffer.resize(BUFFER_SIZE, 0);
+impl Stream for ConnectionReader {
+    type Item = ConnectDatagram;
 
-        debug!("Starting new read loop for {}", self.local_addr);
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         loop {
-            trace!("Reading from the stream");
-            match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
-                Ok(mut bytes_read) => {
-                    if bytes_read > 0 {
-                        debug!("Read {} bytes from the network stream", bytes_read)
-                    }
+            if let Some(size) = self.pending_datagram.take() {
+                if let Some(pending_buf) = self.pending_read.take() {
+                    if pending_buf.len() >= size {
+                        trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
+                        let mut data_buf = pending_buf;
+                        let pending_buf = data_buf.split_off(size);
+
+                        let datagram = ConnectDatagram::decode(data_buf.to_vec()).expect(
+                            "could not construct ConnectDatagram from bytes despite explicit check",
+                        );
 
-                    if let Some(mut pending_buf) = self.pending_read.take() {
-                        debug!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
-                        bytes_read += pending_buf.len();
+                        trace!("deserialized message of size {} bytes", datagram.size());
+                        return match datagram.version() {
+                            // do some special work based on version number if necessary
+                            _ => {
+                                if pending_buf.len() >= std::mem::size_of::<u32>() {
+                                    trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
+
+                                    let mut size_buf = pending_buf;
+                                    let pending_buf =
+                                        size_buf.split_off(std::mem::size_of::<u32>());
+                                    let size = u32::from_be_bytes(
+                                        size_buf
+                                            .to_vec()
+                                            .as_slice()
+                                            .try_into()
+                                            .expect("could not parse bytes into u32"),
+                                    ) as usize;
+
+                                    self.pending_datagram.replace(size);
+                                    self.pending_read.replace(pending_buf);
+                                } else {
+                                    trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
+                                    self.pending_read.replace(pending_buf);
+                                }
+
+                                trace!("returning deserialized datagram to user");
+                                Poll::Ready(Some(datagram))
+                            }
+                        };
+                    } else {
+                        trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
+                        self.pending_datagram.replace(size);
+                        self.pending_read.replace(pending_buf);
+                    }
+                } else {
+                    unreachable!()
+                }
+            }
 
-                        pending_buf.unsplit(buffer);
-                        buffer = pending_buf;
+            let mut buffer = if let Some(buffer) = self.buffer.take() {
+                trace!("prepare buffer to read from the network stream");
+                buffer
+            } else {
+                trace!("construct new buffer to read from the network stream");
+                let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
+                buffer.resize(BUFFER_SIZE, 0);
+                buffer
+            };
+
+            trace!("reading from the network stream");
+            let stream = self.read_stream.as_mut();
+            match stream.poll_read(cx, &mut buffer) {
+                Poll::Ready(Ok(bytes_read)) => {
+                    if bytes_read > 0 {
+                        trace!("read {} bytes from the network stream", bytes_read);
+                    } else {
+                        self.close_stream();
+                        return Poll::Ready(None);
                     }
 
-                    let mut bytes_read_u64: u64 = bytes_read.try_into().expect(
-                        format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
+                    let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() {
+                        trace!("preparing {} pending bytes", pending_buf.len());
+                        pending_buf
+                    } else {
+                        trace!("constructing new pending bytes");
+                        BytesMut::new()
+                    };
+
+                    trace!(
+                        "prepending incomplete data ({} bytes) from earlier read of network stream",
+                        pending_buf.len()
                     );
-                    while bytes_read_u64 > 0 {
-                        debug!(
-                            "{} bytes from network stream still unprocessed",
-                            bytes_read_u64
+                    pending_buf.extend_from_slice(&buffer[0..bytes_read]);
+
+                    if self.pending_datagram.is_none()
+                        && pending_buf.len() >= std::mem::size_of::<u32>()
+                    {
+                        trace!(
+                            "can deserialize size of next datagram from remaining {} pending bytes",
+                            pending_buf.len()
                         );
-
-                        buffer.resize(bytes_read, 0);
-                        debug!("{:?}", buffer.as_ref());
-
-                        match StitchMessage::parse_from_bytes(buffer.as_ref()) {
-                            Ok(mut data) => {
-                                let serialized_size = data.compute_size();
-                                debug!("Deserialized message of size {} bytes", serialized_size);
-
-                                buffer.advance(serialized_size as usize);
-
-                                let serialized_size_u64: u64 = serialized_size.try_into().expect(
-                                    format!(
-                                        "Conversion from usize ({}) to u64 failed",
-                                        serialized_size
-                                    )
-                                    .as_str(),
-                                );
-                                bytes_read_u64 -= serialized_size_u64;
-                                debug!("{} bytes still unprocessed", bytes_read_u64);
-
-                                debug!("Sending deserialized message downstream");
-                                return Poll::Ready(Some(data.take_payload()));
-                            }
-
-                            Err(err) => {
-                                warn!(
-                                    "Could not deserialize data from the received bytes: {:#?}",
-                                    err
-                                );
-
-                                self.pending_read = Some(buffer);
-                                buffer = BytesMut::new();
-                                break;
-                            }
-                        }
+                        let mut size_buf = pending_buf;
+                        let pending_buf = size_buf.split_off(std::mem::size_of::<u32>());
+                        let size = u32::from_be_bytes(
+                            size_buf
+                                .to_vec()
+                                .as_slice()
+                                .try_into()
+                                .expect("could not parse bytes into u32"),
+                        ) as usize;
+
+                        self.pending_datagram.replace(size);
+                        self.pending_read.replace(pending_buf);
+                    } else {
+                        trace!("size of next datagram already deserialized");
+                        self.pending_read.replace(pending_buf);
                     }
 
-                    buffer.resize(BUFFER_SIZE, 0);
+                    trace!("finished reading from stream and storing buffer");
+                    self.buffer.replace(buffer);
                 }
 
-                Err(_err) => return Poll::Pending,
+                Poll::Ready(Err(err)) => {
+                    error!(
+                        "Encountered error when trying to read from network stream {}",
+                        err
+                    );
+                    self.close_stream();
+                    return Poll::Ready(None);
+                }
+
+                Poll::Pending => {
+                    self.buffer.replace(buffer);
+                    return Poll::Pending;
+                }
             }
         }
     }