async-tls = { version = "0.9.0", default-features = false, features = ["client", "server"]}
bytes = "0.5.5"
futures = "0.3.8"
-futures-lite = "1.11.3"
log = "0.4"
protobuf = "2.18.1"
rustls = "0.18.0"
};
// create a client connection to the server
- let mut conn = Connection::tcp_client(ip_address)?;
+ let mut conn = Connection::tcp_client(ip_address).await?;
// send a message to the server
let raw_msg = String::from("Hello world");
- info!("Sending message: {}", raw_msg);
let mut msg = HelloWorld::new();
- msg.set_message(raw_msg);
+ msg.set_message(raw_msg.clone());
conn.writer().send(msg).await?;
+ info!("Sent message: {}", raw_msg);
// wait for the server to reply with an ack
while let Some(reply) = conn.reader().next().await {
};
// create a server
- let mut server = TcpServer::new(ip_address)?;
+ let server = TcpServer::new(ip_address).await?;
// handle server connections
// wait for a connection to come in and be accepted
- while let Some(mut conn) = server.next().await {
+ while let Ok(mut conn) = server.accept().await {
info!("Handling connection from {}", conn.peer_addr());
task::spawn(async move {
.map_err(|_| std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid cert"))?;
// create a client connection to the server
- let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into())?;
+ let mut conn = Connection::tls_client(ip_addr, &domain, client_config.into()).await?;
// send a message to the server
let raw_msg = String::from("Hello world");
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
// create a server
- let mut server = TlsServer::new(ip_address, config.into())?;
+ let server = TlsServer::new(ip_address, config.into()).await?;
// handle server connections
// wait for a connection to come in and be accepted
- while let Some(mut conn) = server.next().await {
- info!("Handling connection from {}", conn.peer_addr());
-
- task::spawn(async move {
- while let Some(msg) = conn.reader().next().await {
- if msg.is::<HelloWorld>() {
- if let Ok(Some(contents)) = msg.unpack::<HelloWorld>() {
- info!(
- "Received a message \"{}\" from {}",
- contents.get_message(),
- conn.peer_addr()
- );
-
- conn.writer()
- .send(contents)
- .await
- .expect("Could not send message back to source connection");
- info!("Sent message back to original sender");
+ loop {
+ match server.accept().await {
+ Ok(Some(mut conn)) => {
+ info!("Handling connection from {}", conn.peer_addr());
+
+ task::spawn(async move {
+ while let Some(msg) = conn.reader().next().await {
+ if msg.is::<HelloWorld>() {
+ if let Ok(Some(contents)) = msg.unpack::<HelloWorld>() {
+ info!(
+ "Received a message \"{}\" from {}",
+ contents.get_message(),
+ conn.peer_addr()
+ );
+
+ conn.writer()
+ .send(contents)
+ .await
+ .expect("Could not send message back to source connection");
+ info!("Sent message back to original sender");
+ }
+ } else {
+ error!("Received a message of unknown type")
+ }
}
- } else {
- error!("Received a message of unknown type")
- }
+ });
}
- });
+
+ Ok(None) => (),
+
+ Err(e) => {
+ error!("Encountered error when accepting connection: {}", e);
+ break
+ }
+ }
}
Ok(())
const BUFFER_SIZE: usize = 8192;
pub struct ConnectionReader {
- local_addr: SocketAddr,
- peer_addr: SocketAddr,
- read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
+ local_addr: SocketAddr,
+ peer_addr: SocketAddr,
+ read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
pending_read: Option<BytesMut>,
- closed: bool,
+ closed: bool,
}
impl ConnectionReader {
trace!("Read {} bytes from the network stream", bytes_read)
} else if self.pending_read.is_none() {
self.close_stream();
- return Poll::Ready(None)
+ return Poll::Ready(None);
}
if let Some(mut pending_buf) = self.pending_read.take() {
// Close the stream
Poll::Ready(Err(_e)) => {
self.close_stream();
- return Poll::Ready(None)
+ return Poll::Ready(None);
}
}
}
use async_std::net::{TcpStream, ToSocketAddrs};
impl Connection {
- pub async fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(ip_addrs: A) -> anyhow::Result<Self> {
+ pub async fn tcp_client<A: ToSocketAddrs + std::fmt::Display>(
+ ip_addrs: A,
+ ) -> anyhow::Result<Self> {
let stream = TcpStream::connect(&ip_addrs).await?;
info!("Established client TCP connection to {}", ip_addrs);
use crate::Connection;
use async_std::net::{SocketAddr, TcpListener, ToSocketAddrs};
-use async_std::pin::Pin;
-use futures::task::{Context, Poll};
-use futures::Stream;
-use futures_lite::stream::StreamExt;
use log::*;
#[allow(dead_code)]
listener,
})
}
-}
-
-impl Stream for TcpServer {
- type Item = Connection;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- match self.listener.incoming().poll_next(cx) {
- Poll::Pending => Poll::Pending,
-
- Poll::Ready(Some(Ok(conn))) => {
- debug!(
- "Received connection attempt from {}",
- conn.peer_addr()
- .expect("Peer address could not be retrieved")
- );
-
- Poll::Ready(Some(Connection::from(conn)))
- },
-
- Poll::Ready(Some(Err(e))) => {
- error!(
- "Encountered error when accepting connection attempt: {}", e
- );
- Poll::Pending
- },
+ pub async fn accept(&self) -> anyhow::Result<Connection> {
+ let (stream, ip_addr) = self.listener.accept().await?;
+ debug!("Received connection attempt from {}", ip_addr);
- Poll::Ready(None) => {
- info!("Shutting TCP server down at {}", self.local_addrs);
- Poll::Ready(None)
- },
- }
+ Ok(Connection::from(stream))
}
}
let local_addr = stream.peer_addr()?;
let peer_addr = stream.peer_addr()?;
- let encrypted_stream: client::TlsStream<TcpStream> = connector.connect(domain, stream).await?;
+ let encrypted_stream: client::TlsStream<TcpStream> =
+ connector.connect(domain, stream).await?;
info!("Completed TLS handshake with {}", peer_addr);
Ok(Self::from(TlsConnectionMetadata::Client {
use crate::tls::TlsConnectionMetadata;
use crate::Connection;
use async_std::net::*;
-use async_std::pin::Pin;
use async_tls::TlsAcceptor;
-use futures::Stream;
-use futures::task::{Context, Poll};
-use futures_lite::StreamExt;
use log::*;
#[allow(dead_code)]
acceptor,
})
}
-}
-
-impl Stream for TlsServer {
- type Item = Connection;
-
- fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- match self.listener.incoming().poll_next(cx) {
- Poll::Pending => Poll::Pending,
-
- Poll::Ready(Some(Ok(tcp_stream))) => {
- let local_addr = tcp_stream
- .local_addr()
- .expect("Local address could not be retrieved");
-
- let peer_addr = tcp_stream
- .peer_addr()
- .expect("Peer address could not be retrieved");
-
- debug!(
- "Received connection attempt from {}", peer_addr
- );
-
- if let Ok(tls_stream) = futures::executor::block_on(self.acceptor.accept(tcp_stream)) {
- debug!("Completed TLS handshake with {}", peer_addr);
- Poll::Ready(Some(Connection::from(TlsConnectionMetadata::Server {
- local_addr,
- peer_addr,
- stream: tls_stream,
- })))
- } else {
- warn!("Could not encrypt connection with TLS from {}", peer_addr);
- Poll::Pending
- }
- },
-
- Poll::Ready(Some(Err(e))) => {
- error!(
- "Encountered error when accepting connection attempt: {}", e
- );
-
- Poll::Pending
- }
- Poll::Ready(None) => {
- info!("Shutting TLS server down at {}", self.local_addrs);
- Poll::Ready(None)
- },
+ pub async fn accept(&self) -> anyhow::Result<Option<Connection>> {
+ let (tcp_stream, peer_addr) = self.listener.accept().await?;
+ debug!("Received connection attempt from {}", peer_addr);
+
+ if let Ok(tls_stream) = self.acceptor.accept(tcp_stream).await {
+ debug!("Completed TLS handshake with {}", peer_addr);
+ Ok(Some(Connection::from(TlsConnectionMetadata::Server {
+ local_addr: self.local_addrs.clone(),
+ peer_addr,
+ stream: tls_stream,
+ })))
+ } else {
+ warn!("Could not encrypt connection with TLS from {}", peer_addr);
+ Ok(None)
}
}
}
use async_channel::RecvError;
use async_std::net::SocketAddr;
use async_std::pin::Pin;
-use futures::{AsyncWrite, Sink};
use futures::io::IoSlice;
use futures::task::{Context, Poll};
+use futures::{AsyncWrite, Sink};
use log::*;
use protobuf::Message;
pub use futures::StreamExt;
pub struct ConnectionWriter {
- local_addr: SocketAddr,
- peer_addr: SocketAddr,
- write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
+ local_addr: SocketAddr,
+ peer_addr: SocketAddr,
+ write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
pending_writes: Vec<Vec<u8>>,
- closed: bool,
+ closed: bool,
}
impl ConnectionWriter {
}
}
- fn poll_flush(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<(), Self::Error>> {
+ fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
if self.pending_writes.len() > 0 {
let stream = self.write_stream.as_mut();
trace!("Sending pending bytes");
let pending = self.pending_writes.split_off(0);
- let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
- IoSlice::new(p)
- }).collect();
+ let writeable_vec: Vec<IoSlice> =
+ pending.iter().map(|p| IoSlice::new(p)).collect();
let stream = self.write_stream.as_mut();
match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
Poll::Ready(Ok(bytes_written)) => {
trace!("Wrote {} bytes to network stream", bytes_written);
Poll::Ready(Ok(()))
- },
+ }
Poll::Ready(Err(_e)) => {
error!("Encountered error when writing to network stream");
Poll::Ready(Err(RecvError))
- },
+ }
}
- },
+ }
Poll::Ready(Err(_e)) => {
error!("Encountered error when flushing network stream");
}
}
- fn poll_close(
- mut self: Pin<&mut Self>,
- cx: &mut Context<'_>,
- ) -> Poll<Result<(), Self::Error>> {
+ fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.closed = true;
let flush = if self.pending_writes.len() > 0 {
trace!("Sending pending bytes");
let pending = self.pending_writes.split_off(0);
- let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
- IoSlice::new(p)
- }).collect();
+ let writeable_vec: Vec<IoSlice> =
+ pending.iter().map(|p| IoSlice::new(p)).collect();
let stream = self.write_stream.as_mut();
match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
Poll::Ready(Ok(bytes_written)) => {
trace!("Wrote {} bytes to network stream", bytes_written);
Poll::Ready(Ok(()))
- },
+ }
Poll::Ready(Err(_e)) => {
error!("Encountered error when writing to network stream");
Poll::Ready(Err(RecvError))
- },
+ }
}
- },
+ }
Poll::Ready(Err(_e)) => {
error!("Encountered error when flushing network stream");
Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
}
- },
+ }
err => err,
}