]> git.lizzy.rs Git - connect-rs.git/blob - examples/tls-echo-server/src/main.rs
make async-oriented, remove block_on
[connect-rs.git] / examples / tls-echo-server / src / main.rs
1 mod schema;
2
3 use crate::schema::hello_world::HelloWorld;
4 use async_std::{io, task};
5 use connect::tls::rustls::internal::pemfile::{certs, rsa_private_keys};
6 use connect::tls::rustls::{NoClientAuth, ServerConfig};
7 use connect::tls::TlsServer;
8 use connect::{SinkExt, StreamExt};
9 use log::*;
10 use std::env;
11 use std::fs::File;
12 use std::io::BufReader;
13
14 #[async_std::main]
15 async fn main() -> anyhow::Result<()> {
16     env_logger::init();
17
18     // Get ip address from cmd line args
19     let (ip_address, cert_path, key_path) = parse_args();
20
21     let certs = certs(&mut BufReader::new(File::open(cert_path)?))
22         .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid cert"))?;
23
24     let mut keys = rsa_private_keys(&mut BufReader::new(File::open(key_path)?))
25         .map_err(|_| io::Error::new(io::ErrorKind::InvalidInput, "invalid key"))?;
26
27     let mut config = ServerConfig::new(NoClientAuth::new());
28     config
29         .set_single_cert(certs, keys.remove(0))
30         .map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
31
32     // create a server
33     let server = TlsServer::new(ip_address, config.into()).await?;
34
35     // handle server connections
36     // wait for a connection to come in and be accepted
37     loop {
38         match server.accept().await {
39             Ok(Some(mut conn)) => {
40                 info!("Handling connection from {}", conn.peer_addr());
41
42                 task::spawn(async move {
43                     while let Some(msg) = conn.reader().next().await {
44                         if msg.is::<HelloWorld>() {
45                             if let Ok(Some(contents)) = msg.unpack::<HelloWorld>() {
46                                 info!(
47                                     "Received a message \"{}\" from {}",
48                                     contents.get_message(),
49                                     conn.peer_addr()
50                                 );
51
52                                 conn.writer()
53                                     .send(contents)
54                                     .await
55                                     .expect("Could not send message back to source connection");
56                                 info!("Sent message back to original sender");
57                             }
58                         } else {
59                             error!("Received a message of unknown type")
60                         }
61                     }
62                 });
63             }
64
65             Ok(None) => (),
66
67             Err(e) => {
68                 error!("Encountered error when accepting connection: {}", e);
69                 break
70             }
71         }
72     }
73
74     Ok(())
75 }
76
77 fn parse_args() -> (String, String, String) {
78     let args: Vec<String> = env::args().collect();
79
80     let ip_address = match args.get(1) {
81         Some(addr) => addr,
82         None => {
83             error!("Need to pass IP address to connect to as first command line argument");
84             panic!();
85         }
86     };
87
88     let cert_path = match args.get(2) {
89         Some(d) => d,
90         None => {
91             error!("Need to pass path to cert file as second command line argument");
92             panic!();
93         }
94     };
95
96     let key_path = match args.get(3) {
97         Some(d) => d,
98         None => {
99             error!("Need to pass path to key file as third command line argument");
100             panic!();
101         }
102     };
103
104     (
105         ip_address.to_string(),
106         cert_path.to_string(),
107         key_path.to_string(),
108     )
109 }