From 9ee09e52916b771bfb7315294ea6cfe13ba97514 Mon Sep 17 00:00:00 2001 From: Sachandhan Ganesh Date: Wed, 20 Jan 2021 16:51:00 -0800 Subject: [PATCH] make async-oriented, remove block_on --- Cargo.toml | 1 - examples/tcp-client/src/main.rs | 6 +-- examples/tcp-echo-server/src/main.rs | 4 +- examples/tls-client/src/main.rs | 2 +- examples/tls-echo-server/src/main.rs | 57 +++++++++++++++---------- src/reader.rs | 12 +++--- src/tcp/client.rs | 4 +- src/tcp/server.rs | 38 ++--------------- src/tls/client.rs | 3 +- src/tls/server.rs | 64 ++++++---------------------- src/writer.rs | 44 ++++++++----------- 11 files changed, 87 insertions(+), 148 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 43a5b4a..334bc02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,7 +13,6 @@ async-std = { version = "1.6.2", features = ["unstable"] } async-tls = { version = "0.9.0", default-features = false, features = ["client", "server"]} bytes = "0.5.5" futures = "0.3.8" -futures-lite = "1.11.3" log = "0.4" protobuf = "2.18.1" rustls = "0.18.0" diff --git a/examples/tcp-client/src/main.rs b/examples/tcp-client/src/main.rs index f1b1d5f..c65a455 100644 --- a/examples/tcp-client/src/main.rs +++ b/examples/tcp-client/src/main.rs @@ -21,16 +21,16 @@ async fn main() -> anyhow::Result<()> { }; // create a client connection to the server - let mut conn = Connection::tcp_client(ip_address)?; + let mut conn = Connection::tcp_client(ip_address).await?; // send a message to the server let raw_msg = String::from("Hello world"); - info!("Sending message: {}", raw_msg); let mut msg = HelloWorld::new(); - msg.set_message(raw_msg); + msg.set_message(raw_msg.clone()); conn.writer().send(msg).await?; + info!("Sent message: {}", raw_msg); // wait for the server to reply with an ack while let Some(reply) = conn.reader().next().await { diff --git a/examples/tcp-echo-server/src/main.rs b/examples/tcp-echo-server/src/main.rs index cc5b5d7..1c6f588 100644 --- a/examples/tcp-echo-server/src/main.rs +++ b/examples/tcp-echo-server/src/main.rs @@ -23,11 +23,11 @@ async fn main() -> anyhow::Result<()> { }; // create a server - let mut server = TcpServer::new(ip_address)?; + let server = TcpServer::new(ip_address).await?; // handle server connections // wait for a connection to come in and be accepted - while let Some(mut conn) = server.next().await { + while let Ok(mut conn) = server.accept().await { info!("Handling connection from {}", conn.peer_addr()); task::spawn(async move { diff --git a/examples/tls-client/src/main.rs b/examples/tls-client/src/main.rs index b513a67..d9198e8 100644 --- a/examples/tls-client/src/main.rs +++ b/examples/tls-client/src/main.rs @@ -26,7 +26,7 @@ async fn main() -> anyhow::Result<()> { .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert"))?; // create a client connection to the server - let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into())?; + let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into()).await?; // send a message to the server let raw_msg = String::from("Hello world"); diff --git a/examples/tls-echo-server/src/main.rs b/examples/tls-echo-server/src/main.rs index 3e8b576..32ebf97 100644 --- a/examples/tls-echo-server/src/main.rs +++ b/examples/tls-echo-server/src/main.rs @@ -30,34 +30,45 @@ async fn main() -> anyhow::Result<()> { .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?; // create a server - let mut server = TlsServer::new(ip_address, config.into())?; + let server = TlsServer::new(ip_address, config.into()).await?; // handle server connections // wait for a connection to come in and be accepted - while let Some(mut conn) = server.next().await { - info!("Handling connection from {}", conn.peer_addr()); - - task::spawn(async move { - while let Some(msg) = conn.reader().next().await { - if msg.is::() { - if let Ok(Some(contents)) = msg.unpack::() { - info!( - "Received a message \"{}\" from {}", - contents.get_message(), - conn.peer_addr() - ); - - conn.writer() - .send(contents) - .await - .expect("Could not send message back to source connection"); - info!("Sent message back to original sender"); + loop { + match server.accept().await { + Ok(Some(mut conn)) => { + info!("Handling connection from {}", conn.peer_addr()); + + task::spawn(async move { + while let Some(msg) = conn.reader().next().await { + if msg.is::() { + if let Ok(Some(contents)) = msg.unpack::() { + info!( + "Received a message \"{}\" from {}", + contents.get_message(), + conn.peer_addr() + ); + + conn.writer() + .send(contents) + .await + .expect("Could not send message back to source connection"); + info!("Sent message back to original sender"); + } + } else { + error!("Received a message of unknown type") + } } - } else { - error!("Received a message of unknown type") - } + }); } - }); + + Ok(None) => (), + + Err(e) => { + error!("Encountered error when accepting connection: {}", e); + break + } + } } Ok(()) diff --git a/src/reader.rs b/src/reader.rs index 6844e39..c859dd5 100644 --- a/src/reader.rs +++ b/src/reader.rs @@ -15,11 +15,11 @@ use protobuf::well_known_types::Any; const BUFFER_SIZE: usize = 8192; pub struct ConnectionReader { - local_addr: SocketAddr, - peer_addr: SocketAddr, - read_stream: Pin>, + local_addr: SocketAddr, + peer_addr: SocketAddr, + read_stream: Pin>, pending_read: Option, - closed: bool, + closed: bool, } impl ConnectionReader { @@ -76,7 +76,7 @@ impl Stream for ConnectionReader { trace!("Read {} bytes from the network stream", bytes_read) } else if self.pending_read.is_none() { self.close_stream(); - return Poll::Ready(None) + return Poll::Ready(None); } if let Some(mut pending_buf) = self.pending_read.take() { @@ -138,7 +138,7 @@ impl Stream for ConnectionReader { // Close the stream Poll::Ready(Err(_e)) => { self.close_stream(); - return Poll::Ready(None) + return Poll::Ready(None); } } } diff --git a/src/tcp/client.rs b/src/tcp/client.rs index 354f7bb..221825d 100644 --- a/src/tcp/client.rs +++ b/src/tcp/client.rs @@ -4,7 +4,9 @@ use crate::Connection; use async_std::net::{TcpStream, ToSocketAddrs}; impl Connection { - pub async fn tcp_client(ip_addrs: A) -> anyhow::Result { + pub async fn tcp_client( + ip_addrs: A, + ) -> anyhow::Result { let stream = TcpStream::connect(&ip_addrs).await?; info!("Established client TCP connection to {}", ip_addrs); diff --git a/src/tcp/server.rs b/src/tcp/server.rs index 3ecb55a..4d22e7b 100644 --- a/src/tcp/server.rs +++ b/src/tcp/server.rs @@ -1,9 +1,5 @@ use crate::Connection; use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs}; -use async_std::pin::Pin; -use futures::task::{Context, Poll}; -use futures::Stream; -use futures_lite::stream::StreamExt; use log::*; #[allow(dead_code)] @@ -22,37 +18,11 @@ impl TcpServer { listener, }) } -} - -impl Stream for TcpServer { - type Item = Connection; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.listener.incoming().poll_next(cx) { - Poll::Pending => Poll::Pending, - - Poll::Ready(Some(Ok(conn))) => { - debug!( - "Received connection attempt from {}", - conn.peer_addr() - .expect("Peer address could not be retrieved") - ); - - Poll::Ready(Some(Connection::from(conn))) - }, - - Poll::Ready(Some(Err(e))) => { - error!( - "Encountered error when accepting connection attempt: {}", e - ); - Poll::Pending - }, + pub async fn accept(&self) -> anyhow::Result { + let (stream, ip_addr) = self.listener.accept().await?; + debug!("Received connection attempt from {}", ip_addr); - Poll::Ready(None) => { - info!("Shutting TCP server down at {}", self.local_addrs); - Poll::Ready(None) - }, - } + Ok(Connection::from(stream)) } } diff --git a/src/tls/client.rs b/src/tls/client.rs index 1bb345b..60acdad 100644 --- a/src/tls/client.rs +++ b/src/tls/client.rs @@ -33,7 +33,8 @@ impl Connection { let local_addr = stream.peer_addr()?; let peer_addr = stream.peer_addr()?; - let encrypted_stream: client::TlsStream = connector.connect(domain, stream).await?; + let encrypted_stream: client::TlsStream = + connector.connect(domain, stream).await?; info!("Completed TLS handshake with {}", peer_addr); Ok(Self::from(TlsConnectionMetadata::Client { diff --git a/src/tls/server.rs b/src/tls/server.rs index 4143641..2011730 100644 --- a/src/tls/server.rs +++ b/src/tls/server.rs @@ -1,11 +1,7 @@ use crate::tls::TlsConnectionMetadata; use crate::Connection; use async_std::net::*; -use async_std::pin::Pin; use async_tls::TlsAcceptor; -use futures::Stream; -use futures::task::{Context, Poll}; -use futures_lite::StreamExt; use log::*; #[allow(dead_code)] @@ -29,53 +25,21 @@ impl TlsServer { acceptor, }) } -} - -impl Stream for TlsServer { - type Item = Connection; - - fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.listener.incoming().poll_next(cx) { - Poll::Pending => Poll::Pending, - - Poll::Ready(Some(Ok(tcp_stream))) => { - let local_addr = tcp_stream - .local_addr() - .expect("Local address could not be retrieved"); - - let peer_addr = tcp_stream - .peer_addr() - .expect("Peer address could not be retrieved"); - - debug!( - "Received connection attempt from {}", peer_addr - ); - - if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) { - debug!("Completed TLS handshake with {}", peer_addr); - Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server { - local_addr, - peer_addr, - stream: tls_stream, - }))) - } else { - warn!("Could not encrypt connection with TLS from {}", peer_addr); - Poll::Pending - } - }, - - Poll::Ready(Some(Err(e))) => { - error!( - "Encountered error when accepting connection attempt: {}", e - ); - - Poll::Pending - } - Poll::Ready(None) => { - info!("Shutting TLS server down at {}", self.local_addrs); - Poll::Ready(None) - }, + pub async fn accept(&self) -> anyhow::Result> { + let (tcp_stream, peer_addr) = self.listener.accept().await?; + debug!("Received connection attempt from {}", peer_addr); + + if let Ok(tls_stream) = self.acceptor.accept(tcp_stream).await { + debug!("Completed TLS handshake with {}", peer_addr); + Ok(Some(Connection::from(TlsConnectionMetadata::Server { + local_addr: self.local_addrs.clone(), + peer_addr, + stream: tls_stream, + }))) + } else { + warn!("Could not encrypt connection with TLS from {}", peer_addr); + Ok(None) } } } diff --git a/src/writer.rs b/src/writer.rs index f26ab4b..2403024 100644 --- a/src/writer.rs +++ b/src/writer.rs @@ -2,9 +2,9 @@ use crate::schema::ConnectionMessage; use async_channel::RecvError; use async_std::net::SocketAddr; use async_std::pin::Pin; -use futures::{AsyncWrite, Sink}; use futures::io::IoSlice; use futures::task::{Context, Poll}; +use futures::{AsyncWrite, Sink}; use log::*; use protobuf::Message; @@ -12,11 +12,11 @@ pub use futures::SinkExt; pub use futures::StreamExt; pub struct ConnectionWriter { - local_addr: SocketAddr, - peer_addr: SocketAddr, - write_stream: Pin>, + local_addr: SocketAddr, + peer_addr: SocketAddr, + write_stream: Pin>, pending_writes: Vec>, - closed: bool, + closed: bool, } impl ConnectionWriter { @@ -77,10 +77,7 @@ impl Sink for ConnectionWriter { } } - fn poll_flush( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { if self.pending_writes.len() > 0 { let stream = self.write_stream.as_mut(); @@ -91,9 +88,8 @@ impl Sink for ConnectionWriter { trace!("Sending pending bytes"); let pending = self.pending_writes.split_off(0); - let writeable_vec: Vec = pending.iter().map(|p| { - IoSlice::new(p) - }).collect(); + let writeable_vec: Vec = + pending.iter().map(|p| IoSlice::new(p)).collect(); let stream = self.write_stream.as_mut(); match stream.poll_write_vectored(cx, writeable_vec.as_slice()) { @@ -102,14 +98,14 @@ impl Sink for ConnectionWriter { Poll::Ready(Ok(bytes_written)) => { trace!("Wrote {} bytes to network stream", bytes_written); Poll::Ready(Ok(())) - }, + } Poll::Ready(Err(_e)) => { error!("Encountered error when writing to network stream"); Poll::Ready(Err(RecvError)) - }, + } } - }, + } Poll::Ready(Err(_e)) => { error!("Encountered error when flushing network stream"); @@ -121,10 +117,7 @@ impl Sink for ConnectionWriter { } } - fn poll_close( - mut self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll> { + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { self.closed = true; let flush = if self.pending_writes.len() > 0 { @@ -137,9 +130,8 @@ impl Sink for ConnectionWriter { trace!("Sending pending bytes"); let pending = self.pending_writes.split_off(0); - let writeable_vec: Vec = pending.iter().map(|p| { - IoSlice::new(p) - }).collect(); + let writeable_vec: Vec = + pending.iter().map(|p| IoSlice::new(p)).collect(); let stream = self.write_stream.as_mut(); match stream.poll_write_vectored(cx, writeable_vec.as_slice()) { @@ -148,14 +140,14 @@ impl Sink for ConnectionWriter { Poll::Ready(Ok(bytes_written)) => { trace!("Wrote {} bytes to network stream", bytes_written); Poll::Ready(Ok(())) - }, + } Poll::Ready(Err(_e)) => { error!("Encountered error when writing to network stream"); Poll::Ready(Err(RecvError)) - }, + } } - }, + } Poll::Ready(Err(_e)) => { error!("Encountered error when flushing network stream"); @@ -179,7 +171,7 @@ impl Sink for ConnectionWriter { Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)), } - }, + } err => err, } -- 2.44.0