]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
refactor to have datagram already serialized in memory
[connect-rs.git] / src / writer.rs
1 use crate::protocol::ConnectDatagram;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use futures::io::IoSlice;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, Sink};
7 use log::*;
8 use std::error::Error;
9
10 pub use futures::{SinkExt, StreamExt};
11 use std::fmt::Debug;
12
13 /// Encountered when there is an issue with writing messages on the network stream.
14 ///
15 #[derive(Debug)]
16 pub enum ConnectionWriteError {
17     /// Encountered when trying to send a message while the connection is closed.
18     ConnectionClosed,
19
20     /// Encountered when there is an IO-level error with the connection.
21     IoError(std::io::Error),
22 }
23
24 impl Error for ConnectionWriteError {}
25
26 impl std::fmt::Display for ConnectionWriteError {
27     fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
28         match self {
29             ConnectionWriteError::ConnectionClosed => {
30                 formatter.write_str("cannot send message when connection is closed")
31             }
32             ConnectionWriteError::IoError(err) => std::fmt::Display::fmt(&err, formatter),
33         }
34     }
35 }
36
37 /// An interface to write messages to the network connection.
38 ///
39 /// Implements the `Sink` trait to asynchronously write messages to the network connection.
40 ///
41 /// # Example
42 ///
43 /// Basic usage:
44 ///
45 /// ```ignore
46 /// writer.send(msg).await?;
47 /// ```
48 ///
49 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
50 /// example program or other client example programs for a more thorough showcase.
51 ///
52 pub struct ConnectionWriter {
53     local_addr: SocketAddr,
54     peer_addr: SocketAddr,
55     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
56     pending_writes: Vec<Vec<u8>>,
57     closed: bool,
58 }
59
60 impl ConnectionWriter {
61     /// Creates a new [`ConnectionWriter`] from an [`AsyncWrite`] trait object and the local and peer
62     /// socket metadata.
63     pub fn new(
64         local_addr: SocketAddr,
65         peer_addr: SocketAddr,
66         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
67     ) -> Self {
68         Self {
69             local_addr,
70             peer_addr,
71             write_stream,
72             pending_writes: Vec::new(),
73             closed: false,
74         }
75     }
76
77     /// Get the local IP address and port.
78     pub fn local_addr(&self) -> SocketAddr {
79         self.local_addr.clone()
80     }
81
82     /// Get the peer IP address and port.
83     pub fn peer_addr(&self) -> SocketAddr {
84         self.peer_addr.clone()
85     }
86
87     /// Check if the `Sink` of messages to the network is closed.
88     pub fn is_closed(&self) -> bool {
89         self.closed
90     }
91
92     pub(crate) fn write_pending_bytes(
93         &mut self,
94         cx: &mut Context<'_>,
95     ) -> Poll<Result<(), ConnectionWriteError>> {
96         if self.pending_writes.len() > 0 {
97             let stream = self.write_stream.as_mut();
98
99             match stream.poll_flush(cx) {
100                 Poll::Pending => Poll::Pending,
101
102                 Poll::Ready(Ok(_)) => {
103                     let stream = self.write_stream.as_mut();
104
105                     let split = self.pending_writes.split_off(0);
106                     let pending: Vec<IoSlice> = split.iter().map(|p| IoSlice::new(p)).collect();
107
108                     trace!("sending pending bytes to network stream");
109                     match stream.poll_write_vectored(cx, pending.as_slice()) {
110                         Poll::Pending => {
111                             self.pending_writes = split;
112                             Poll::Pending
113                         }
114
115                         Poll::Ready(Ok(bytes_written)) => {
116                             trace!("wrote {} bytes to network stream", bytes_written);
117                             Poll::Ready(Ok(()))
118                         }
119
120                         Poll::Ready(Err(err)) => {
121                             error!("Encountered error when writing to network stream");
122                             self.pending_writes = split;
123                             Poll::Ready(Err(ConnectionWriteError::IoError(err)))
124                         }
125                     }
126                 }
127
128                 Poll::Ready(Err(err)) => {
129                     error!("Encountered error when flushing network stream");
130                     Poll::Ready(Err(ConnectionWriteError::IoError(err)))
131                 }
132             }
133         } else {
134             Poll::Ready(Ok(()))
135         }
136     }
137 }
138
139 impl Sink<ConnectDatagram> for ConnectionWriter {
140     type Error = ConnectionWriteError;
141
142     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
143         if self.is_closed() {
144             trace!("connection is closed - cannot send message");
145             Poll::Ready(Err(ConnectionWriteError::ConnectionClosed))
146         } else {
147             trace!("connection ready to send message");
148             Poll::Ready(Ok(()))
149         }
150     }
151
152     fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> {
153         trace!("preparing datagram to be queued for sending");
154
155         let buffer = item.into_bytes();
156         let msg_size = buffer.len();
157         trace!("serialized pending message into {} bytes", msg_size);
158
159         self.pending_writes.push(buffer);
160
161         Ok(())
162     }
163
164     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
165         self.write_pending_bytes(cx)
166     }
167
168     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
169         self.closed = true;
170         debug!("Closing the sink for connection with {}", self.peer_addr);
171
172         match self.write_pending_bytes(cx) {
173             Poll::Pending => Poll::Pending,
174
175             Poll::Ready(Ok(_)) => {
176                 let stream = self.write_stream.as_mut();
177
178                 match stream.poll_close(cx) {
179                     Poll::Pending => Poll::Pending,
180
181                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
182
183                     Poll::Ready(Err(err)) => Poll::Ready(Err(ConnectionWriteError::IoError(err))),
184                 }
185             }
186
187             err => err,
188         }
189     }
190 }