]> git.lizzy.rs Git - connect-rs.git/commitdiff
fix tls and cull warnings
authorSachandhan Ganesh <sachan.ganesh@gmail.com>
Fri, 15 Jan 2021 07:42:33 +0000 (23:42 -0800)
committerSachandhan Ganesh <sachan.ganesh@gmail.com>
Fri, 15 Jan 2021 07:42:33 +0000 (23:42 -0800)
15 files changed:
Cargo.toml
examples/tcp-client/Cargo.toml
examples/tcp-client/src/main.rs
examples/tcp-echo-server/Cargo.toml
examples/tcp-echo-server/src/main.rs
schema/message.proto
src/lib.rs
src/reader.rs
src/schema/message.rs
src/schema/mod.rs
src/tcp/client.rs
src/tcp/server.rs
src/tls/client.rs
src/tls/server.rs
src/writer.rs

index f5bf2a0303b083c212ff0efcdb4425f3578a6b76..e8c34c968ab383fd503d01183df777d59490be98 100644 (file)
@@ -3,6 +3,8 @@ name = "connect"
 version = "0.0.2"
 authors = ["Sachandhan Ganesh <sachan.ganesh@gmail.com>"]
 edition = "2018"
+description = "message queue abstraction over async network streams"
+
 
 [dependencies]
 anyhow = "1.0.31"
index 1449e16ab880ad33aa2acaa054135b4131bbc57c..1fb8e2b06773903c59722b598a6ef900808c2746 100644 (file)
@@ -13,7 +13,7 @@ env_logger = "0.7"
 log = "0.4"
 protobuf = "2.18.1"
 
-stitch-net = { path = "../../" }
+connect = { path = "../../" }
 
 [build-dependencies]
-protobuf-codegen-pure = "2.18.1"
\ No newline at end of file
+protobuf-codegen-pure = "2.18.1"
index a5de8e5a3bc81719db23ac990ab2ef21ed59e084..f1b1d5fc084cab92cef447afe87fa98c2a5f75b3 100644 (file)
@@ -1,10 +1,10 @@
 pub mod schema;
 
 use crate::schema::hello_world::HelloWorld;
+use connect::{Connection, SinkExt, StreamExt};
 use log::*;
 use protobuf::well_known_types::Any;
 use std::env;
