use async_std::pin::Pin;
use bytes::{Buf, BytesMut};
use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, Stream};
+use futures::{AsyncRead, Stream};
use log::*;
use protobuf::Message;
use std::convert::TryInto;
pub struct ConnectionReader {
local_addr: SocketAddr,
peer_addr: SocketAddr,
- read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+ read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
pending_read: Option<BytesMut>,
closed: bool,
}
pub fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
- read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
+ read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
) -> Self {
Self {
local_addr,
pub fn is_closed(&self) -> bool {
self.closed
}
+
+ pub(crate) fn close_stream(&mut self) {
+ trace!("Closing the stream for connection with {}", self.peer_addr);
+ self.pending_read.take();
+ self.closed = true;
+ }
}
impl Stream for ConnectionReader {
type Item = Any;
- fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+ fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut buffer = BytesMut::new();
buffer.resize(BUFFER_SIZE, 0);
trace!("Starting new read loop for {}", self.local_addr);
loop {
trace!("Reading from the stream");
- match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
- Ok(mut bytes_read) => {
+ let stream = self.read_stream.as_mut();
+
+ match stream.poll_read(cx, &mut buffer) {
+ Poll::Pending => return Poll::Pending,
+
+ Poll::Ready(Ok(mut bytes_read)) => {
if bytes_read > 0 {
trace!("Read {} bytes from the network stream", bytes_read)
- } else if bytes_read == 0 && self.pending_read.is_none() {
- return Poll::Pending;
+ } else if self.pending_read.is_none() {
+ self.close_stream();
+ return Poll::Ready(None)
}
if let Some(mut pending_buf) = self.pending_read.take() {
}
// Close the stream
- Err(_err) => {
- trace!("Closing the stream");
- self.pending_read.take();
- self.closed = true;
- return Poll::Ready(None);
+ Poll::Ready(Err(_e)) => {
+ self.close_stream();
+ Poll::Ready(None)
}
}
}
use async_channel::RecvError;
use async_std::net::SocketAddr;
use async_std::pin::Pin;
+use futures::{AsyncWrite, Sink};
+use futures::io::IoSlice;
use futures::task::{Context, Poll};
-use futures::{AsyncWrite, AsyncWriteExt, Sink};
use log::*;
use protobuf::Message;
pub use futures::StreamExt;
pub struct ConnectionWriter {
- local_addr: SocketAddr,
- peer_addr: SocketAddr,
- write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
- pending_write: Option<ConnectionMessage>,
- closed: bool,
+ local_addr: SocketAddr,
+ peer_addr: SocketAddr,
+ write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
+ pending_writes: Vec<Vec<u8>>,
+ closed: bool,
}
impl ConnectionWriter {
pub fn new(
local_addr: SocketAddr,
peer_addr: SocketAddr,
- write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
+ write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
) -> Self {
Self {
local_addr,
peer_addr,
write_stream,
- pending_write: None,
+ pending_writes: Vec::new(),
closed: false,
}
}
pub fn is_closed(&self) -> bool {
self.closed
}
-
- fn send_to_conn(&mut self) -> Poll<Result<(), RecvError>> {
- if let Some(pending_msg) = self.pending_write.take() {
- trace!("Send pending message");
- if let Ok(buffer) = pending_msg.write_to_bytes() {
- let msg_size = buffer.len();
- trace!("{} bytes to be sent over network connection", msg_size);
-
- return if let Ok(_) =
- futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
- {
- if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) {
- trace!("Sent message of {} bytes", msg_size);
- Poll::Ready(Ok(()))
- } else {
- trace!("Encountered error while flushing queued bytes to network stream");
- Poll::Ready(Err(RecvError))
- }
- } else {
- error!("Encountered error when writing to network stream");
- Poll::Ready(Err(RecvError))
- };
- } else {
- error!("Encountered error when serializing message to bytes");
- return Poll::Ready(Err(RecvError));
- }
- } else {
- trace!("No message to send over connection");
- }
-
- Poll::Ready(Ok(()))
- }
}
impl<M: Message> Sink<M> for ConnectionWriter {
type Error = RecvError;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- if self.pending_write.is_some() {
- trace!("Connection not ready to send message yet, waiting for prior message");
- Poll::Pending
+ if self.is_closed() {
+ trace!("Connection is closed, cannot send message");
+ Poll::Ready(Err(RecvError))
} else {
trace!("Connection ready to send message");
Poll::Ready(Ok(()))
fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
trace!("Preparing message to be sent next");
- let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
- self.pending_write.replace(stitch_msg);
+ let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
- Ok(())
+ if let Ok(buffer) = msg.write_to_bytes() {
+ let msg_size = buffer.len();
+ trace!("Serialized pending message into {} bytes", msg_size);
+
+ self.pending_writes.push(buffer);
+
+ Ok(())
+ } else {
+ error!("Encountered error when serializing message to bytes");
+ Err(RecvError)
+ }
}
fn poll_flush(
mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
+ cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
- self.send_to_conn()
+ if self.pending_writes.len() > 0 {
+ let stream = self.write_stream.as_mut();
+
+ match stream.poll_flush(cx) {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(_)) => {
+ trace!("Sending pending bytes");
+
+ let pending = self.pending_writes.split_off(0);
+ let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+ IoSlice::new(p)
+ }).collect();
+
+ let stream = self.write_stream.as_mut();
+ match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(bytes_written)) => {
+ trace!("Wrote {} bytes to network stream", bytes_written);
+ Poll::Ready(Ok(()))
+ },
+
+ Poll::Ready(Err(_e)) => {
+ error!("Encountered error when writing to network stream");
+ Poll::Ready(Err(RecvError))
+ },
+ }
+ },
+
+ Poll::Ready(Err(_e)) => {
+ error!("Encountered error when flushing network stream");
+ Poll::Ready(Err(RecvError))
+ }
+ }
+ } else {
+ Poll::Ready(Ok(()))
+ }
}
fn poll_close(
mut self: Pin<&mut Self>,
- _cx: &mut Context<'_>,
+ cx: &mut Context<'_>,
) -> Poll<Result<(), Self::Error>> {
- let _ = self.send_to_conn();
-
self.closed = true;
- if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
- Poll::Ready(Ok(()))
+
+ let flush = if self.pending_writes.len() > 0 {
+ let stream = self.write_stream.as_mut();
+
+ match stream.poll_flush(cx) {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(_)) => {
+ trace!("Sending pending bytes");
+
+ let pending = self.pending_writes.split_off(0);
+ let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
+ IoSlice::new(p)
+ }).collect();
+
+ let stream = self.write_stream.as_mut();
+ match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(bytes_written)) => {
+ trace!("Wrote {} bytes to network stream", bytes_written);
+ Poll::Ready(Ok(()))
+ },
+
+ Poll::Ready(Err(_e)) => {
+ error!("Encountered error when writing to network stream");
+ Poll::Ready(Err(RecvError))
+ },
+ }
+ },
+
+ Poll::Ready(Err(_e)) => {
+ error!("Encountered error when flushing network stream");
+ Poll::Ready(Err(RecvError))
+ }
+ }
} else {
- Poll::Ready(Err(RecvError))
+ Poll::Ready(Ok(()))
+ };
+
+ match flush {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(_)) => {
+ let stream = self.write_stream.as_mut();
+
+ match stream.poll_close(cx) {
+ Poll::Pending => Poll::Pending,
+
+ Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
+
+ Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
+ }
+ },
+
+ err => err,
}
}
}