-use crate::schema::StitchMessage;
+use crate::schema::ConnectionMessage;
use async_std::net::SocketAddr;
use async_std::pin::Pin;
use bytes::{Buf, BytesMut};
use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, Stream};
+use futures::{AsyncRead, Stream};
use log::*;
use protobuf::Message;
use std::convert::TryInto;
const BUFFER_SIZE: usize = 8192;
-pub struct StitchConnectionReader {
- local_addr: SocketAddr,
- peer_addr: SocketAddr,
- read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+pub struct ConnectionReader {
+ local_addr: SocketAddr,
+ peer_addr: SocketAddr,
+ read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
pending_read: Option<BytesMut>,
+ closed: bool,
}
-impl StitchConnectionReader {
+impl ConnectionReader {
pub fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
- read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+ read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
) -> Self {
Self {
local_addr,
peer_addr,
read_stream,
pending_read: None,
+ closed: false,
}
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr.clone()
}
+
+ pub fn is_closed(&self) -> bool {
+ self.closed
+ }
+
+ pub(crate) fn close_stream(&mut self) {
+ trace!("Closing the stream for connection with {}", self.peer_addr);
+ self.pending_read.take();
+ self.closed = true;
+ }
}
-impl Stream for StitchConnectionReader {
+impl Stream for ConnectionReader {
type Item = Any;
- fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buffer = BytesMut::new();
buffer.resize(BUFFER_SIZE, 0);
- debug!("Starting new read loop for {}", self.local_addr);
+ trace!("Starting new read loop for {}", self.local_addr);
loop {
trace!("Reading from the stream");
- match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
- Ok(mut bytes_read) => {
+ let stream = self.read_stream.as_mut();
+
+ match stream.poll_read(cx, &mut buffer) {
+ Poll::Pending => return Poll::Pending,
+
+ Poll::Ready(Ok(mut bytes_read)) => {
if bytes_read > 0 {
- debug!("Read {} bytes from the network stream", bytes_read)
+ trace!("Read {} bytes from the network stream", bytes_read)
+ } else if self.pending_read.is_none() {
+ self.close_stream();
+ return Poll::Ready(None)
}
if let Some(mut pending_buf) = self.pending_read.take() {
- debug!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
+ trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
bytes_read += pending_buf.len();
pending_buf.unsplit(buffer);
format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
);
while bytes_read_u64 > 0 {
- debug!(
+ trace!(
"{} bytes from network stream still unprocessed",
bytes_read_u64
);
buffer.resize(bytes_read, 0);
- debug!("{:?}", buffer.as_ref());
- match StitchMessage::parse_from_bytes(buffer.as_ref()) {
+ match ConnectionMessage::parse_from_bytes(buffer.as_ref()) {
Ok(mut data) => {
let serialized_size = data.compute_size();
- debug!("Deserialized message of size {} bytes", serialized_size);
+ trace!("Deserialized message of size {} bytes", serialized_size);
buffer.advance(serialized_size as usize);
.as_str(),
);
bytes_read_u64 -= serialized_size_u64;
- debug!("{} bytes still unprocessed", bytes_read_u64);
+ trace!("{} bytes still unprocessed", bytes_read_u64);
- debug!("Sending deserialized message downstream");
+ trace!("Sending deserialized message downstream");
return Poll::Ready(Some(data.take_payload()));
}
buffer.resize(BUFFER_SIZE, 0);
}
- Err(_err) => return Poll::Pending,
+ // Close the stream
+ Poll::Ready(Err(_e)) => {
+ self.close_stream();
+ return Poll::Ready(None)
+ }
}
}
}