From 7afb21ae61aaeb1255e1fe3a584c6733ef2e5d84 Mon Sep 17 00:00:00 2001 From: Sachandhan Ganesh Date: Sat, 13 Feb 2021 22:47:24 -0800 Subject: [PATCH] refactor read/write for correctness and ordering of messages --- examples/tcp-client/Cargo.toml | 4 +- examples/tcp-client/src/main.rs | 4 +- examples/tcp-echo-server/Cargo.toml | 4 +- examples/tls-client/Cargo.toml | 4 +- examples/tls-client/src/main.rs | 4 +- examples/tls-echo-server/Cargo.toml | 4 +- src/lib.rs | 11 +- src/protocol.rs | 132 ++++++++++--------- src/reader.rs | 189 ++++++++++++++++++---------- src/tls/mod.rs | 14 +-- src/writer.rs | 26 ++-- 11 files changed, 232 insertions(+), 164 deletions(-) diff --git a/examples/tcp-client/Cargo.toml b/examples/tcp-client/Cargo.toml index 4097809..66da115 100644 --- a/examples/tcp-client/Cargo.toml +++ b/examples/tcp-client/Cargo.toml @@ -7,8 +7,8 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.31" -async-std = { version = "1.6.2", features = ["attributes"] } +anyhow = "1.0" +async-std = { version = "1.9.0", features = ["attributes"] } env_logger = "0.7" log = "0.4" diff --git a/examples/tcp-client/src/main.rs b/examples/tcp-client/src/main.rs index 5e15ada..827341d 100644 --- a/examples/tcp-client/src/main.rs +++ b/examples/tcp-client/src/main.rs @@ -28,12 +28,10 @@ async fn main() -> anyhow::Result<()> { // wait for the server to reply with an ack if let Some(mut reply) = conn.reader().next().await { - info!("Received message"); - let data = reply.take_data().unwrap(); let msg = String::from_utf8(data)?; - info!("Unpacked reply: {}", msg); + info!("Received message: {}", msg); } Ok(()) diff --git a/examples/tcp-echo-server/Cargo.toml b/examples/tcp-echo-server/Cargo.toml index e4417c7..5becda0 100644 --- a/examples/tcp-echo-server/Cargo.toml +++ b/examples/tcp-echo-server/Cargo.toml @@ -7,8 +7,8 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.31" -async-std = { version = "1.6.2", features = ["attributes"] } +anyhow = "1.0" +async-std = { version = "1.9.0", features = ["attributes"] } env_logger = "0.7" log = "0.4" diff --git a/examples/tls-client/Cargo.toml b/examples/tls-client/Cargo.toml index 33013c4..1745ba2 100644 --- a/examples/tls-client/Cargo.toml +++ b/examples/tls-client/Cargo.toml @@ -7,8 +7,8 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.31" -async-std = { version = "1.6.2", features = ["attributes"] } +anyhow = "1.0" +async-std = { version = "1.9.0", features = ["attributes"] } env_logger = "0.7" log = "0.4" diff --git a/examples/tls-client/src/main.rs b/examples/tls-client/src/main.rs index c0b9280..cee6750 100644 --- a/examples/tls-client/src/main.rs +++ b/examples/tls-client/src/main.rs @@ -33,12 +33,10 @@ async fn main() -> anyhow::Result<()> { // wait for the server to reply with an ack if let Some(mut reply) = conn.reader().next().await { - info!("Received message"); - let data = reply.take_data().unwrap(); let msg = String::from_utf8(data)?; - info!("Unpacked reply: {}", msg); + info!("Received message: {}", msg); } Ok(()) diff --git a/examples/tls-echo-server/Cargo.toml b/examples/tls-echo-server/Cargo.toml index 5f8478c..cbf276c 100644 --- a/examples/tls-echo-server/Cargo.toml +++ b/examples/tls-echo-server/Cargo.toml @@ -7,8 +7,8 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] -anyhow = "1.0.31" -async-std = { version = "1.6.2", features = ["attributes"] } +anyhow = "1.0" +async-std = { version = "1.9.0", features = ["attributes"] } env_logger = "0.7" log = "0.4" diff --git a/src/lib.rs b/src/lib.rs index 033d389..522db96 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,18 +17,18 @@ mod protocol; mod reader; pub mod tcp; +mod writer; #[cfg(feature = "tls")] #[doc(cfg(feature = "tls"))] pub mod tls; -mod writer; +use async_std::{net::SocketAddr, pin::Pin}; +use futures::{AsyncRead, AsyncWrite}; -pub use crate::protocol::{ConnectDatagram, DatagramEmptyError}; +pub use crate::protocol::{ConnectDatagram, DatagramError}; pub use crate::reader::ConnectionReader; pub use crate::writer::{ConnectionWriteError, ConnectionWriter}; -use async_std::{net::SocketAddr, pin::Pin}; -use futures::{AsyncRead, AsyncWrite}; pub use futures::{SinkExt, StreamExt}; /// Wrapper around a [`ConnectionReader`] and [`ConnectionWriter`] to read and write on a network @@ -98,3 +98,6 @@ impl Connection { return peer_addr; } } + +#[cfg(test)] +mod tests {} diff --git a/src/protocol.rs b/src/protocol.rs index 73da4bc..a2fe0b0 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -1,29 +1,39 @@ +use std::array::TryFromSliceError; +use std::convert::TryInto; use std::error::Error; -use std::io::Read; -const VERSION: u8 = 1; +const VERSION: u16 = 1; -/// Encountered when trying to construct a [`ConnectDatagram`] with an empty message body. +/// Encountered when there is an issue constructing, serializing, or deserializing a [`ConnectDatagram`]. /// #[derive(Debug, Clone)] -pub struct DatagramEmptyError; +pub enum DatagramError { + /// Tried to construct a [`ConnectDatagram`] with an empty message body. + EmptyBody, -impl Error for DatagramEmptyError {} + /// Did not provide the complete byte-string necessary to deserialize the [`ConnectDatagram`]. + IncompleteBytes, -impl std::fmt::Display for DatagramEmptyError { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!( - f, - "datagram cannot be constructed when provided payload is empty" - ) + BytesParseFail(TryFromSliceError), +} + +impl Error for DatagramError {} + +impl std::fmt::Display for DatagramError { + fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DatagramError::EmptyBody => formatter.write_str("tried to construct a `ConnectDatagram` with an empty message body"), + DatagramError::IncompleteBytes => formatter.write_str("did not provide the complete byte-string necessary to deserialize the `ConnectDatagram`"), + DatagramError::BytesParseFail(err) => std::fmt::Display::fmt(err, formatter), + } } } -/// A simple packet format containing a version, recipient tag, and message body. +/// A simple size-prefixed packet format containing a version tag, recipient tag, and message body. /// #[derive(Clone)] pub struct ConnectDatagram { - version: u8, + version: u16, recipient: u16, data: Option>, } @@ -31,13 +41,13 @@ pub struct ConnectDatagram { impl ConnectDatagram { /// Creates a new [`ConnectDatagram`] based on an intended recipient and message body. /// - /// This will return a [`DatagramEmptyError`] if the `data` parameter contains no bytes, or - /// in other words, when there is no message body. + /// This will return a [EmptyBody](`DatagramError::EmptyBody`) error if the `data` parameter + /// contains no bytes, or in other words, when there is no message body. /// /// The version field is decided by the library version and used to maintain backwards /// compatibility with previous datagram formats. /// - pub fn new(recipient: u16, data: Vec) -> Result { + pub fn new(recipient: u16, data: Vec) -> Result { if data.len() > 0 { Ok(Self { version: VERSION, @@ -45,13 +55,13 @@ impl ConnectDatagram { data: Some(data), }) } else { - Err(DatagramEmptyError) + Err(DatagramError::EmptyBody) } } /// Gets the version number of the datagram. /// - pub fn version(&self) -> u8 { + pub fn version(&self) -> u16 { self.version } @@ -73,7 +83,9 @@ impl ConnectDatagram { self.data.take() } - /// Calculates the serialized byte-size of the datagram. + /// Calculates the size-prefixed serialized byte-size of the datagram. + /// + /// This will include the byte-size of the size-prefix. /// pub fn size(&self) -> usize { let data_len = if let Some(data) = self.data() { @@ -82,7 +94,7 @@ impl ConnectDatagram { 0 }; - 3 + data_len + 8 + data_len } /// Constructs a serialized representation of the datagram contents. @@ -103,48 +115,47 @@ impl ConnectDatagram { /// Serializes the datagram. /// pub fn encode(self) -> Vec { - let size: u32 = (self.size()) as u32; + let content_encoded = self.bytes(); + let size: u32 = (content_encoded.len()) as u32; let mut bytes = Vec::from(size.to_be_bytes()); - bytes.extend(self.bytes()); + bytes.extend(content_encoded); return bytes; } - /// Deserializes the datagram from a `source`. + /// Deserializes the datagram from a buffer. /// - pub fn decode(source: &mut (dyn Read + Send + Sync)) -> anyhow::Result { - // payload size - let mut payload_size_bytes: [u8; 4] = [0; 4]; - source.read_exact(&mut payload_size_bytes)?; - let payload_size = u32::from_be_bytes(payload_size_bytes); - - // read whole payload - let mut payload_bytes = vec![0; payload_size as usize]; - source.read_exact(payload_bytes.as_mut_slice())?; - - // version - let version_bytes = payload_bytes.remove(0); - let version = u8::from_be(version_bytes); - - // recipient - let mut recipient_bytes: [u8; 2] = [0; 2]; - for i in 0..recipient_bytes.len() { - recipient_bytes[i] = payload_bytes.remove(0); - } - let recipient = u16::from_be_bytes(recipient_bytes); - - // data - let data = payload_bytes; - - if data.len() > 0 { - Ok(Self { - version, - recipient, - data: Some(data), - }) + /// The buffer **should not** contain the size-prefix, and only contain the byte contents of the + /// struct (version, recipient, and message body). + /// + pub fn decode(mut buffer: Vec) -> Result { + if buffer.len() > 4 { + let mem_size = std::mem::size_of::(); + let data = buffer.split_off(mem_size * 2); + + let (version_bytes, recipient_bytes) = buffer.split_at(mem_size); + + match version_bytes.try_into() { + Ok(version_slice) => match recipient_bytes.try_into() { + Ok(recipient_slice) => { + let version = u16::from_be_bytes(version_slice); + let recipient = u16::from_be_bytes(recipient_slice); + + Ok(Self { + version, + recipient, + data: Some(data), + }) + } + + Err(err) => Err(DatagramError::BytesParseFail(err)), + }, + + Err(err) => Err(DatagramError::BytesParseFail(err)), + } } else { - Err(anyhow::Error::from(DatagramEmptyError)) + Err(DatagramError::IncompleteBytes) } } } @@ -152,10 +163,9 @@ impl ConnectDatagram { #[cfg(test)] mod tests { use crate::protocol::ConnectDatagram; - use std::io::Cursor; #[test] - fn encoded_size() -> anyhow::Result<()> { + fn serialized_size() -> anyhow::Result<()> { let mut data = Vec::new(); for _ in 0..5 { data.push(1); @@ -163,7 +173,7 @@ mod tests { assert_eq!(5, data.len()); let sample = ConnectDatagram::new(1, data)?; - assert_eq!(7 + 5, sample.encode().len()); + assert_eq!(8 + 5, sample.encode().len()); Ok(()) } @@ -193,12 +203,14 @@ mod tests { assert_eq!(5, data.len()); let sample = ConnectDatagram::new(1, data)?; + let serialized_size = sample.size(); + assert_eq!(8 + 5, serialized_size); let mut payload = sample.encode(); - assert_eq!(7 + 5, payload.len()); + assert_eq!(serialized_size, payload.len()); - let mut cursor: Cursor<&mut [u8]> = Cursor::new(payload.as_mut()); - let sample_back_res = ConnectDatagram::decode(&mut cursor); + let payload = payload.split_off(std::mem::size_of::()); + let sample_back_res = ConnectDatagram::decode(payload); assert!(sample_back_res.is_ok()); let sample_back = sample_back_res.unwrap(); diff --git a/src/reader.rs b/src/reader.rs index b1b7248..3954c4e 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -1,14 +1,13 @@ use crate::protocol::ConnectDatagram; use async_std::net::SocketAddr; use async_std::pin::Pin; -use bytes::{Buf, BytesMut}; +use bytes::BytesMut; use futures::task::{Context, Poll}; use futures::{AsyncRead, Stream}; use log::*; -use std::io::Cursor; +use std::convert::TryInto; -pub use futures::SinkExt; -pub use futures::StreamExt; +pub use futures::{SinkExt, StreamExt}; /// A default buffer size to read in bytes and then deserialize as messages. const BUFFER_SIZE: usize = 8192; @@ -35,7 +34,9 @@ pub struct ConnectionReader { local_addr: SocketAddr, peer_addr: SocketAddr, read_stream: Pin>, + buffer: Option, pending_read: Option, + pending_datagram: Option, closed: bool, } @@ -47,11 +48,16 @@ impl ConnectionReader { peer_addr: SocketAddr, read_stream: Pin>, ) -> Self { + let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); + buffer.resize(BUFFER_SIZE, 0); + Self { local_addr, peer_addr, read_stream, + buffer: Some(buffer), pending_read: None, + pending_datagram: None, closed: false, } } @@ -72,7 +78,9 @@ impl ConnectionReader { } pub(crate) fn close_stream(&mut self) { - trace!("Closing the stream for connection with {}", self.peer_addr); + trace!("closing the stream for connection with {}", self.peer_addr); + self.buffer.take(); + self.pending_datagram.take(); self.pending_read.take(); self.closed = true; } @@ -82,80 +90,133 @@ impl Stream for ConnectionReader { type Item = ConnectDatagram; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let mut buffer = BytesMut::new(); - buffer.resize(BUFFER_SIZE, 0); - - trace!("Starting new read loop for {}", self.local_addr); loop { - trace!("Reading from the stream"); - let stream = self.read_stream.as_mut(); + if let Some(size) = self.pending_datagram.take() { + if let Some(pending_buf) = self.pending_read.take() { + if pending_buf.len() >= size { + trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size); + let mut data_buf = pending_buf; + let pending_buf = data_buf.split_off(size); + + let datagram = ConnectDatagram::decode(data_buf.to_vec()).expect( + "could not construct ConnectDatagram from bytes despite explicit check", + ); + + trace!("deserialized message of size {} bytes", datagram.size()); + return match datagram.version() { + // do some special work based on version number if necessary + _ => { + if pending_buf.len() >= std::mem::size_of::() { + trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len()); + + let mut size_buf = pending_buf; + let pending_buf = + size_buf.split_off(std::mem::size_of::()); + let size = u32::from_be_bytes( + size_buf + .to_vec() + .as_slice() + .try_into() + .expect("could not parse bytes into u32"), + ) as usize; + + self.pending_datagram.replace(size); + self.pending_read.replace(pending_buf); + } else { + trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len()); + self.pending_read.replace(pending_buf); + } - match stream.poll_read(cx, &mut buffer) { - Poll::Pending => return Poll::Pending, + trace!("returning deserialized datagram to user"); + Poll::Ready(Some(datagram)) + } + }; + } else { + trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size); + self.pending_datagram.replace(size); + self.pending_read.replace(pending_buf); + } + } else { + unreachable!() + } + } - Poll::Ready(Ok(mut bytes_read)) => { + let mut buffer = if let Some(buffer) = self.buffer.take() { + trace!("prepare buffer to read from the network stream"); + buffer + } else { + trace!("construct new buffer to read from the network stream"); + let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); + buffer.resize(BUFFER_SIZE, 0); + buffer + }; + + trace!("reading from the network stream"); + let stream = self.read_stream.as_mut(); + match stream.poll_read(cx, &mut buffer) { + Poll::Ready(Ok(bytes_read)) => { if bytes_read > 0 { - trace!("Read {} bytes from the network stream", bytes_read) - } else if self.pending_read.is_none() { + trace!("read {} bytes from the network stream", bytes_read); + } else { self.close_stream(); return Poll::Ready(None); } - if let Some(mut pending_buf) = self.pending_read.take() { - trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len()); - bytes_read += pending_buf.len(); - - pending_buf.unsplit(buffer); - buffer = pending_buf; + let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() { + trace!("preparing {} pending bytes", pending_buf.len()); + pending_buf + } else { + trace!("constructing new pending bytes"); + BytesMut::new() + }; + + trace!( + "prepending incomplete data ({} bytes) from earlier read of network stream", + pending_buf.len() + ); + pending_buf.extend_from_slice(&buffer[0..bytes_read]); + + if self.pending_datagram.is_none() + && pending_buf.len() >= std::mem::size_of::() + { + trace!( + "can deserialize size of next datagram from remaining {} pending bytes", + pending_buf.len() + ); + let mut size_buf = pending_buf; + let pending_buf = size_buf.split_off(std::mem::size_of::()); + let size = u32::from_be_bytes( + size_buf + .to_vec() + .as_slice() + .try_into() + .expect("could not parse bytes into u32"), + ) as usize; + + self.pending_datagram.replace(size); + self.pending_read.replace(pending_buf); + } else { + trace!("size of next datagram already deserialized"); + self.pending_read.replace(pending_buf); } - while bytes_read > 0 { - trace!("{} bytes from network stream still unprocessed", bytes_read); - - buffer.resize(bytes_read, 0); - - let mut cursor = Cursor::new(buffer.as_mut()); - match ConnectDatagram::decode(&mut cursor) { - Ok(data) => { - return match data.version() { - _ => { - let serialized_size = data.size(); - trace!( - "Deserialized message of size {} bytes", - serialized_size - ); - - buffer.advance(serialized_size); - bytes_read -= serialized_size; - trace!("{} bytes still unprocessed", bytes_read); - - trace!("Sending deserialized message downstream"); - Poll::Ready(Some(data)) - } - } - } - - Err(err) => { - warn!( - "Could not deserialize data from the received bytes: {:#?}", - err - ); - - self.pending_read = Some(buffer); - buffer = BytesMut::new(); - break; - } - } - } - - buffer.resize(BUFFER_SIZE, 0); + trace!("finished reading from stream and storing buffer"); + self.buffer.replace(buffer); } - // Close the stream - Poll::Ready(Err(_e)) => { + Poll::Ready(Err(err)) => { + error!( + "Encountered error when trying to read from network stream {}", + err + ); self.close_stream(); return Poll::Ready(None); } + + Poll::Pending => { + self.buffer.replace(buffer); + return Poll::Pending; + } } } } diff --git a/src/tls/mod.rs b/src/tls/mod.rs index eeefd47..a0b7f10 100644 --- a/src/tls/mod.rs +++ b/src/tls/mod.rs @@ -12,20 +12,20 @@ pub(crate) use crate::Connection; pub(crate) mod client; pub(crate) mod listener; -#[cfg(feature = "tls")] -#[doc(cfg(feature = "tls"))] -pub use async_tls; +use async_std::net::TcpStream; +use async_tls::server; +use std::net::SocketAddr; pub use client::*; pub use listener::*; #[cfg(feature = "tls")] #[doc(cfg(feature = "tls"))] -pub use rustls; +pub use async_tls; -use async_std::net::TcpStream; -use async_tls::server; -use std::net::SocketAddr; +#[cfg(feature = "tls")] +#[doc(cfg(feature = "tls"))] +pub use rustls; /// Used to differentiate between an outgoing connection ([`TlsConnectionMetadata::Client`]) or /// incoming connection listener ([`TlsConnectionMetadata::Listener`]). diff --git a/src/writer.rs b/src/writer.rs index 9b48c02..040bafd 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -1,7 +1,6 @@ use crate::protocol::ConnectDatagram; use async_std::net::SocketAddr; use async_std::pin::Pin; -use futures::io::IoSlice; use futures::task::{Context, Poll}; use futures::{AsyncWrite, Sink}; use log::*; @@ -54,7 +53,7 @@ pub struct ConnectionWriter { local_addr: SocketAddr, peer_addr: SocketAddr, write_stream: Pin>, - pending_writes: Vec>, + pending_writes: Vec, closed: bool, } @@ -101,18 +100,15 @@ impl ConnectionWriter { Poll::Pending => Poll::Pending, Poll::Ready(Ok(_)) => { - trace!("Sending pending bytes"); - - let pending = self.pending_writes.split_off(0); - let writeable_vec: Vec = - pending.iter().map(|p| IoSlice::new(p)).collect(); - let stream = self.write_stream.as_mut(); - match stream.poll_write_vectored(cx, writeable_vec.as_slice()) { + + trace!("sending pending bytes to network stream"); + match stream.poll_write(cx, self.pending_writes.as_slice()) { Poll::Pending => Poll::Pending, Poll::Ready(Ok(bytes_written)) => { - trace!("Wrote {} bytes to network stream", bytes_written); + trace!("wrote {} bytes to network stream", bytes_written); + self.pending_writes.clear(); Poll::Ready(Ok(())) } @@ -139,22 +135,22 @@ impl Sink for ConnectionWriter { fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { if self.is_closed() { - trace!("Connection is closed, cannot send message"); + trace!("connection is closed, cannot send message"); Poll::Ready(Err(ConnectionWriteError::ConnectionClosed)) } else { - trace!("Connection ready to send message"); + trace!("connection ready to send message"); Poll::Ready(Ok(())) } } fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> { - trace!("Preparing message to be sent next"); + trace!("preparing datagram to be queued for sending"); let buffer = item.encode(); let msg_size = buffer.len(); - trace!("Serialized pending message into {} bytes", msg_size); + trace!("serialized pending message into {} bytes", msg_size); - self.pending_writes.push(buffer); + self.pending_writes.extend(buffer); Ok(()) } -- 2.44.0