use crate::protocol::ConnectDatagram;
use async_std::net::SocketAddr;
use async_std::pin::Pin;
-use bytes::{Buf, BytesMut};
+use bytes::BytesMut;
use futures::task::{Context, Poll};
use futures::{AsyncRead, Stream};
use log::*;
-use std::io::Cursor;
+use std::convert::TryInto;
-pub use futures::SinkExt;
-pub use futures::StreamExt;
+pub use futures::{SinkExt, StreamExt};
/// A default buffer size to read in bytes and then deserialize as messages.
const BUFFER_SIZE: usize = 8192;
local_addr: SocketAddr,
peer_addr: SocketAddr,
read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
+ buffer: Option<BytesMut>,
pending_read: Option<BytesMut>,
+ pending_datagram: Option<usize>,
closed: bool,
}
peer_addr: SocketAddr,
read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
) -> Self {
+ let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
+ buffer.resize(BUFFER_SIZE, 0);
+
Self {
local_addr,
peer_addr,
read_stream,
+ buffer: Some(buffer),
pending_read: None,
+ pending_datagram: None,
closed: false,
}
}
}
pub(crate) fn close_stream(&mut self) {
- trace!("Closing the stream for connection with {}", self.peer_addr);
+ trace!("closing the stream for connection with {}", self.peer_addr);
+ self.buffer.take();
+ self.pending_datagram.take();
self.pending_read.take();
self.closed = true;
}
type Item = ConnectDatagram;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- let mut buffer = BytesMut::new();
- buffer.resize(BUFFER_SIZE, 0);
-
- trace!("Starting new read loop for {}", self.local_addr);
loop {
- trace!("Reading from the stream");
- let stream = self.read_stream.as_mut();
+ if let Some(size) = self.pending_datagram.take() {
+ if let Some(pending_buf) = self.pending_read.take() {
+ if pending_buf.len() >= size {
+ trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
+ let mut data_buf = pending_buf;
+ let pending_buf = data_buf.split_off(size);
+
+ let datagram = ConnectDatagram::decode(data_buf.to_vec()).expect(
+ "could not construct ConnectDatagram from bytes despite explicit check",
+ );
+
+ trace!("deserialized message of size {} bytes", datagram.size());
+ return match datagram.version() {
+ // do some special work based on version number if necessary
+ _ => {
+ if pending_buf.len() >= std::mem::size_of::<u32>() {
+ trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
+
+ let mut size_buf = pending_buf;
+ let pending_buf =
+ size_buf.split_off(std::mem::size_of::<u32>());
+ let size = u32::from_be_bytes(
+ size_buf
+ .to_vec()
+ .as_slice()
+ .try_into()
+ .expect("could not parse bytes into u32"),
+ ) as usize;
+
+ self.pending_datagram.replace(size);
+ self.pending_read.replace(pending_buf);
+ } else {
+ trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
+ self.pending_read.replace(pending_buf);
+ }
- match stream.poll_read(cx, &mut buffer) {
- Poll::Pending => return Poll::Pending,
+ trace!("returning deserialized datagram to user");
+ Poll::Ready(Some(datagram))
+ }
+ };
+ } else {
+ trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
+ self.pending_datagram.replace(size);
+ self.pending_read.replace(pending_buf);
+ }
+ } else {
+ unreachable!()
+ }
+ }
- Poll::Ready(Ok(mut bytes_read)) => {
+ let mut buffer = if let Some(buffer) = self.buffer.take() {
+ trace!("prepare buffer to read from the network stream");
+ buffer
+ } else {
+ trace!("construct new buffer to read from the network stream");
+ let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
+ buffer.resize(BUFFER_SIZE, 0);
+ buffer
+ };
+
+ trace!("reading from the network stream");
+ let stream = self.read_stream.as_mut();
+ match stream.poll_read(cx, &mut buffer) {
+ Poll::Ready(Ok(bytes_read)) => {
if bytes_read > 0 {
- trace!("Read {} bytes from the network stream", bytes_read)
- } else if self.pending_read.is_none() {
+ trace!("read {} bytes from the network stream", bytes_read);
+ } else {
self.close_stream();
return Poll::Ready(None);
}
- if let Some(mut pending_buf) = self.pending_read.take() {
- trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
- bytes_read += pending_buf.len();
-
- pending_buf.unsplit(buffer);
- buffer = pending_buf;
+ let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() {
+ trace!("preparing {} pending bytes", pending_buf.len());
+ pending_buf
+ } else {
+ trace!("constructing new pending bytes");
+ BytesMut::new()
+ };
+
+ trace!(
+ "prepending incomplete data ({} bytes) from earlier read of network stream",
+ pending_buf.len()
+ );
+ pending_buf.extend_from_slice(&buffer[0..bytes_read]);
+
+ if self.pending_datagram.is_none()
+ && pending_buf.len() >= std::mem::size_of::<u32>()
+ {
+ trace!(
+ "can deserialize size of next datagram from remaining {} pending bytes",
+ pending_buf.len()
+ );
+ let mut size_buf = pending_buf;
+ let pending_buf = size_buf.split_off(std::mem::size_of::<u32>());
+ let size = u32::from_be_bytes(
+ size_buf
+ .to_vec()
+ .as_slice()
+ .try_into()
+ .expect("could not parse bytes into u32"),
+ ) as usize;
+
+ self.pending_datagram.replace(size);
+ self.pending_read.replace(pending_buf);
+ } else {
+ trace!("size of next datagram already deserialized");
+ self.pending_read.replace(pending_buf);
}
- while bytes_read > 0 {
- trace!("{} bytes from network stream still unprocessed", bytes_read);
-
- buffer.resize(bytes_read, 0);
-
- let mut cursor = Cursor::new(buffer.as_mut());
- match ConnectDatagram::decode(&mut cursor) {
- Ok(data) => {
- return match data.version() {
- _ => {
- let serialized_size = data.size();
- trace!(
- "Deserialized message of size {} bytes",
- serialized_size
- );
-
- buffer.advance(serialized_size);
- bytes_read -= serialized_size;
- trace!("{} bytes still unprocessed", bytes_read);
-
- trace!("Sending deserialized message downstream");
- Poll::Ready(Some(data))
- }
- }
- }
-
- Err(err) => {
- warn!(
- "Could not deserialize data from the received bytes: {:#?}",
- err
- );
-
- self.pending_read = Some(buffer);
- buffer = BytesMut::new();
- break;
- }
- }
- }
-
- buffer.resize(BUFFER_SIZE, 0);
+ trace!("finished reading from stream and storing buffer");
+ self.buffer.replace(buffer);
}
- // Close the stream
- Poll::Ready(Err(_e)) => {
+ Poll::Ready(Err(err)) => {
+ error!(
+ "Encountered error when trying to read from network stream {}",
+ err
+ );
self.close_stream();
return Poll::Ready(None);
}
+
+ Poll::Pending => {
+ self.buffer.replace(buffer);
+ return Poll::Pending;
+ }
}
}
}