]> git.lizzy.rs Git - connect-rs.git/blob - src/tls/listener.rs
remove `block_on` in tls-listener
[connect-rs.git] / src / tls / listener.rs
1 use crate::tls::TlsConnectionMetadata;
2 use crate::Connection;
3 use async_std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
4 use async_std::pin::Pin;
5 use async_std::task::{Context, Poll};
6 use async_stream::stream;
7 use async_tls::{server::TlsStream, TlsAcceptor};
8 use futures::Stream;
9 use futures_lite::StreamExt;
10 use log::*;
11
12 /// Listens on a bound socket for incoming TLS connections to be handled as independent
13 /// [`Connection`]s.
14 ///
15 /// Implements the [`Stream`] trait to asynchronously accept incoming TLS connections.
16 ///
17 /// # Example
18 ///
19 /// Please see the [tls-echo-server](https://github.com/sachanganesh/connect-rs/blob/main/examples/tls-echo-server/src/main.rs)
20 /// example program for a more thorough showcase.
21 ///
22 /// Basic usage:
23 ///
24 /// ```ignore
25 /// let mut server = TlsListener::bind("127.0.0.1:3456", config.into()).await?;
26 ///
27 /// // wait for a connection to come in and be accepted
28 /// while let Some(mut conn) = server.next().await {
29 ///     // do something with connection
30 /// }
31 /// ```
32 #[allow(dead_code)]
33 pub struct TlsListener {
34     local_addrs: SocketAddr,
35     conn_stream: Pin<
36         Box<
37             dyn Stream<
38                     Item = Option<
39                         Option<(SocketAddr, Result<TlsStream<TcpStream>, std::io::Error>)>,
40                     >,
41                 > + Send
42                 + Sync,
43         >,
44     >,
45     // listener: TcpListener,
46     // acceptor: TlsAcceptor,
47 }
48
49 impl TlsListener {
50     /// Creates a [`TlsListener`] by binding to an IP address and port and listens for incoming TLS
51     /// connections that have successfully been accepted.
52     ///
53     /// # Example
54     ///
55     /// Basic usage:
56     ///
57     /// ```ignore
58     /// let mut server = TlsListener::bind("127.0.0.1:3456", config.into()).await?;
59     /// ```
60     pub async fn bind<A: ToSocketAddrs + std::fmt::Display>(
61         ip_addrs: A,
62         acceptor: TlsAcceptor,
63     ) -> anyhow::Result<Self> {
64         let listener = TcpListener::bind(&ip_addrs).await?;
65         info!("Started TLS server at {}", &ip_addrs);
66
67         let local_addrs = listener.local_addr()?;
68
69         let stream = Box::pin(stream! {
70             loop {
71                 yield match listener.incoming().next().await {
72                     Some(Ok(tcp_stream)) => {
73                         let peer_addr = tcp_stream
74                             .peer_addr()
75                             .expect("Could not retrieve peer IP address");
76                         debug!("Received connection attempt from {}", peer_addr);
77
78                         Some(Some((peer_addr, acceptor.accept(tcp_stream).await)))
79                     }
80
81                     Some(Err(err)) => {
82                         error!(
83                             "Encountered error when trying to accept new connection {}",
84                             err
85                         );
86                         Some(None)
87                     }
88
89                     None => None,
90                 }
91             }
92         });
93
94         Ok(Self {
95             local_addrs,
96             conn_stream: stream,
97             // listener,
98             // acceptor,
99         })
100     }
101
102     // /// Creates a [`Connection`] for the next `accept`ed TCP connection at the bound socket.
103     // ///
104     // /// # Example
105     // ///
106     // /// Basic usage:
107     // ///
108     // /// ```ignore
109     // /// let mut server = TlsListener::bind("127.0.0.1:3456", config.into()).await?;
110     // /// while let Some(mut conn) = server.next().await {
111     // ///     // do something with connection
112     // /// }
113     // /// ```
114     // pub async fn accept(&self) -> anyhow::Result<Connection> {
115     //     let (tcp_stream, peer_addr) = self.listener.accept().await?;
116     //     debug!("Received connection attempt from {}", peer_addr);
117     //
118     //     match self.acceptor.accept(tcp_stream).await {
119     //         Ok(tls_stream) => {
120     //             debug!("Completed TLS handshake with {}", peer_addr);
121     //             Ok(Connection::from(TlsConnectionMetadata::Listener {
122     //                 local_addr: self.local_addrs.clone(),
123     //                 peer_addr,
124     //                 stream: tls_stream,
125     //             }))
126     //         }
127     //
128     //         Err(e) => {
129     //             warn!("Could not encrypt connection with TLS from {}", peer_addr);
130     //             Err(anyhow::Error::new(e))
131     //         }
132     //     }
133     // }
134 }
135
136 impl Stream for TlsListener {
137     type Item = Connection;
138
139     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140         match self.conn_stream.poll_next(cx) {
141             Poll::Ready(Some(Some(Some((peer_addr, Ok(tls_stream)))))) => {
142                 debug!("Completed TLS handshake with {}", peer_addr);
143                 Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Listener {
144                     local_addr: self.local_addrs.clone(),
145                     peer_addr,
146                     stream: tls_stream,
147                 })))
148             }
149
150             Poll::Ready(Some(Some(Some((peer_addr, Err(err)))))) => {
151                 warn!(
152                     "Could not encrypt connection with TLS from {}: {}",
153                     peer_addr, err
154                 );
155                 Poll::Pending
156             }
157
158             Poll::Pending => Poll::Pending,
159
160             _ => Poll::Ready(None),
161         }
162     }
163 }