]> git.lizzy.rs Git - connect-rs.git/blobdiff - src/tls/server.rs
fix tls and cull warnings
[connect-rs.git] / src / tls / server.rs
index 79dba44033dd8c94015b823d0eb773d52f2df901..66e4206dcb844a67ffc4088ced0895cc0ba4a0de 100644 (file)
-use crate::channel_factory;
-use crate::registry::StitchRegistry;
-use crate::{ServerRegistry, StitchClient, StitchNetClient, StitchNetServer};
-use async_channel::{Receiver, Sender};
-use async_std::io::*;
+use crate::Connection;
+use crate::tls::TlsConnectionMetadata;
 use async_std::net::*;
+use async_std::pin::Pin;
 use async_std::prelude::*;
-use async_std::sync::{Arc, Condvar, Mutex};
 use async_std::task;
 use async_tls::TlsAcceptor;
-use dashmap::DashMap;
-use futures_util::AsyncReadExt;
+use futures::task::{Context, Poll};
 use log::*;
 
-impl StitchNetServer {
-    pub fn tls_server<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-        acceptor: TlsAcceptor,
-    ) -> Result<(StitchNetServer, Receiver<Arc<StitchNetClient>>)> {
-        Self::tls_server_with_bound(ip_addrs, acceptor, None)
-    }
+#[allow(dead_code)]
+pub struct TlsServer {
+    local_addrs: SocketAddr,
+    listener: TcpListener,
+    acceptor: TlsAcceptor,
+}
 
-    pub fn tls_server_with_bound<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-        acceptor: TlsAcceptor,
-        cap: Option<usize>,
-    ) -> Result<(Self, Receiver<Arc<StitchNetClient>>)> {
+impl TlsServer {
+    pub fn new<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A, acceptor: TlsAcceptor) -> anyhow::Result<Self> {
         let listener = task::block_on(TcpListener::bind(ip_addrs))?;
         info!("Started TLS server at {}", listener.local_addr()?);
 
-        let registry = Arc::new(DashMap::new());
-        let (sender, receiver) = channel_factory(cap);
-
-        let handler = task::spawn(handle_server_connections(
-            acceptor,
-            registry.clone(),
+        Ok(Self {
+            local_addrs: listener.local_addr()?,
             listener,
-            sender.clone(),
-            cap,
-        ));
-
-        Ok((
-            Self {
-                registry,
-                connections_chan: (sender, receiver.clone()),
-                accept_loop_task: handler,
-            },
-            receiver,
-        ))
+            acceptor,
+        })
     }
 }
 
-async fn handle_server_connections<'a>(
-    acceptor: TlsAcceptor,
-    registry: ServerRegistry,
-    input: TcpListener,
-    output: Sender<Arc<StitchNetClient>>,
-    cap: Option<usize>,
-) -> anyhow::Result<()> {
-    let mut conns = input.incoming();
-
-    debug!("Reading from the stream of incoming connections");
-    loop {
-        match conns.next().await {
-            Some(Ok(tcp_stream)) => {
-                let local_addr = tcp_stream.local_addr()?;
-                let peer_addr = tcp_stream.peer_addr()?;
-
-                debug!("Received connection attempt from {}", peer_addr);
-
-                let tls_stream = acceptor.accept(tcp_stream).await?;
-
-                let (read_stream, write_stream) = tls_stream.split();
-                let (tls_write_sender, tls_write_receiver) = channel_factory(cap);
-
-                let client_registry: StitchRegistry = crate::registry::new();
-                let read_readiness = Arc::new((Mutex::new(false), Condvar::new()));
-                let write_readiness = Arc::new((Mutex::new(false), Condvar::new()));
-
-                let read_task = task::spawn(crate::tasks::read_from_stream(
-                    client_registry.clone(),
-                    read_stream,
-                    read_readiness.clone(),
-                ));
-
-                let write_task = task::spawn(crate::tasks::write_to_stream(
-                    tls_write_receiver.clone(),
-                    write_stream,
-                    write_readiness.clone(),
-                ));
-
-                let conn = StitchNetClient {
-                    local_addr,
-                    peer_addr,
-                    registry: client_registry,
-                    stream_writer_chan: (tls_write_sender, tls_write_receiver),
-                    read_readiness,
-                    write_readiness,
-                    read_task,
-                    write_task,
-                };
-
-                debug!("Attempting to register connection from {}", peer_addr);
-                let conn = Arc::new(conn);
-                registry.insert(conn.peer_addr(), conn.clone());
-                debug!(
-                    "Registered client connection for {} in server registry",
-                    peer_addr
-                );
-
-                if let Err(err) = output.send(conn).await {
-                    error!(
-                        "Stopping the server accept loop - could not send accepted TLS client connection to channel: {:#?}",
-                        err
-                    );
-
-                    break Err(anyhow::Error::from(err));
-                } else {
-                    info!("Accepted connection from {}", peer_addr);
-                }
-            }
-
-            Some(Err(err)) => error!(
-                "Encountered error when accepting TLS connection: {:#?}",
-                err
-            ),
-
-            None => {
-                warn!("Stopping the server accept loop - unable to accept any more connections");
-
-                break Ok(());
+impl Stream for TlsServer {
+    type Item = Connection;
+
+    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        if let Some(Ok(tcp_stream)) = futures::executor::block_on(self.listener.incoming().next()) {
+            let local_addr = tcp_stream.local_addr().expect(
+                "Local address could not be retrieved",
+            );
+
+            let peer_addr = tcp_stream.peer_addr().expect(
+                "Peer address could not be retrieved",
+            );
+            debug!("Received connection attempt from {}", peer_addr);
+
+            if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
+                debug!("Established TLS connection from {}", peer_addr);
+                Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server{ local_addr, peer_addr, stream: tls_stream })))
+            } else {
+                debug!("Could not encrypt connection with TLS from {}", peer_addr);
+                Poll::Pending
             }
+        } else {
+            Poll::Ready(None)
         }
     }
 }