-use stitch_net::{SinkExt, StitchConnection, StreamExt};
 
 #[async_std::main]
 async fn main() -> anyhow::Result<()> {
@@ -21,7 +21,7 @@ async fn main() -> anyhow::Result<()> {
     };
 
     // create a client connection to the server
-    let mut conn = StitchConnection::tcp_client(ip_address)?;
+    let mut conn = Connection::tcp_client(ip_address)?;
 
     // send a message to the server
     let raw_msg = String::from("Hello world");
index d44763ca5163189df7613b2f15e2abb0c9bcfcc5..29af62d1b1c3fbc95946c70cee8af68fdb722b57 100644 (file)
@@ -13,7 +13,7 @@ env_logger = "0.7"
 log = "0.4"
 protobuf = "2.18.1"
 
-stitch-net = { path = "../../" }
+connect = { path = "../../" }
 
 [build-dependencies]
-protobuf-codegen-pure = "2.18.1"
\ No newline at end of file
+protobuf-codegen-pure = "2.18.1"
index 376ab0c1cfcd7eb3a7d0ce508bf43f69778d3f3e..cc5b5d7d3fa9e2a4f250fe649d18b6e24330ce1d 100644 (file)
@@ -2,10 +2,10 @@ mod schema;
 
 use crate::schema::hello_world::HelloWorld;
 use async_std::task;
+use connect::tcp::TcpServer;
+use connect::{SinkExt, StreamExt};
 use log::*;
 use std::env;
-use stitch_net::tcp::StitchTcpServer;
-use stitch_net::{SinkExt, StreamExt};
 
 #[async_std::main]
 async fn main() -> anyhow::Result<()> {
@@ -23,7 +23,7 @@ async fn main() -> anyhow::Result<()> {
     };
 
     // create a server
-    let mut server = StitchTcpServer::new(ip_address)?;
+    let mut server = TcpServer::new(ip_address)?;
 
     // handle server connections
     // wait for a connection to come in and be accepted
index e80daff36b2fa2ee8c85bbfa5a456adde75e1646..d1f6ee0f4fad5c81afebab070ec5b4cff924773c 100644 (file)
@@ -3,6 +3,6 @@ package message;
 
 import "google/protobuf/any.proto";
 
-message StitchMessage {
+message ConnectionMessage {
     google.protobuf.Any payload = 1;
 }
index c2277629303e8fb5b271866355a44429d2383670..b6163bed50a2cf404570ef8ee753d86ec50e4434 100644 (file)
@@ -1,35 +1,24 @@
+mod reader;
 pub mod schema;
 pub mod tcp;
-// @todo pub mod tls;
-mod reader;
+pub mod tls;
 mod writer;
 
-pub use crate::reader::StitchConnectionReader;
-use crate::schema::StitchMessage;
-pub use crate::writer::StitchConnectionWriter;
-use async_channel::RecvError;
+pub use crate::reader::ConnectionReader;
+pub use crate::writer::ConnectionWriter;
 use async_std::net::SocketAddr;
-use async_std::pin::Pin;
-use bytes::{Buf, BytesMut};
-use futures::task::{Context, Poll};
-use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, Sink, Stream};
-use log::*;
-use protobuf::Message;
-use std::convert::TryInto;
-
-pub use futures::SinkExt;
-pub use futures::StreamExt;
-use protobuf::well_known_types::Any;
+use futures::{AsyncRead, AsyncWrite};
+pub use futures::{SinkExt, StreamExt};
 
-pub struct StitchConnection {
+pub struct Connection {
     local_addr: SocketAddr,
     peer_addr: SocketAddr,
-    reader: StitchConnectionReader,
-    writer: StitchConnectionWriter,
+    reader: ConnectionReader,
+    writer: ConnectionWriter,
 }
 
 #[allow(dead_code)]
-impl StitchConnection {
+impl Connection {
     pub(crate) fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
@@ -39,8 +28,8 @@ impl StitchConnection {
         Self {
             local_addr,
             peer_addr,
-            reader: StitchConnectionReader::new(local_addr, peer_addr, read_stream),
-            writer: StitchConnectionWriter::new(local_addr, peer_addr, write_stream),
+            reader: ConnectionReader::new(local_addr, peer_addr, read_stream),
+            writer: ConnectionWriter::new(local_addr, peer_addr, write_stream),
         }
     }
 
@@ -52,11 +41,11 @@ impl StitchConnection {
         self.peer_addr.clone()
     }
 
-    pub fn split(self) -> (StitchConnectionReader, StitchConnectionWriter) {
+    pub fn split(self) -> (ConnectionReader, ConnectionWriter) {
         (self.reader, self.writer)
     }
 
-    pub fn join(reader: StitchConnectionReader, writer: StitchConnectionWriter) -> Self {
+    pub fn join(reader: ConnectionReader, writer: ConnectionWriter) -> Self {
         Self {
             local_addr: reader.local_addr(),
             peer_addr: reader.peer_addr(),
@@ -65,11 +54,11 @@ impl StitchConnection {
         }
     }
 
-    pub fn reader(&mut self) -> &mut StitchConnectionReader {
+    pub fn reader(&mut self) -> &mut ConnectionReader {
         &mut self.reader
     }
 
-    pub fn writer(&mut self) -> &mut StitchConnectionWriter {
+    pub fn writer(&mut self) -> &mut ConnectionWriter {
         &mut self.writer
     }
 
index 568e40a359ceaee18fbbef74c2675fa0d4b60745..1dd9f98145b03ad8f285ba835470048e7196150a 100644 (file)
@@ -1,4 +1,4 @@
-use crate::schema::StitchMessage;
+use crate::schema::ConnectionMessage;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
 use bytes::{Buf, BytesMut};
@@ -14,14 +14,14 @@ use protobuf::well_known_types::Any;
 
 const BUFFER_SIZE: usize = 8192;
 
-pub struct StitchConnectionReader {
+pub struct ConnectionReader {
     local_addr: SocketAddr,
     peer_addr: SocketAddr,
     read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
     pending_read: Option<BytesMut>,
 }
 
-impl StitchConnectionReader {
+impl ConnectionReader {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
@@ -44,7 +44,7 @@ impl StitchConnectionReader {
     }
 }
 
-impl Stream for StitchConnectionReader {
+impl Stream for ConnectionReader {
     type Item = Any;
 
     fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
@@ -80,7 +80,7 @@ impl Stream for StitchConnectionReader {
                         buffer.resize(bytes_read, 0);
                         debug!("{:?}", buffer.as_ref());
 
-                        match StitchMessage::parse_from_bytes(buffer.as_ref()) {
+                        match ConnectionMessage::parse_from_bytes(buffer.as_ref()) {
                             Ok(mut data) => {
                                 let serialized_size = data.compute_size();
                                 debug!("Deserialized message of size {} bytes", serialized_size);
index 352580a5721ba5864219c38740026e3698f686f3..c41fbb06e070fc36afc56ab4f2ae9383d39227aa 100644 (file)
@@ -24,7 +24,7 @@
 // const _PROTOBUF_VERSION_CHECK: () = ::protobuf::VERSION_2_19_0;
 
 #[derive(PartialEq,Clone,Default)]
-pub struct StitchMessage {
+pub struct ConnectionMessage {
     // message fields
     pub payload: ::protobuf::SingularPtrField<::protobuf::well_known_types::Any>,
     // special fields
@@ -32,14 +32,14 @@ pub struct StitchMessage {
     pub cached_size: ::protobuf::CachedSize,
 }
 
-impl<'a> ::std::default::Default for &'a StitchMessage {
-    fn default() -> &'a StitchMessage {
-        <StitchMessage as ::protobuf::Message>::default_instance()
+impl<'a> ::std::default::Default for &'a ConnectionMessage {
+    fn default() -> &'a ConnectionMessage {
+        <ConnectionMessage as ::protobuf::Message>::default_instance()
     }
 }
 
-impl StitchMessage {
-    pub fn new() -> StitchMessage {
+impl ConnectionMessage {
+    pub fn new() -> ConnectionMessage {
         ::std::default::Default::default()
     }
 
@@ -77,7 +77,7 @@ impl StitchMessage {
     }
 }
 
-impl ::protobuf::Message for StitchMessage {
+impl ::protobuf::Message for ConnectionMessage {
     fn is_initialized(&self) -> bool {
         for v in &self.payload {
             if !v.is_initialized() {
@@ -151,8 +151,8 @@ impl ::protobuf::Message for StitchMessage {
         Self::descriptor_static()
     }
 
-    fn new() -> StitchMessage {
-        StitchMessage::new()
+    fn new() -> ConnectionMessage {
+        ConnectionMessage::new()
     }
 
     fn descriptor_static() -> &'static ::protobuf::reflect::MessageDescriptor {
@@ -161,46 +161,46 @@ impl ::protobuf::Message for StitchMessage {
             let mut fields = ::std::vec::Vec::new();
             fields.push(::protobuf::reflect::accessor::make_singular_ptr_field_accessor::<_, ::protobuf::types::ProtobufTypeMessage<::protobuf::well_known_types::Any>>(
                 "payload",
-                |m: &StitchMessage| { &m.payload },
-                |m: &mut StitchMessage| { &mut m.payload },
+                |m: &ConnectionMessage| { &m.payload },
+                |m: &mut ConnectionMessage| { &mut m.payload },
             ));
-            ::protobuf::reflect::MessageDescriptor::new_pb_name::<StitchMessage>(
-                "StitchMessage",
+            ::protobuf::reflect::MessageDescriptor::new_pb_name::<ConnectionMessage>(
+                "ConnectionMessage",
                 fields,
                 file_descriptor_proto()
             )
         })
     }
 
-    fn default_instance() -> &'static StitchMessage {
-        static instance: ::protobuf::rt::LazyV2<StitchMessage> = ::protobuf::rt::LazyV2::INIT;
-        instance.get(StitchMessage::new)
+    fn default_instance() -> &'static ConnectionMessage {
+        static instance: ::protobuf::rt::LazyV2<ConnectionMessage> = ::protobuf::rt::LazyV2::INIT;
+        instance.get(ConnectionMessage::new)
     }
 }
 
-impl ::protobuf::Clear for StitchMessage {
+impl ::protobuf::Clear for ConnectionMessage {
     fn clear(&mut self) {
         self.payload.clear();
         self.unknown_fields.clear();
     }
 }
 
-impl ::std::fmt::Debug for StitchMessage {
+impl ::std::fmt::Debug for ConnectionMessage {
     fn fmt(&self, f: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
         ::protobuf::text_format::fmt(self, f)
     }
 }
 
-impl ::protobuf::reflect::ProtobufValue for StitchMessage {
+impl ::protobuf::reflect::ProtobufValue for ConnectionMessage {
     fn as_ref(&self) -> ::protobuf::reflect::ReflectValueRef {
         ::protobuf::reflect::ReflectValueRef::Message(self)
     }
 }
 
 static file_descriptor_proto_data: &'static [u8] = b"\
-    \n\rmessage.proto\x12\x07message\x1a\x19google/protobuf/any.proto\"C\n\r\
-    StitchMessage\x120\n\x07payload\x18\x01\x20\x01(\x0b2\x14.google.protobu\
-    f.AnyR\x07payloadB\0:\0B\0b\x06proto3\
+    \n\rmessage.proto\x12\x07message\x1a\x19google/protobuf/any.proto\"G\n\
+    \x11ConnectionMessage\x120\n\x07payload\x18\x01\x20\x01(\x0b2\x14.google\
+    .protobuf.AnyR\x07payloadB\0:\0B\0b\x06proto3\
 ";
 
 static file_descriptor_proto_lazy: ::protobuf::rt::LazyV2<::protobuf::descriptor::FileDescriptorProto> = ::protobuf::rt::LazyV2::INIT;
index 1b8ac931a77a936a0aeffd828896f2b358165334..49c0b7c849a0a354f47eadff247fc02ba0fd7816 100644 (file)
@@ -1,12 +1,11 @@
 mod message;
 
-pub use message::StitchMessage;
+pub use message::ConnectionMessage;
 use protobuf::well_known_types::Any;
 use protobuf::Message;
 
-impl StitchMessage {
-    // @todo make pub(crate)
-    pub fn from_msg<T: Message>(msg: T) -> Self {
+impl ConnectionMessage {
+    pub(crate) fn from_msg<T: Message>(msg: T) -> Self {
         let mut sm = Self::new();
         let payload = Any::pack(&msg).expect("Protobuf Message could not be packed into Any type");
 
index a0994d4694940788869e30a296e42e41cd220958..4c85138af01d95bfb13c629fc2163c529821786a 100644 (file)
@@ -1,21 +1,20 @@
 use async_std::task;
 use log::*;
 
-use crate::StitchConnection;
+use crate::Connection;
 use async_std::net::{TcpStream, ToSocketAddrs};
 
-impl StitchConnection {
-    pub fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-    ) -> anyhow::Result<StitchConnection> {
-        let read_stream = task::block_on(TcpStream::connect(&ip_addrs))?;
+impl Connection {
+    pub fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
+        let stream = task::block_on(TcpStream::connect(&ip_addrs))?;
         info!("Established client TCP connection to {}", ip_addrs);
 
-        Ok(Self::from(read_stream))
+        stream.set_nodelay(true)?;
+        Ok(Self::from(stream))
     }
 }
 
-impl From<TcpStream> for StitchConnection {
+impl From<TcpStream> for Connection {
     fn from(stream: TcpStream) -> Self {
         let write_stream = stream.clone();
 
index adae4ccf4faa948bb69d615e3d8be4d9b309b41e..43fd1dec1c0f6d1e468e6655afc452d81901826c 100644 (file)
@@ -1,4 +1,4 @@
-use crate::StitchConnection;
+use crate::Connection;
 use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs};
 use async_std::pin::Pin;
 use async_std::task;
@@ -7,12 +7,12 @@ use futures::{Stream, StreamExt};
 use log::*;
 
 #[allow(dead_code)]
-pub struct StitchTcpServer {
+pub struct TcpServer {
     local_addrs: SocketAddr,
     listener: TcpListener,
 }
 
-impl StitchTcpServer {
+impl TcpServer {
     pub fn new<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
         let listener = task::block_on(TcpListener::bind(&ip_addrs))?;
         info!("Started TCP server at {}", &ip_addrs);
@@ -24,12 +24,12 @@ impl StitchTcpServer {
     }
 }
 
-impl Stream for StitchTcpServer {
-    type Item = StitchConnection;
+impl Stream for TcpServer {
+    type Item = Connection;
 
     fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
         if let Some(Ok(conn)) = futures::executor::block_on(self.listener.incoming().next()) {
-            Poll::Ready(Some(StitchConnection::from(conn)))
+            Poll::Ready(Some(Connection::from(conn)))
         } else {
             Poll::Ready(None)
         }
index 5f9acb22507b968bd4ecfea0c9a263f13999b09a..ef0ea6401953c8687b3f2e3dc6719f5ae2e0f311 100644 (file)
@@ -1,76 +1,65 @@
-use async_channel::{Receiver, Sender};
-use async_std::io::*;
-use async_std::net::*;
 use async_std::task;
 use async_tls::TlsConnector;
-use futures_util::io::AsyncReadExt;
 use log::*;
 
-use crate::registry::StitchRegistry;
-use crate::StitchNetClient;
-use crate::{channel_factory, StitchMessage};
-use async_std::sync::{Arc, Condvar, Mutex};
+use crate::Connection;
+use async_std::net::{TcpStream, SocketAddr, ToSocketAddrs};
+use async_tls::client;
+use async_tls::server;
+use futures::AsyncReadExt;
 
-impl StitchNetClient {
-    pub fn tls_client<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-        domain: &str,
-        connector: TlsConnector,
-    ) -> Result<Self> {
-        Self::tls_client_with_bound(ip_addrs, domain, connector, None)
-    }
+pub enum TlsConnectionMetadata {
+    Client { local_addr: SocketAddr, peer_addr: SocketAddr, stream: client::TlsStream<TcpStream> },
+    Server { local_addr: SocketAddr, peer_addr: SocketAddr, stream: server::TlsStream<TcpStream> },
+}
 
-    pub fn tls_client_with_bound<A: ToSocketAddrs + std::fmt::Display>(
+impl Connection {
+    pub fn tls_client<A: ToSocketAddrs + std::fmt::Display>(
         ip_addrs: A,
         domain: &str,
         connector: TlsConnector,
-        cap: Option<usize>,
-    ) -> Result<Self> {
+    ) -> anyhow::Result<Self> {
         let stream = task::block_on(TcpStream::connect(&ip_addrs))?;
-        stream.set_nodelay(true)?;
         info!("Established client TCP connection to {}", ip_addrs);
+        stream.set_nodelay(true)?;
 
-        Self::tls_client_from_parts(stream, domain, connector, channel_factory(cap))
-    }
-
-    pub fn tls_client_from_parts(
-        stream: TcpStream,
-        domain: &str,
-        connector: TlsConnector,
-        (tls_write_sender, tls_write_receiver): (Sender<StitchMessage>, Receiver<StitchMessage>),
-    ) -> Result<Self> {
-        let local_addr = stream.local_addr()?;
+        let local_addr = stream.peer_addr()?;
         let peer_addr = stream.peer_addr()?;
 
-        let encrypted_stream = task::block_on(connector.connect(domain, stream))?;
-        let (read_stream, write_stream) = encrypted_stream.split();
+        let encrypted_stream: client::TlsStream<TcpStream> =
+            task::block_on(connector.connect(domain, stream))?;
         info!("Completed TLS handshake with {}", peer_addr);
 
-        let registry: StitchRegistry = crate::registry::new();
-        let read_readiness = Arc::new((Mutex::new(false), Condvar::new()));
-        let write_readiness = Arc::new((Mutex::new(false), Condvar::new()));
+        Ok(Self::from(TlsConnectionMetadata::Client { local_addr, peer_addr, stream: encrypted_stream }))
+    }
+}
+
+impl From<TlsConnectionMetadata> for Connection {
+    fn from(metadata: TlsConnectionMetadata) -> Self {
+        match metadata {
+            TlsConnectionMetadata::Client { local_addr, peer_addr, stream } => {
+                let (read_stream, write_stream) = stream.split();
+
+                Self::new(
+                    local_addr,
+                    peer_addr,
+                    Box::new(read_stream),
+                    Box::new(write_stream),
+                )
+            },
+
+            TlsConnectionMetadata::Server { local_addr, peer_addr, stream } => {
+                let (read_stream, write_stream) = stream.split();
 
-        let read_task = task::spawn(crate::tasks::read_from_stream(
-            registry.clone(),
-            read_stream,
-            read_readiness.clone(),
-        ));
+                Self::new(
+                    local_addr,
+                    peer_addr,
+                    Box::new(read_stream),
+                    Box::new(write_stream),
+                )
+            }
+        }
 
-        let write_task = task::spawn(crate::tasks::write_to_stream(
-            tls_write_receiver.clone(),
-            write_stream,
-            write_readiness.clone(),
-        ));
 
-        Ok(Self {
-            local_addr,
-            peer_addr,
-            registry,
-            stream_writer_chan: (tls_write_sender, tls_write_receiver),
-            read_readiness,
-            write_readiness,
-            read_task,
-            write_task,
-        })
     }
 }
index 79dba44033dd8c94015b823d0eb773d52f2df901..66e4206dcb844a67ffc4088ced0895cc0ba4a0de 100644 (file)
-use crate::channel_factory;
-use crate::registry::StitchRegistry;
-use crate::{ServerRegistry, StitchClient, StitchNetClient, StitchNetServer};
-use async_channel::{Receiver, Sender};
-use async_std::io::*;
+use crate::Connection;
+use crate::tls::TlsConnectionMetadata;
 use async_std::net::*;
+use async_std::pin::Pin;
 use async_std::prelude::*;
-use async_std::sync::{Arc, Condvar, Mutex};
 use async_std::task;
 use async_tls::TlsAcceptor;
-use dashmap::DashMap;
-use futures_util::AsyncReadExt;
+use futures::task::{Context, Poll};
 use log::*;
 
-impl StitchNetServer {
-    pub fn tls_server<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-        acceptor: TlsAcceptor,
-    ) -> Result<(StitchNetServer, Receiver<Arc<StitchNetClient>>)> {
-        Self::tls_server_with_bound(ip_addrs, acceptor, None)
-    }
+#[allow(dead_code)]
+pub struct TlsServer {
+    local_addrs: SocketAddr,
+    listener: TcpListener,
+    acceptor: TlsAcceptor,
+}
 
-    pub fn tls_server_with_bound<A: ToSocketAddrs + std::fmt::Display>(
-        ip_addrs: A,
-        acceptor: TlsAcceptor,
-        cap: Option<usize>,
-    ) -> Result<(Self, Receiver<Arc<StitchNetClient>>)> {
+impl TlsServer {
+    pub fn new<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A, acceptor: TlsAcceptor) -> anyhow::Result<Self> {
         let listener = task::block_on(TcpListener::bind(ip_addrs))?;
         info!("Started TLS server at {}", listener.local_addr()?);
 
-        let registry = Arc::new(DashMap::new());
-        let (sender, receiver) = channel_factory(cap);
-
-        let handler = task::spawn(handle_server_connections(
-            acceptor,
-            registry.clone(),
+        Ok(Self {
+            local_addrs: listener.local_addr()?,
             listener,
-            sender.clone(),
-            cap,
-        ));
-
-        Ok((
-            Self {
-                registry,
-                connections_chan: (sender, receiver.clone()),
-                accept_loop_task: handler,
-            },
-            receiver,
-        ))
+            acceptor,
+        })
     }
 }
 
-async fn handle_server_connections<'a>(
-    acceptor: TlsAcceptor,
-    registry: ServerRegistry,
-    input: TcpListener,
-    output: Sender<Arc<StitchNetClient>>,
-    cap: Option<usize>,
-) -> anyhow::Result<()> {
-    let mut conns = input.incoming();
-
-    debug!("Reading from the stream of incoming connections");
-    loop {
-        match conns.next().await {
-            Some(Ok(tcp_stream)) => {
-                let local_addr = tcp_stream.local_addr()?;
-                let peer_addr = tcp_stream.peer_addr()?;
-
-                debug!("Received connection attempt from {}", peer_addr);
-
-                let tls_stream = acceptor.accept(tcp_stream).await?;
-
-                let (read_stream, write_stream) = tls_stream.split();
-                let (tls_write_sender, tls_write_receiver) = channel_factory(cap);
-
-                let client_registry: StitchRegistry = crate::registry::new();
-                let read_readiness = Arc::new((Mutex::new(false), Condvar::new()));
-                let write_readiness = Arc::new((Mutex::new(false), Condvar::new()));
-
-                let read_task = task::spawn(crate::tasks::read_from_stream(
-                    client_registry.clone(),
-                    read_stream,
-                    read_readiness.clone(),
-                ));
-
-                let write_task = task::spawn(crate::tasks::write_to_stream(
-                    tls_write_receiver.clone(),
-                    write_stream,
-                    write_readiness.clone(),
-                ));
-
-                let conn = StitchNetClient {
-                    local_addr,
-                    peer_addr,
-                    registry: client_registry,
-                    stream_writer_chan: (tls_write_sender, tls_write_receiver),
-                    read_readiness,
-                    write_readiness,
-                    read_task,
-                    write_task,
-                };
-
-                debug!("Attempting to register connection from {}", peer_addr);
-                let conn = Arc::new(conn);
-                registry.insert(conn.peer_addr(), conn.clone());
-                debug!(
-                    "Registered client connection for {} in server registry",
-                    peer_addr
-                );
-
-                if let Err(err) = output.send(conn).await {
-                    error!(
-                        "Stopping the server accept loop - could not send accepted TLS client connection to channel: {:#?}",
-                        err
-                    );
-
-                    break Err(anyhow::Error::from(err));
-                } else {
-                    info!("Accepted connection from {}", peer_addr);
-                }
-            }
-
-            Some(Err(err)) => error!(
-                "Encountered error when accepting TLS connection: {:#?}",
-                err
-            ),
-
-            None => {
-                warn!("Stopping the server accept loop - unable to accept any more connections");
-
-                break Ok(());
+impl Stream for TlsServer {
+    type Item = Connection;
+
+    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
+        if let Some(Ok(tcp_stream)) = futures::executor::block_on(self.listener.incoming().next()) {
+            let local_addr = tcp_stream.local_addr().expect(
+                "Local address could not be retrieved",
+            );
+
+            let peer_addr = tcp_stream.peer_addr().expect(
+                "Peer address could not be retrieved",
+            );
+            debug!("Received connection attempt from {}", peer_addr);
+
+            if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
+                debug!("Established TLS connection from {}", peer_addr);
+                Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server{ local_addr, peer_addr, stream: tls_stream })))
+            } else {
+                debug!("Could not encrypt connection with TLS from {}", peer_addr);
+                Poll::Pending
             }
+        } else {
+            Poll::Ready(None)
         }
     }
 }
index f6f267edf335d2c40cdd1cf1e9f89bde32642128..c2275ac9d39b45866ae5ef10d0ed947d1dcab686 100644 (file)
@@ -1,4 +1,4 @@
-use crate::schema::StitchMessage;
+use crate::schema::ConnectionMessage;
 use async_channel::RecvError;
 use async_std::net::SocketAddr;
 use async_std::pin::Pin;
@@ -10,14 +10,14 @@ use protobuf::Message;
 pub use futures::SinkExt;
 pub use futures::StreamExt;
 
-pub struct StitchConnectionWriter {
+pub struct ConnectionWriter {
     local_addr: SocketAddr,
     peer_addr: SocketAddr,
     write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
-    pending_write: Option<StitchMessage>,
+    pending_write: Option<ConnectionMessage>,
 }
 
-impl StitchConnectionWriter {
+impl ConnectionWriter {
     pub fn new(
         local_addr: SocketAddr,
         peer_addr: SocketAddr,
@@ -40,7 +40,7 @@ impl StitchConnectionWriter {
     }
 }
 
-impl<T: Message> Sink<T> for StitchConnectionWriter {
+impl<T: Message> Sink<T> for ConnectionWriter {
     type Error = RecvError;
 
     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
@@ -55,7 +55,7 @@ impl<T: Message> Sink<T> for StitchConnectionWriter {
 
     fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
         debug!("Preparing message to be sent next");
-        let stitch_msg: StitchMessage = StitchMessage::from_msg(item);
+        let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
         self.pending_write.replace(stitch_msg);
 
         Ok(())