]> git.lizzy.rs Git - connect-rs.git/commitdiff
make async-oriented, remove block_on
authorSachandhan Ganesh <sachan.ganesh@gmail.com>
Thu, 21 Jan 2021 00:51:00 +0000 (16:51 -0800)
committerSachandhan Ganesh <sachan.ganesh@gmail.com>
Thu, 21 Jan 2021 00:51:00 +0000 (16:51 -0800)
Cargo.toml
examples/tcp-client/src/main.rs
examples/tcp-echo-server/src/main.rs
examples/tls-client/src/main.rs
examples/tls-echo-server/src/main.rs
src/reader.rs
src/tcp/client.rs
src/tcp/server.rs
src/tls/client.rs
src/tls/server.rs
src/writer.rs

index 43a5b4ad1a1a0cdb99233af1097180010522ae30..334bc0223de7e7e4fa46190443e1e5b4aef8390f 100644 (file)
@@ -13,7 +13,6 @@ async-std = { version = "1.6.2", features = ["unstable"] }
 async-tls = { version = "0.9.0", default-features = false, features = ["client", "server"]}
 bytes = "0.5.5"
 futures = "0.3.8"
-futures-lite = "1.11.3"
 log = "0.4"
 protobuf = "2.18.1"
 rustls = "0.18.0"
index f1b1d5fc084cab92cef447afe87fa98c2a5f75b3..c65a45548cb32894155e443b1409105b67237902 100644 (file)
@@ -21,16 +21,16 @@ async fn main() -> anyhow::Result<()> {
     };
 
     // create a client connection to the server
-    let mut conn = Connection::tcp_client(ip_address)?;
+    let mut conn = Connection::tcp_client(ip_address).await?;
 
     // send a message to the server
     let raw_msg = String::from("Hello world");
-    info!("Sending message: {}", raw_msg);
 
     let mut msg = HelloWorld::new();
-    msg.set_message(raw_msg);
+    msg.set_message(raw_msg.clone());
 
     conn.writer().send(msg).await?;
