]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
c65a4441cfc0b922ae7977de879493bcb52a2593
[connect-rs.git] / src / writer.rs
1 use crate::protocol::ConnectDatagram;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use futures::io::IoSlice;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, Sink};
7 use log::*;
8 use std::error::Error;
9
10 pub use futures::SinkExt;
11 pub use futures::StreamExt;
12 use std::fmt::Debug;
13
14 /// Encountered when there is an issue with writing messages on the network stream.
15 ///
16 #[derive(Debug)]
17 pub enum ConnectionWriteError {
18     /// Encountered when trying to send a message while the connection is closed.
19     ConnectionClosed,
20
21     /// Encountered when there is an IO-level error with the connection.
22     IoError(std::io::Error),
23 }
24
25 impl Error for ConnectionWriteError {}
26
27 impl std::fmt::Display for ConnectionWriteError {
28     fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
29         match self {
30             ConnectionWriteError::ConnectionClosed => formatter.write_str("cannot send message when connection is closed"),
31             ConnectionWriteError::IoError(err) => std::fmt::Display::fmt(&err, formatter),
32         }
33     }
34 }
35
36 /// An interface to write messages to the network connection.
37 ///
38 /// Implements the [`Sink`] trait to asynchronously write messages to the network connection.
39 ///
40 /// # Example
41 ///
42 /// Basic usage:
43 ///
44 /// ```ignore
45 /// writer.send(msg).await?;
46 /// ```
47 ///
48 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
49 /// example program or other client example programs for a more thorough showcase.
50 ///
51 pub struct ConnectionWriter {
52     local_addr: SocketAddr,
53     peer_addr: SocketAddr,
54     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
55     pending_writes: Vec<Vec<u8>>,
56     closed: bool,
57 }
58
59 impl ConnectionWriter {
60     /// Creates a new [`ConnectionWriter`] from an [`AsyncWrite`] trait object and the local and peer
61     /// socket metadata.
62     pub fn new(
63         local_addr: SocketAddr,
64         peer_addr: SocketAddr,
65         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
66     ) -> Self {
67         Self {
68             local_addr,
69             peer_addr,
70             write_stream,
71             pending_writes: Vec::new(),
72             closed: false,
73         }
74     }
75
76     /// Get the local IP address and port.
77     pub fn local_addr(&self) -> SocketAddr {
78         self.local_addr.clone()
79     }
80
81     /// Get the peer IP address and port.
82     pub fn peer_addr(&self) -> SocketAddr {
83         self.peer_addr.clone()
84     }
85
86     /// Check if the [`Sink`] of messages to the network is closed.
87     pub fn is_closed(&self) -> bool {
88         self.closed
89     }
90
91     pub(crate) fn write_pending_bytes(
92         &mut self,
93         cx: &mut Context<'_>,
94     ) -> Poll<Result<(), ConnectionWriteError>> {
95         if self.pending_writes.len() > 0 {
96             let stream = self.write_stream.as_mut();
97
98             match stream.poll_flush(cx) {
99                 Poll::Pending => Poll::Pending,
100
101                 Poll::Ready(Ok(_)) => {
102                     trace!("Sending pending bytes");
103
104                     let pending = self.pending_writes.split_off(0);
105                     let writeable_vec: Vec<IoSlice> =
106                         pending.iter().map(|p| IoSlice::new(p)).collect();
107
108                     let stream = self.write_stream.as_mut();
109                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
110                         Poll::Pending => Poll::Pending,
111
112                         Poll::Ready(Ok(bytes_written)) => {
113                             trace!("Wrote {} bytes to network stream", bytes_written);
114                             Poll::Ready(Ok(()))
115                         }
116
117                         Poll::Ready(Err(err)) => {
118                             error!("Encountered error when writing to network stream");
119                             Poll::Ready(Err(ConnectionWriteError::IoError(err)))
120                         }
121                     }
122                 }
123
124                 Poll::Ready(Err(err)) => {
125                     error!("Encountered error when flushing network stream");
126                     Poll::Ready(Err(ConnectionWriteError::IoError(err)))
127                 }
128             }
129         } else {
130             Poll::Ready(Ok(()))
131         }
132     }
133 }
134
135 impl Sink<ConnectDatagram> for ConnectionWriter {
136     type Error = ConnectionWriteError;
137
138     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
139         if self.is_closed() {
140             trace!("Connection is closed, cannot send message");
141             Poll::Ready(Err(ConnectionWriteError::ConnectionClosed))
142         } else {
143             trace!("Connection ready to send message");
144             Poll::Ready(Ok(()))
145         }
146     }
147
148     fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> {
149         trace!("Preparing message to be sent next");
150
151         let buffer = item.encode();
152         let msg_size = buffer.len();
153         trace!("Serialized pending message into {} bytes", msg_size);
154
155         self.pending_writes.push(buffer);
156
157         Ok(())
158     }
159
160     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161         self.write_pending_bytes(cx)
162     }
163
164     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165         self.closed = true;
166
167         match self.write_pending_bytes(cx) {
168             Poll::Pending => Poll::Pending,
169
170             Poll::Ready(Ok(_)) => {
171                 let stream = self.write_stream.as_mut();
172
173                 match stream.poll_close(cx) {
174                     Poll::Pending => Poll::Pending,
175
176                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
177
178                     Poll::Ready(Err(err)) => Poll::Ready(Err(ConnectionWriteError::IoError(err))),
179                 }
180             }
181
182             err => err,
183         }
184     }
185 }