From: Sachandhan Ganesh Date: Wed, 20 Jan 2021 04:50:21 +0000 (-0800) Subject: don't block in poll_x fns, fixes conn closing issues X-Git-Url: https://git.lizzy.rs/?a=commitdiff_plain;h=5b7896f1a371a2977b6e9dee939b0e05ad9a0c65;p=connect-rs.git don't block in poll_x fns, fixes conn closing issues --- diff --git a/examples/tcp-client/src/schema/hello_world.rs b/examples/tcp-client/src/schema/hello_world.rs index 5af4935..1af771c 100644 --- a/examples/tcp-client/src/schema/hello_world.rs +++ b/examples/tcp-client/src/schema/hello_world.rs @@ -1,4 +1,4 @@ -// This file is generated by rust-protobuf 2.19.0. Do not edit +// This file is generated by rust-protobuf 2.20.0. Do not edit // @generated // https://github.com/rust-lang/rust-clippy/issues/702 @@ -21,7 +21,7 @@ /// Generated files are compatible only with the same version /// of protobuf runtime. -// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0; +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0; #[derive(PartialEq,Clone,Default)] pub struct HelloWorld { diff --git a/examples/tcp-echo-server/src/schema/hello_world.rs b/examples/tcp-echo-server/src/schema/hello_world.rs index 5af4935..1af771c 100644 --- a/examples/tcp-echo-server/src/schema/hello_world.rs +++ b/examples/tcp-echo-server/src/schema/hello_world.rs @@ -1,4 +1,4 @@ -// This file is generated by rust-protobuf 2.19.0. Do not edit +// This file is generated by rust-protobuf 2.20.0. Do not edit // @generated // https://github.com/rust-lang/rust-clippy/issues/702 @@ -21,7 +21,7 @@ /// Generated files are compatible only with the same version /// of protobuf runtime. -// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0; +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0; #[derive(PartialEq,Clone,Default)] pub struct HelloWorld { diff --git a/examples/tls-client/src/schema/hello_world.rs b/examples/tls-client/src/schema/hello_world.rs index 5af4935..1af771c 100644 --- a/examples/tls-client/src/schema/hello_world.rs +++ b/examples/tls-client/src/schema/hello_world.rs @@ -1,4 +1,4 @@ -// This file is generated by rust-protobuf 2.19.0. Do not edit +// This file is generated by rust-protobuf 2.20.0. Do not edit // @generated // https://github.com/rust-lang/rust-clippy/issues/702 @@ -21,7 +21,7 @@ /// Generated files are compatible only with the same version /// of protobuf runtime. -// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0; +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0; #[derive(PartialEq,Clone,Default)] pub struct HelloWorld { diff --git a/examples/tls-echo-server/src/schema/hello_world.rs b/examples/tls-echo-server/src/schema/hello_world.rs index 5af4935..1af771c 100644 --- a/examples/tls-echo-server/src/schema/hello_world.rs +++ b/examples/tls-echo-server/src/schema/hello_world.rs @@ -1,4 +1,4 @@ -// This file is generated by rust-protobuf 2.19.0. Do not edit +// This file is generated by rust-protobuf 2.20.0. Do not edit // @generated // https://github.com/rust-lang/rust-clippy/issues/702 @@ -21,7 +21,7 @@ /// Generated files are compatible only with the same version /// of protobuf runtime. -// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0; +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0; #[derive(PartialEq,Clone,Default)] pub struct HelloWorld { diff --git a/src/lib.rs b/src/lib.rs index 9239db5..c85c2cd 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ mod writer; pub use crate::reader::ConnectionReader; pub use crate::writer::ConnectionWriter; -use async_std::net::SocketAddr; +use async_std::{net::SocketAddr, pin::Pin}; use futures::{AsyncRead, AsyncWrite}; pub use futures::{SinkExt, StreamExt}; @@ -22,8 +22,8 @@ impl Connection { pub(crate) fn new( local_addr: SocketAddr, peer_addr: SocketAddr, - read_stream: Box, - write_stream: Box, + read_stream: Pin>, + write_stream: Pin>, ) -> Self { Self { local_addr, diff --git a/src/reader.rs b/src/reader.rs index 0452f6c..cfdc754 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -3,7 +3,7 @@ use async_std::net::SocketAddr; use async_std::pin::Pin; use bytes::{Buf, BytesMut}; use futures::task::{Context, Poll}; -use futures::{AsyncRead, AsyncReadExt, Stream}; +use futures::{AsyncRead, Stream}; use log::*; use protobuf::Message; use std::convert::TryInto; @@ -17,7 +17,7 @@ const BUFFER_SIZE: usize = 8192; pub struct ConnectionReader { local_addr: SocketAddr, peer_addr: SocketAddr, - read_stream: Box, + read_stream: Pin>, pending_read: Option, closed: bool, } @@ -26,7 +26,7 @@ impl ConnectionReader { pub fn new( local_addr: SocketAddr, peer_addr: SocketAddr, - read_stream: Box, + read_stream: Pin>, ) -> Self { Self { local_addr, @@ -48,24 +48,35 @@ impl ConnectionReader { pub fn is_closed(&self) -> bool { self.closed } + + pub(crate) fn close_stream(&mut self) { + trace!("Closing the stream for connection with {}", self.peer_addr); + self.pending_read.take(); + self.closed = true; + } } impl Stream for ConnectionReader { type Item = Any; - fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut buffer = BytesMut::new(); buffer.resize(BUFFER_SIZE, 0); trace!("Starting new read loop for {}", self.local_addr); loop { trace!("Reading from the stream"); - match futures::executor::block_on(self.read_stream.read(&mut buffer)) { - Ok(mut bytes_read) => { + let stream = self.read_stream.as_mut(); + + match stream.poll_read(cx, &mut buffer) { + Poll::Pending => return Poll::Pending, + + Poll::Ready(Ok(mut bytes_read)) => { if bytes_read > 0 { trace!("Read {} bytes from the network stream", bytes_read) - } else if bytes_read == 0 && self.pending_read.is_none() { - return Poll::Pending; + } else if self.pending_read.is_none() { + self.close_stream(); + return Poll::Ready(None) } if let Some(mut pending_buf) = self.pending_read.take() { @@ -125,11 +136,9 @@ impl Stream for ConnectionReader { } // Close the stream - Err(_err) => { - trace!("Closing the stream"); - self.pending_read.take(); - self.closed = true; - return Poll::Ready(None); + Poll::Ready(Err(_e)) => { + self.close_stream(); + Poll::Ready(None) } } } diff --git a/src/schema/message.rs b/src/schema/message.rs index c41fbb0..265075f 100644 --- a/src/schema/message.rs +++ b/src/schema/message.rs @@ -1,4 +1,4 @@ -// This file is generated by rust-protobuf 2.19.0. Do not edit +// This file is generated by rust-protobuf 2.20.0. Do not edit // @generated // https://github.com/rust-lang/rust-clippy/issues/702 @@ -21,7 +21,7 @@ /// Generated files are compatible only with the same version /// of protobuf runtime. -// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0; +// const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_20_0; #[derive(PartialEq,Clone,Default)] pub struct ConnectionMessage { diff --git a/src/schema/mod.rs b/src/schema/mod.rs index 49c0b7c..2774e64 100644 --- a/src/schema/mod.rs +++ b/src/schema/mod.rs @@ -5,7 +5,7 @@ use protobuf::well_known_types::Any; use protobuf::Message; impl ConnectionMessage { - pub(crate) fn from_msg(msg: T) -> Self { + pub(crate) fn from_msg(msg: M) -> Self { let mut sm = Self::new(); let payload = Any::pack(&msg).expect("Protobuf Message could not be packed into Any type"); diff --git a/src/tcp/client.rs b/src/tcp/client.rs index 140c00d..e9cd42b 100644 --- a/src/tcp/client.rs +++ b/src/tcp/client.rs @@ -28,8 +28,8 @@ impl From for Connection { Self::new( local_addr, peer_addr, - Box::new(stream), - Box::new(write_stream), + Box::pin(stream), + Box::pin(write_stream), ) } } diff --git a/src/tls/client.rs b/src/tls/client.rs index 225c0db..afb9565 100644 --- a/src/tls/client.rs +++ b/src/tls/client.rs @@ -58,8 +58,8 @@ impl From for Connection { Self::new( local_addr, peer_addr, - Box::new(read_stream), - Box::new(write_stream), + Box::pin(read_stream), + Box::pin(write_stream), ) } @@ -73,8 +73,8 @@ impl From for Connection { Self::new( local_addr, peer_addr, - Box::new(read_stream), - Box::new(write_stream), + Box::pin(read_stream), + Box::pin(write_stream), ) } } diff --git a/src/writer.rs b/src/writer.rs index f6442a6..f26ab4b 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -2,8 +2,9 @@ use crate::schema::ConnectionMessage; use async_channel::RecvError; use async_std::net::SocketAddr; use async_std::pin::Pin; +use futures::{AsyncWrite, Sink}; +use futures::io::IoSlice; use futures::task::{Context, Poll}; -use futures::{AsyncWrite, AsyncWriteExt, Sink}; use log::*; use protobuf::Message; @@ -11,24 +12,24 @@ pub use futures::SinkExt; pub use futures::StreamExt; pub struct ConnectionWriter { - local_addr: SocketAddr, - peer_addr: SocketAddr, - write_stream: Box, - pending_write: Option, - closed: bool, + local_addr: SocketAddr, + peer_addr: SocketAddr, + write_stream: Pin>, + pending_writes: Vec>, + closed: bool, } impl ConnectionWriter { pub fn new( local_addr: SocketAddr, peer_addr: SocketAddr, - write_stream: Box, + write_stream: Pin>, ) -> Self { Self { local_addr, peer_addr, write_stream, - pending_write: None, + pending_writes: Vec::new(), closed: false, } } @@ -44,47 +45,15 @@ impl ConnectionWriter { pub fn is_closed(&self) -> bool { self.closed } - - fn send_to_conn(&mut self) -> Poll> { - if let Some(pending_msg) = self.pending_write.take() { - trace!("Send pending message"); - if let Ok(buffer) = pending_msg.write_to_bytes() { - let msg_size = buffer.len(); - trace!("{} bytes to be sent over network connection", msg_size); - - return if let Ok(_) = - futures::executor::block_on(self.write_stream.write_all(buffer.as_slice())) - { - if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) { - trace!("Sent message of {} bytes", msg_size); - Poll::Ready(Ok(())) - } else { - trace!("Encountered error while flushing queued bytes to network stream"); - Poll::Ready(Err(RecvError)) - } - } else { - error!("Encountered error when writing to network stream"); - Poll::Ready(Err(RecvError)) - }; - } else { - error!("Encountered error when serializing message to bytes"); - return Poll::Ready(Err(RecvError)); - } - } else { - trace!("No message to send over connection"); - } - - Poll::Ready(Ok(())) - } } impl Sink for ConnectionWriter { type Error = RecvError; fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll> { - if self.pending_write.is_some() { - trace!("Connection not ready to send message yet, waiting for prior message"); - Poll::Pending + if self.is_closed() { + trace!("Connection is closed, cannot send message"); + Poll::Ready(Err(RecvError)) } else { trace!("Connection ready to send message"); Poll::Ready(Ok(())) @@ -93,30 +62,126 @@ impl Sink for ConnectionWriter { fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> { trace!("Preparing message to be sent next"); - let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item); - self.pending_write.replace(stitch_msg); + let msg: ConnectionMessage = ConnectionMessage::from_msg(item); - Ok(()) + if let Ok(buffer) = msg.write_to_bytes() { + let msg_size = buffer.len(); + trace!("Serialized pending message into {} bytes", msg_size); + + self.pending_writes.push(buffer); + + Ok(()) + } else { + error!("Encountered error when serializing message to bytes"); + Err(RecvError) + } } fn poll_flush( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { - self.send_to_conn() + if self.pending_writes.len() > 0 { + let stream = self.write_stream.as_mut(); + + match stream.poll_flush(cx) { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(_)) => { + trace!("Sending pending bytes"); + + let pending = self.pending_writes.split_off(0); + let writeable_vec: Vec = pending.iter().map(|p| { + IoSlice::new(p) + }).collect(); + + let stream = self.write_stream.as_mut(); + match stream.poll_write_vectored(cx, writeable_vec.as_slice()) { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(bytes_written)) => { + trace!("Wrote {} bytes to network stream", bytes_written); + Poll::Ready(Ok(())) + }, + + Poll::Ready(Err(_e)) => { + error!("Encountered error when writing to network stream"); + Poll::Ready(Err(RecvError)) + }, + } + }, + + Poll::Ready(Err(_e)) => { + error!("Encountered error when flushing network stream"); + Poll::Ready(Err(RecvError)) + } + } + } else { + Poll::Ready(Ok(())) + } } fn poll_close( mut self: Pin<&mut Self>, - _cx: &mut Context<'_>, + cx: &mut Context<'_>, ) -> Poll> { - let _ = self.send_to_conn(); - self.closed = true; - if let Ok(_) = futures::executor::block_on(self.write_stream.close()) { - Poll::Ready(Ok(())) + + let flush = if self.pending_writes.len() > 0 { + let stream = self.write_stream.as_mut(); + + match stream.poll_flush(cx) { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(_)) => { + trace!("Sending pending bytes"); + + let pending = self.pending_writes.split_off(0); + let writeable_vec: Vec = pending.iter().map(|p| { + IoSlice::new(p) + }).collect(); + + let stream = self.write_stream.as_mut(); + match stream.poll_write_vectored(cx, writeable_vec.as_slice()) { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(bytes_written)) => { + trace!("Wrote {} bytes to network stream", bytes_written); + Poll::Ready(Ok(())) + }, + + Poll::Ready(Err(_e)) => { + error!("Encountered error when writing to network stream"); + Poll::Ready(Err(RecvError)) + }, + } + }, + + Poll::Ready(Err(_e)) => { + error!("Encountered error when flushing network stream"); + Poll::Ready(Err(RecvError)) + } + } } else { - Poll::Ready(Err(RecvError)) + Poll::Ready(Ok(())) + }; + + match flush { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(_)) => { + let stream = self.write_stream.as_mut(); + + match stream.poll_close(cx) { + Poll::Pending => Poll::Pending, + + Poll::Ready(Ok(_)) => Poll::Ready(Ok(())), + + Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)), + } + }, + + err => err, } } }