+    info!("Sent message: {}", raw_msg);
 
     // wait for the server to reply with an ack
     while let Some(reply) = conn.reader().next().await {
index cc5b5d7d3fa9e2a4f250fe649d18b6e24330ce1d..1c6f588ca710edd3f8904f653145a9f9210a8336 100644 (file)
@@ -23,11 +23,11 @@ async fn main() -> anyhow::Result<()> {
     };
 
     // create a server
-    let mut server = TcpServer::new(ip_address)?;
+    let server = TcpServer::new(ip_address).await?;
 
     // handle server connections
     // wait for a connection to come in and be accepted
-    while let Some(mut conn) = server.next().await {
+    while let Ok(mut conn) = server.accept().await {
         info!("Handling connection from {}", conn.peer_addr());
 
         task::spawn(async move {
index b513a67bd6660567ba7149707a26173732b2e9f1..d9198e8ca2d756c6a20e657919a3fb75e5b5d5ad 100644 (file)
@@ -26,7 +26,7 @@ async fn main() -> anyhow::Result<()> {
         .map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert"))?;
 
     // create a client connection to the server
-    let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into())?;
+    let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into()).await?;
 
     // send a message to the server
     let raw_msg = String::from("Hello world");
index 3e8b57699d014d005b5f906bc7cac0312c90c8f3..32ebf972f939d3c11678997125854d3c5c05487b 100644 (file)
@@ -30,34 +30,45 @@ async fn main() -> anyhow::Result<()> {
         .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
 
     // create a server
-    let mut server = TlsServer::new(ip_address, config.into())?;
+    let server = TlsServer::new(ip_address, config.into()).await?;
 
     // handle server connections
     // wait for a connection to come in and be accepted
-    while let Some(mut conn) = server.next().await {
-        info!("Handling connection from {}", conn.peer_addr());
-
-        task::spawn(async move {
-            while let Some(msg) = conn.reader().next().await {
-                if msg.is::<HelloWorld>() {
-                    if let Ok(Some(contents)) = msg.unpack::<HelloWorld>() {
-                        info!(
-                            "Received a message \"{}\" from {}",
-                            contents.get_message(),
-                            conn.peer_addr()
-                        );
-
-                        conn.writer()
-                            .send(contents)
-                            .await
-                            .expect("Could not send message back to source connection");
-                        info!("Sent message back to original sender");
+    loop {
+        match server.accept().await {
+            Ok(Some(mut conn)) => {
+                info!("Handling connection from {}", conn.peer_addr());
+
+                task::spawn(async move {
+                    while let Some(msg) = conn.reader().next().await {
+                        if msg.is::<HelloWorld>() {
+                            if let Ok(Some(contents)) = msg.unpack::<HelloWorld>() {
+                                info!(
+                                    "Received a message \"{}\" from {}",
+                                    contents.get_message(),
+                                    conn.peer_addr()
+                                );
+
+                                conn.writer()
+                                    .send(contents)
+                                    .await
+                                    .expect("Could not send message back to source connection");
+                                info!("Sent message back to original sender");
+                            }
+                        } else {
+                            error!("Received a message of unknown type")
+                        }
                     }
-                } else {
-                    error!("Received a message of unknown type")
-                }
+                });
             }
-        });
+
+            Ok(None) => (),
+
+            Err(e) => {
+                error!("Encountered error when accepting connection: {}", e);
+                break
+            }
+        }
     }
 
     Ok(())
index 6844e393f297148f09b0e5ef0b64c577fc923826..c859dd510d9afe9a7811a87654146b9383a4297b 100644 (file)
@@ -15,11 +15,11 @@ use protobuf::well_known_types::Any;
 const BUFFER_SIZE: usize = 8192;
 
 pub struct ConnectionReader {
-    local_addr:   SocketAddr,
-    peer_addr:    SocketAddr,
-    read_stream:  Pin<Box<dyn AsyncRead + Send + Sync>>,
+    local_addr: SocketAddr,
+    peer_addr: SocketAddr,
+    read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
     pending_read: Option<BytesMut>,
-    closed:       bool,
+    closed: bool,
 }
 
 impl ConnectionReader {
@@ -76,7 +76,7 @@ impl Stream for ConnectionReader {
                         trace!("Read {} bytes from the network stream", bytes_read)
                     } else if self.pending_read.is_none() {
                         self.close_stream();
-                        return Poll::Ready(None)
+                        return Poll::Ready(None);
                     }
 
                     if let Some(mut pending_buf) = self.pending_read.take() {
@@ -138,7 +138,7 @@ impl Stream for ConnectionReader {
                 // Close the stream
                 Poll::Ready(Err(_e)) => {
                     self.close_stream();
-                    return Poll::Ready(None)
+                    return Poll::Ready(None);
                 }
             }
         }
index 354f7bb18880a13acb416b16e4f813e1ee802f38..221825da27399ad7ab477a13be1a979b7179abcf 100644 (file)
@@ -4,7 +4,9 @@ use crate::Connection;
 use async_std::net::{TcpStream, ToSocketAddrs};
 
 impl Connection {
-    pub async fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
+    pub async fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(
+        ip_addrs: A,
+    ) -> anyhow::Result<Self> {
         let stream = TcpStream::connect(&ip_addrs).await?;
         info!("Established client TCP connection to {}", ip_addrs);
 
index 3ecb55ac4ac220500c7e5036e68f71aa4f6b0c22..4d22e7b4c37b7ca7141834f18e9b442ae35125ab 100644 (file)
@@ -1,9 +1,5 @@
 use crate::Connection;
 use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs};
-use async_std::pin::Pin;
-use futures::task::{Context, Poll};
-use futures::Stream;
-use futures_lite::stream::StreamExt;
 use log::*;
 
 #[allow(dead_code)]
@@ -22,37 +18,11 @@ impl TcpServer {
             listener,
         })
     }
-}
-
-impl Stream for TcpServer {
-    type Item = Connection;
-
-    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        match self.listener.incoming().poll_next(cx) {
-            Poll::Pending => Poll::Pending,
-
-            Poll::Ready(Some(Ok(conn))) => {
-                debug!(
-                    "Received connection attempt from {}",
-                    conn.peer_addr()
-                        .expect("Peer address could not be retrieved")
-                );
-
-                Poll::Ready(Some(Connection::from(conn)))
-            },
-
-            Poll::Ready(Some(Err(e))) => {
-                error!(
-                    "Encountered error when accepting connection attempt: {}", e
-                );
 
-                Poll::Pending
-            },
+    pub async fn accept(&self) -> anyhow::Result<Connection> {
+        let (stream, ip_addr) = self.listener.accept().await?;
+        debug!("Received connection attempt from {}", ip_addr);
 
-            Poll::Ready(None) => {
-                info!("Shutting TCP server down at {}", self.local_addrs);
-                Poll::Ready(None)
-            },
-        }
+        Ok(Connection::from(stream))
     }
 }
index 1bb345bf242509c9a1d299882c2d69f29b101c90..60acdad99aad7ad990b43f539ffd782542371364 100644 (file)
@@ -33,7 +33,8 @@ impl Connection {
         let local_addr = stream.peer_addr()?;
         let peer_addr = stream.peer_addr()?;
 
-        let encrypted_stream: client::TlsStream<TcpStream> = connector.connect(domain, stream).await?;
+        let encrypted_stream: client::TlsStream<TcpStream> =
+            connector.connect(domain, stream).await?;
         info!("Completed TLS handshake with {}", peer_addr);
 
         Ok(Self::from(TlsConnectionMetadata::Client {
index 41436413ba736dea5df09b1190c769ecde7be81f..2011730fa410e40001aa88c4f0d814ce6741244f 100644 (file)
@@ -1,11 +1,7 @@
 use crate::tls::TlsConnectionMetadata;
 use crate::Connection;
 use async_std::net::*;
-use async_std::pin::Pin;
 use async_tls::TlsAcceptor;
-use futures::Stream;
-use futures::task::{Context, Poll};
-use futures_lite::StreamExt;
 use log::*;
 
 #[allow(dead_code)]
@@ -29,53 +25,21 @@ impl TlsServer {
             acceptor,
         })
     }
-}
-
-impl Stream for TlsServer {
-    type Item = Connection;
-
-    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        match self.listener.incoming().poll_next(cx) {
-            Poll::Pending => Poll::Pending,
-
-            Poll::Ready(Some(Ok(tcp_stream))) => {
-                let local_addr = tcp_stream
-                    .local_addr()
-                    .expect("Local address could not be retrieved");
-
-                let peer_addr = tcp_stream
-                    .peer_addr()
-                    .expect("Peer address could not be retrieved");
-
-                debug!(
-                    "Received connection attempt from {}", peer_addr
-                );
-
-                if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
-                    debug!("Completed TLS handshake with {}", peer_addr);
-                    Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server {
-                        local_addr,
-                        peer_addr,
-                        stream: tls_stream,
-                    })))
-                } else {
-                    warn!("Could not encrypt connection with TLS from {}", peer_addr);
-                    Poll::Pending
-                }
-            },
-
-            Poll::Ready(Some(Err(e))) => {
-                error!(
-                    "Encountered error when accepting connection attempt: {}", e
-                );
-
-                Poll::Pending
-            }
 
-            Poll::Ready(None) => {
-                info!("Shutting TLS server down at {}", self.local_addrs);
-                Poll::Ready(None)
-            },
+    pub async fn accept(&self) -> anyhow::Result<Option<Connection>> {
+        let (tcp_stream, peer_addr) = self.listener.accept().await?;
+        debug!("Received connection attempt from {}", peer_addr);
+
+        if let Ok(tls_stream) = self.acceptor.accept(tcp_stream).await {
+            debug!("Completed TLS handshake with {}", peer_addr);
+            Ok(Some(Connection::from(TlsConnectionMetadata::Server {
+                local_addr: self.local_addrs.clone(),
+                peer_addr,
+                stream: tls_stream,
+            })))
+        } else {
+            warn!("Could not encrypt connection with TLS from {}", peer_addr);
+            Ok(None)
         }
     }
 }
index f26ab4b12eba3719b98947c25fc80dc96fa5f24d..24030249110beb3768a0211e46654c779e488d71 100644 (file)
@@ -2,9 +2,9 @@ use crate::schema::ConnectionMessage;
 use async_channel::RecvError;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
-use futures::{AsyncWrite, Sink};
 use futures::io::IoSlice;
 use futures::task::{Context, Poll};
+use futures::{AsyncWrite, Sink};
 use log::*;
 use protobuf::Message;
 
@@ -12,11 +12,11 @@ pub use futures::SinkExt;
 pub use futures::StreamExt;
 
 pub struct ConnectionWriter {
-    local_addr:     SocketAddr,
-    peer_addr:      SocketAddr,
-    write_stream:   Pin<Box<dyn AsyncWrite + Send + Sync>>,
+    local_addr: SocketAddr,
+    peer_addr: SocketAddr,
+    write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
     pending_writes: Vec<Vec<u8>>,
-    closed:         bool,
+    closed: bool,
 }
 
 impl ConnectionWriter {
@@ -77,10 +77,7 @@ impl<M: Message> Sink<M> for ConnectionWriter {
         }
     }
 
-    fn poll_flush(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Result<(), Self::Error>> {
+    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
         if self.pending_writes.len() > 0 {
             let stream = self.write_stream.as_mut();
 
@@ -91,9 +88,8 @@ impl<M: Message> Sink<M> for ConnectionWriter {
                     trace!("Sending pending bytes");
 
                     let pending = self.pending_writes.split_off(0);
-                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
-                        IoSlice::new(p)
-                    }).collect();
+                    let writeable_vec: Vec<IoSlice> =
+                        pending.iter().map(|p| IoSlice::new(p)).collect();
 
                     let stream = self.write_stream.as_mut();
                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
@@ -102,14 +98,14 @@ impl<M: Message> Sink<M> for ConnectionWriter {
                         Poll::Ready(Ok(bytes_written)) => {
                             trace!("Wrote {} bytes to network stream", bytes_written);
                             Poll::Ready(Ok(()))
-                        },
+                        }
 
                         Poll::Ready(Err(_e)) => {
                             error!("Encountered error when writing to network stream");
                             Poll::Ready(Err(RecvError))
-                        },
+                        }
                     }
-                },
+                }
 
                 Poll::Ready(Err(_e)) => {
                     error!("Encountered error when flushing network stream");
@@ -121,10 +117,7 @@ impl<M: Message> Sink<M> for ConnectionWriter {
         }
     }
 
-    fn poll_close(
-        mut self: Pin<&mut Self>,
-        cx: &mut Context<'_>,
-    ) -> Poll<Result<(), Self::Error>> {
+    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
         self.closed = true;
 
         let flush = if self.pending_writes.len() > 0 {
@@ -137,9 +130,8 @@ impl<M: Message> Sink<M> for ConnectionWriter {
                     trace!("Sending pending bytes");
 
                     let pending = self.pending_writes.split_off(0);
-                    let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
-                        IoSlice::new(p)
-                    }).collect();
+                    let writeable_vec: Vec<IoSlice> =
+                        pending.iter().map(|p| IoSlice::new(p)).collect();
 
                     let stream = self.write_stream.as_mut();
                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
@@ -148,14 +140,14 @@ impl<M: Message> Sink<M> for ConnectionWriter {
                         Poll::Ready(Ok(bytes_written)) => {
                             trace!("Wrote {} bytes to network stream", bytes_written);
                             Poll::Ready(Ok(()))
-                        },
+                        }
 
                         Poll::Ready(Err(_e)) => {
                             error!("Encountered error when writing to network stream");
                             Poll::Ready(Err(RecvError))
-                        },
+                        }
                     }
-                },
+                }
 
                 Poll::Ready(Err(_e)) => {
                     error!("Encountered error when flushing network stream");
@@ -179,7 +171,7 @@ impl<M: Message> Sink<M> for ConnectionWriter {
 
                     Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
                 }
-            },
+            }
 
             err => err,
         }