]> git.lizzy.rs Git - connect-rs.git/commitdiff
avoid block_on as much as possible
authorSachandhan Ganesh <sachan.ganesh@gmail.com>
Wed, 20 Jan 2021 07:09:29 +0000 (23:09 -0800)
committerSachandhan Ganesh <sachan.ganesh@gmail.com>
Wed, 20 Jan 2021 07:41:24 +0000 (23:41 -0800)
Cargo.toml
src/tcp/client.rs
src/tcp/server.rs
src/tls/client.rs
src/tls/server.rs

index 334bc0223de7e7e4fa46190443e1e5b4aef8390f..43a5b4ad1a1a0cdb99233af1097180010522ae30 100644 (file)
@@ -13,6 +13,7 @@ async-std = { version = "1.6.2", features = ["unstable"] }
 async-tls = { version = "0.9.0", default-features = false, features = ["client", "server"]}
 bytes = "0.5.5"
 futures = "0.3.8"
+futures-lite = "1.11.3"
 log = "0.4"
 protobuf = "2.18.1"
 rustls = "0.18.0"
index e9cd42bfc395b511675c4dabaf092ed11822f2e1..354f7bb18880a13acb416b16e4f813e1ee802f38 100644 (file)
@@ -4,8 +4,8 @@ use crate::Connection;
 use async_std::net::{TcpStream, ToSocketAddrs};
 
 impl Connection {
-    pub fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
-        let stream = futures::executor::block_on(TcpStream::connect(&ip_addrs))?;
+    pub async fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
+        let stream = TcpStream::connect(&ip_addrs).await?;
         info!("Established client TCP connection to {}", ip_addrs);
 
         stream.set_nodelay(true)?;
index 674227b6350c9a921b2deebd423929d329d8f6ea..3ecb55ac4ac220500c7e5036e68f71aa4f6b0c22 100644 (file)
@@ -2,7 +2,8 @@ use crate::Connection;
 use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs};
 use async_std::pin::Pin;
 use futures::task::{Context, Poll};
-use futures::{Stream, StreamExt};
+use futures::Stream;
+use futures_lite::stream::StreamExt;
 use log::*;
 
 #[allow(dead_code)]
