]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
remove `block_on` in tls-listener
[connect-rs.git] / src / writer.rs
1 use crate::protocol::ConnectDatagram;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use futures::io::IoSlice;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, Sink};
7 use log::*;
8 use std::error::Error;
9
10 pub use futures::SinkExt;
11 pub use futures::StreamExt;
12 use std::fmt::Debug;
13
14 /// Encountered when there is an issue with writing messages on the network stream.
15 ///
16 #[derive(Debug)]
17 pub enum ConnectionWriteError {
18     /// Encountered when trying to send a message while the connection is closed.
19     ConnectionClosed,
20
21     /// Encountered when there is an IO-level error with the connection.
22     IoError(std::io::Error),
23 }
24
25 impl Error for ConnectionWriteError {}
26
27 impl std::fmt::Display for ConnectionWriteError {
28     fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
29         match self {
30             ConnectionWriteError::ConnectionClosed => {
31                 formatter.write_str("cannot send message when connection is closed")
32             }
33             ConnectionWriteError::IoError(err) => std::fmt::Display::fmt(&err, formatter),
34         }
35     }
36 }
37
38 /// An interface to write messages to the network connection.
39 ///
40 /// Implements the [`Sink`] trait to asynchronously write messages to the network connection.
41 ///
42 /// # Example
43 ///
44 /// Basic usage:
45 ///
46 /// ```ignore
47 /// writer.send(msg).await?;
48 /// ```
49 ///
50 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
51 /// example program or other client example programs for a more thorough showcase.
52 ///
53 pub struct ConnectionWriter {
54     local_addr: SocketAddr,
55     peer_addr: SocketAddr,
56     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
57     pending_writes: Vec<Vec<u8>>,
58     closed: bool,
59 }
60
61 impl ConnectionWriter {
62     /// Creates a new [`ConnectionWriter`] from an [`AsyncWrite`] trait object and the local and peer
63     /// socket metadata.
64     pub fn new(
65         local_addr: SocketAddr,
66         peer_addr: SocketAddr,
67         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
68     ) -> Self {
69         Self {
70             local_addr,
71             peer_addr,
72             write_stream,
73             pending_writes: Vec::new(),
74             closed: false,
75         }
76     }
77
78     /// Get the local IP address and port.
79     pub fn local_addr(&self) -> SocketAddr {
80         self.local_addr.clone()
81     }
82
83     /// Get the peer IP address and port.
84     pub fn peer_addr(&self) -> SocketAddr {
85         self.peer_addr.clone()
86     }
87
88     /// Check if the [`Sink`] of messages to the network is closed.
89     pub fn is_closed(&self) -> bool {
90         self.closed
91     }
92
93     pub(crate) fn write_pending_bytes(
94         &mut self,
95         cx: &mut Context<'_>,
96     ) -> Poll<Result<(), ConnectionWriteError>> {
97         if self.pending_writes.len() > 0 {
98             let stream = self.write_stream.as_mut();
99
100             match stream.poll_flush(cx) {
101                 Poll::Pending => Poll::Pending,
102
103                 Poll::Ready(Ok(_)) => {
104                     trace!("Sending pending bytes");
105
106                     let pending = self.pending_writes.split_off(0);
107                     let writeable_vec: Vec<IoSlice> =
108                         pending.iter().map(|p| IoSlice::new(p)).collect();
109
110                     let stream = self.write_stream.as_mut();
111                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
112                         Poll::Pending => Poll::Pending,
113
114                         Poll::Ready(Ok(bytes_written)) => {
115                             trace!("Wrote {} bytes to network stream", bytes_written);
116                             Poll::Ready(Ok(()))
117                         }
118
119                         Poll::Ready(Err(err)) => {
120                             error!("Encountered error when writing to network stream");
121                             Poll::Ready(Err(ConnectionWriteError::IoError(err)))
122                         }
123                     }
124                 }
125
126                 Poll::Ready(Err(err)) => {
127                     error!("Encountered error when flushing network stream");
128                     Poll::Ready(Err(ConnectionWriteError::IoError(err)))
129                 }
130             }
131         } else {
132             Poll::Ready(Ok(()))
133         }
134     }
135 }
136
137 impl Sink<ConnectDatagram> for ConnectionWriter {
138     type Error = ConnectionWriteError;
139
140     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141         if self.is_closed() {
142             trace!("Connection is closed, cannot send message");
143             Poll::Ready(Err(ConnectionWriteError::ConnectionClosed))
144         } else {
145             trace!("Connection ready to send message");
146             Poll::Ready(Ok(()))
147         }
148     }
149
150     fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> {
151         trace!("Preparing message to be sent next");
152
153         let buffer = item.encode();
154         let msg_size = buffer.len();
155         trace!("Serialized pending message into {} bytes", msg_size);
156
157         self.pending_writes.push(buffer);
158
159         Ok(())
160     }
161
162     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
163         self.write_pending_bytes(cx)
164     }
165
166     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
167         self.closed = true;
168
169         match self.write_pending_bytes(cx) {
170             Poll::Pending => Poll::Pending,
171
172             Poll::Ready(Ok(_)) => {
173                 let stream = self.write_stream.as_mut();
174
175                 match stream.poll_close(cx) {
176                     Poll::Pending => Poll::Pending,
177
178                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
179
180                     Poll::Ready(Err(err)) => Poll::Ready(Err(ConnectionWriteError::IoError(err))),
181                 }
182             }
183
184             err => err,
185         }
186     }
187 }