]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/tls/listener.rs
remove `block_on` in tls-listener
[connect-rs.git] / src / tls / listener.rs
index 047dcae88dbbe301aea273d921c18d3a7b2b432f..4378baa368db9f50e61696533ff4fb2d45a14e51 100644 (file)
@@ -1,10 +1,12 @@
 use crate::tls::TlsConnectionMetadata;
 use crate::Connection;
-use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs};
+use async_std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
 use async_std::pin::Pin;
 use async_std::task::{Context, Poll};
-use async_tls::TlsAcceptor;
-use futures::{Stream, StreamExt};
+use async_stream::stream;
+use async_tls::{server::TlsStream, TlsAcceptor};
+use futures::Stream;
+use futures_lite::StreamExt;
 use log::*;
 
 /// Listens on a bound socket for incoming TLS connections to be handled as independent
@@ -14,6 +16,9 @@ use log::*;
 ///
 /// # Example
 ///
+/// Please see the [tls-echo-server](https://github.com/sachanganesh/connect-rs/blob/main/examples/tls-echo-server/src/main.rs)
+/// example program for a more thorough showcase.
+///
 /// Basic usage:
 ///
 /// ```ignore
@@ -27,8 +32,18 @@ use log::*;
 #[allow(dead_code)]
 pub struct TlsListener {
     local_addrs: SocketAddr,
-    listener: TcpListener,
-    acceptor: TlsAcceptor,
+    conn_stream: Pin<
+        Box<
+            dyn Stream<
+                    Item = Option<
+                        Option<(SocketAddr, Result<TlsStream<TcpStream>, std::io::Error>)>,
+                    >,
+                > + Send
+                + Sync,
+        >,
+    >,
+    // listener: TcpListener,
+    // acceptor: TlsAcceptor,
 }
 
 impl TlsListener {
@@ -46,87 +61,103 @@ impl TlsListener {
         ip_addrs: A,
         acceptor: TlsAcceptor,
     ) -> anyhow::Result<Self> {
-        let listener = TcpListener::bind(ip_addrs).await?;
-        info!("Started TLS server at {}", listener.local_addr()?);
+        let listener = TcpListener::bind(&ip_addrs).await?;
+        info!("Started TLS server at {}", &ip_addrs);
 
-        Ok(Self {
-            local_addrs: listener.local_addr()?,
-            listener,
-            acceptor,
-        })
-    }
+        let local_addrs = listener.local_addr()?;
 
-    /// Creates a [`Connection`] for the next `accept`ed TCP connection at the bound socket.
-    ///
-    /// # Example
-    ///
-    /// Basic usage:
-    ///
-    /// ```ignore
-    /// let mut server = TlsListener::bind("127.0.0.1:3456", config.into()).await?;
-    /// while let Some(mut conn) = server.next().await {
-    ///     // do something with connection
-    /// }
-    /// ```
-    pub async fn accept(&self) -> anyhow::Result<Connection> {
-        let (tcp_stream, peer_addr) = self.listener.accept().await?;
-        debug!("Received connection attempt from {}", peer_addr);
+        let stream = Box::pin(stream! {
+            loop {
+                yield match listener.incoming().next().await {
+                    Some(Ok(tcp_stream)) => {
+                        let peer_addr = tcp_stream
+                            .peer_addr()
+                            .expect("Could not retrieve peer IP address");
+                        debug!("Received connection attempt from {}", peer_addr);
 
-        match self.acceptor.accept(tcp_stream).await {
-            Ok(tls_stream) => {
-                debug!("Completed TLS handshake with {}", peer_addr);
-                Ok(Connection::from(TlsConnectionMetadata::Listener {
-                    local_addr: self.local_addrs.clone(),
-                    peer_addr,
-                    stream: tls_stream,
-                }))
-            }
+                        Some(Some((peer_addr, acceptor.accept(tcp_stream).await)))
+                    }
 
-            Err(e) => {
-                warn!("Could not encrypt connection with TLS from {}", peer_addr);
-                Err(anyhow::Error::new(e))
+                    Some(Err(err)) => {
+                        error!(
+                            "Encountered error when trying to accept new connection {}",
+                            err
+                        );
+                        Some(None)
+                    }
+
+                    None => None,
+                }
             }
-        }
+        });
+
+        Ok(Self {
+            local_addrs,
+            conn_stream: stream,
+            // listener,
+            // acceptor,
+        })
     }
+
+    // /// Creates a [`Connection`] for the next `accept`ed TCP connection at the bound socket.
+    // ///
+    // /// # Example
+    // ///
+    // /// Basic usage:
+    // ///
+    // /// ```ignore
+    // /// let mut server = TlsListener::bind("127.0.0.1:3456", config.into()).await?;
+    // /// while let Some(mut conn) = server.next().await {
+    // ///     // do something with connection
+    // /// }
+    // /// ```
+    // pub async fn accept(&self) -> anyhow::Result<Connection> {
+    //     let (tcp_stream, peer_addr) = self.listener.accept().await?;
+    //     debug!("Received connection attempt from {}", peer_addr);
+    //
+    //     match self.acceptor.accept(tcp_stream).await {
+    //         Ok(tls_stream) => {
+    //             debug!("Completed TLS handshake with {}", peer_addr);
+    //             Ok(Connection::from(TlsConnectionMetadata::Listener {
+    //                 local_addr: self.local_addrs.clone(),
+    //                 peer_addr,
+    //                 stream: tls_stream,
+    //             }))
+    //         }
+    //
+    //         Err(e) => {
+    //             warn!("Could not encrypt connection with TLS from {}", peer_addr);
+    //             Err(anyhow::Error::new(e))
+    //         }
+    //     }
+    // }
 }
 
 impl Stream for TlsListener {
     type Item = Connection;
 
-    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
-        match futures::executor::block_on(self.listener.incoming().next()) {
-            Some(Ok(tcp_stream)) => {
-                let peer_addr = tcp_stream
-                    .peer_addr()
-                    .expect("Could not retrieve peer IP address");
-                debug!("Received connection attempt from {}", peer_addr);
-
-                match futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
-                    Ok(tls_stream) => {
-                        debug!("Completed TLS handshake with {}", peer_addr);
-                        Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Listener {
-                            local_addr: self.local_addrs.clone(),
-                            peer_addr,
-                            stream: tls_stream,
-                        })))
-                    }
-
-                    Err(_e) => {
-                        warn!("Could not encrypt connection with TLS from {}", peer_addr);
-                        Poll::Pending
-                    }
-                }
+    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        match self.conn_stream.poll_next(cx) {
+            Poll::Ready(Some(Some(Some((peer_addr, Ok(tls_stream)))))) => {
+                debug!("Completed TLS handshake with {}", peer_addr);
+                Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Listener {
+                    local_addr: self.local_addrs.clone(),
+                    peer_addr,
+                    stream: tls_stream,
+                })))
             }
 
-            Some(Err(e)) => {
-                error!(
-                    "Encountered error when trying to accept new connection {}",
-                    e
+            Poll::Ready(Some(Some(Some((peer_addr, Err(err)))))) => {
+                warn!(
+                    "Could not encrypt connection with TLS from {}: {}",
+                    peer_addr, err
                 );
                 Poll::Pending
             }
 
-            None => Poll::Ready(None),
+            Poll::Pending => Poll::Pending,
+
+            _ => Poll::Ready(None),
         }
     }
 }