@@ -12,8 +13,8 @@ pub struct TcpServer {
 }
 
 impl TcpServer {
-    pub fn new<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
-        let listener = futures::executor::block_on(TcpListener::bind(&ip_addrs))?;
+    pub async fn new<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
+        let listener = TcpListener::bind(&ip_addrs).await?;
         info!("Started TCP server at {}", &ip_addrs);
 
         Ok(Self {
@@ -26,17 +27,32 @@ impl TcpServer {
 impl Stream for TcpServer {
     type Item = Connection;
 
-    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        if let Some(Ok(conn)) = futures::executor::block_on(self.listener.incoming().next()) {
-            debug!(
-                "Received connection attempt from {}",
-                conn.peer_addr()
-                    .expect("Peer address could not be retrieved")
-            );
-            Poll::Ready(Some(Connection::from(conn)))
-        } else {
-            info!("Shutting TCP server down at {}", self.local_addrs);
-            Poll::Ready(None)
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        match self.listener.incoming().poll_next(cx) {
+            Poll::Pending => Poll::Pending,
+
+            Poll::Ready(Some(Ok(conn))) => {
+                debug!(
+                    "Received connection attempt from {}",
+                    conn.peer_addr()
+                        .expect("Peer address could not be retrieved")
+                );
+
+                Poll::Ready(Some(Connection::from(conn)))
+            },
+
+            Poll::Ready(Some(Err(e))) => {
+                error!(
+                    "Encountered error when accepting connection attempt: {}", e
+                );
+
+                Poll::Pending
+            },
+
+            Poll::Ready(None) => {
+                info!("Shutting TCP server down at {}", self.local_addrs);
+                Poll::Ready(None)
+            },
         }
     }
 }
index afb95656eea924ab375ef168df5fde7357e72050..1bb345bf242509c9a1d299882c2d69f29b101c90 100644 (file)
@@ -21,20 +21,19 @@ pub enum TlsConnectionMetadata {
 }
 
 impl Connection {
-    pub fn tls_client<A: ToSocketAddrs + std::fmt::Display>(
+    pub async fn tls_client<A: ToSocketAddrs + std::fmt::Display>(
         ip_addrs: A,
         domain: &str,
         connector: TlsConnector,
     ) -> anyhow::Result<Self> {
-        let stream = futures::executor::block_on(TcpStream::connect(&ip_addrs))?;
+        let stream = TcpStream::connect(&ip_addrs).await?;
         info!("Established client TCP connection to {}", ip_addrs);
         stream.set_nodelay(true)?;
 
         let local_addr = stream.peer_addr()?;
         let peer_addr = stream.peer_addr()?;
 
-        let encrypted_stream: client::TlsStream<TcpStream> =
-            futures::executor::block_on(connector.connect(domain, stream))?;
+        let encrypted_stream: client::TlsStream<TcpStream> = connector.connect(domain, stream).await?;
         info!("Completed TLS handshake with {}", peer_addr);
 
         Ok(Self::from(TlsConnectionMetadata::Client {
index dde4b16833dd9eb9e50497be16443ba9d102ec39..41436413ba736dea5df09b1190c769ecde7be81f 100644 (file)
@@ -2,9 +2,10 @@ use crate::tls::TlsConnectionMetadata;
 use crate::Connection;
 use async_std::net::*;
 use async_std::pin::Pin;
-use async_std::prelude::*;
 use async_tls::TlsAcceptor;
+use futures::Stream;
 use futures::task::{Context, Poll};
+use futures_lite::StreamExt;
 use log::*;
 
 #[allow(dead_code)]
@@ -15,11 +16,11 @@ pub struct TlsServer {
 }
 
 impl TlsServer {
-    pub fn new<A: ToSocketAddrs + std::fmt::Display>(
+    pub async fn new<A: ToSocketAddrs + std::fmt::Display>(
         ip_addrs: A,
         acceptor: TlsAcceptor,
     ) -> anyhow::Result<Self> {
-        let listener = futures::executor::block_on(TcpListener::bind(ip_addrs))?;
+        let listener = TcpListener::bind(ip_addrs).await?;
         info!("Started TLS server at {}", listener.local_addr()?);
 
         Ok(Self {
@@ -33,32 +34,48 @@ impl TlsServer {
 impl Stream for TlsServer {
     type Item = Connection;
 
-    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        if let Some(Ok(tcp_stream)) = futures::executor::block_on(self.listener.incoming().next()) {
-            let local_addr = tcp_stream
-                .local_addr()
-                .expect("Local address could not be retrieved");
+    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        match self.listener.incoming().poll_next(cx) {
+            Poll::Pending => Poll::Pending,
 
-            let peer_addr = tcp_stream
-                .peer_addr()
-                .expect("Peer address could not be retrieved");
-            debug!("Received connection attempt from {}", peer_addr);
+            Poll::Ready(Some(Ok(tcp_stream))) => {
+                let local_addr = tcp_stream
+                    .local_addr()
+                    .expect("Local address could not be retrieved");
+
+                let peer_addr = tcp_stream
+                    .peer_addr()
+                    .expect("Peer address could not be retrieved");
+
+                debug!(
+                    "Received connection attempt from {}", peer_addr
+                );
+
+                if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
+                    debug!("Completed TLS handshake with {}", peer_addr);
+                    Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server {
+                        local_addr,
+                        peer_addr,
+                        stream: tls_stream,
+                    })))
+                } else {
+                    warn!("Could not encrypt connection with TLS from {}", peer_addr);
+                    Poll::Pending
+                }
+            },
+
+            Poll::Ready(Some(Err(e))) => {
+                error!(
+                    "Encountered error when accepting connection attempt: {}", e
+                );
 
-            if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
-                debug!("Completed TLS handshake with {}", peer_addr);
-                Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server {
-                    local_addr,
-                    peer_addr,
-                    stream: tls_stream,
-                })))
-            } else {
-                warn!("Could not encrypt connection with TLS from {}", peer_addr);
-                // @otodo close the tcp-stream connection
                 Poll::Pending
             }
-        } else {
-            info!("Shutting TLS server down at {}", self.local_addrs);
-            Poll::Ready(None)
+
+            Poll::Ready(None) => {
+                info!("Shutting TLS server down at {}", self.local_addrs);
+                Poll::Ready(None)
+            },
         }
     }
 }