]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
add documentation prop
[connect-rs.git] / src / writer.rs
1 use crate::schema::ConnectionMessage;
2 use async_channel::RecvError;
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use futures::io::IoSlice;
6 use futures::task::{Context, Poll};
7 use futures::{AsyncWrite, Sink};
8 use log::*;
9 use protobuf::Message;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13
14 /// An interface to write messages to the network connection
15 ///
16 /// Implements the [`Sink`] trait to asynchronously write messages to the network connection.
17 ///
18 /// # Example
19 ///
20 /// Basic usage:
21 ///
22 /// ```ignore
23 /// writer.send(msg).await?;
24 /// ```
25 ///
26 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
27 /// example program or other client example programs for a more thorough showcase.
28 ///
29 pub struct ConnectionWriter {
30     local_addr: SocketAddr,
31     peer_addr: SocketAddr,
32     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
33     pending_writes: Vec<Vec<u8>>,
34     closed: bool,
35 }
36
37 impl ConnectionWriter {
38     /// Creates a new [`ConnectionWriter`] from an [`AsyncWrite`] trait object and the local and peer
39     /// socket metadata
40     pub fn new(
41         local_addr: SocketAddr,
42         peer_addr: SocketAddr,
43         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
44     ) -> Self {
45         Self {
46             local_addr,
47             peer_addr,
48             write_stream,
49             pending_writes: Vec::new(),
50             closed: false,
51         }
52     }
53
54     /// Get the local IP address and port
55     pub fn local_addr(&self) -> SocketAddr {
56         self.local_addr.clone()
57     }
58
59     /// Get the peer IP address and port
60     pub fn peer_addr(&self) -> SocketAddr {
61         self.peer_addr.clone()
62     }
63
64     /// Check if the [`Sink`] of messages to the network is closed
65     pub fn is_closed(&self) -> bool {
66         self.closed
67     }
68 }
69
70 impl<M: Message> Sink<M> for ConnectionWriter {
71     type Error = RecvError;
72
73     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
74         if self.is_closed() {
75             trace!("Connection is closed, cannot send message");
76             Poll::Ready(Err(RecvError))
77         } else {
78             trace!("Connection ready to send message");
79             Poll::Ready(Ok(()))
80         }
81     }
82
83     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
84         trace!("Preparing message to be sent next");
85         let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
86
87         if let Ok(buffer) = msg.write_to_bytes() {
88             let msg_size = buffer.len();
89             trace!("Serialized pending message into {} bytes", msg_size);
90
91             self.pending_writes.push(buffer);
92
93             Ok(())
94         } else {
95             error!("Encountered error when serializing message to bytes");
96             Err(RecvError)
97         }
98     }
99
100     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
101         if self.pending_writes.len() > 0 {
102             let stream = self.write_stream.as_mut();
103
104             match stream.poll_flush(cx) {
105                 Poll::Pending => Poll::Pending,
106
107                 Poll::Ready(Ok(_)) => {
108                     trace!("Sending pending bytes");
109
110                     let pending = self.pending_writes.split_off(0);
111                     let writeable_vec: Vec<IoSlice> =
112                         pending.iter().map(|p| IoSlice::new(p)).collect();
113
114                     let stream = self.write_stream.as_mut();
115                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
116                         Poll::Pending => Poll::Pending,
117
118                         Poll::Ready(Ok(bytes_written)) => {
119                             trace!("Wrote {} bytes to network stream", bytes_written);
120                             Poll::Ready(Ok(()))
121                         }
122
123                         Poll::Ready(Err(_e)) => {
124                             error!("Encountered error when writing to network stream");
125                             Poll::Ready(Err(RecvError))
126                         }
127                     }
128                 }
129
130                 Poll::Ready(Err(_e)) => {
131                     error!("Encountered error when flushing network stream");
132                     Poll::Ready(Err(RecvError))
133                 }
134             }
135         } else {
136             Poll::Ready(Ok(()))
137         }
138     }
139
140     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
141         self.closed = true;
142
143         let flush = if self.pending_writes.len() > 0 {
144             let stream = self.write_stream.as_mut();
145
146             match stream.poll_flush(cx) {
147                 Poll::Pending => Poll::Pending,
148
149                 Poll::Ready(Ok(_)) => {
150                     trace!("Sending pending bytes");
151
152                     let pending = self.pending_writes.split_off(0);
153                     let writeable_vec: Vec<IoSlice> =
154                         pending.iter().map(|p| IoSlice::new(p)).collect();
155
156                     let stream = self.write_stream.as_mut();
157                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
158                         Poll::Pending => Poll::Pending,
159
160                         Poll::Ready(Ok(bytes_written)) => {
161                             trace!("Wrote {} bytes to network stream", bytes_written);
162                             Poll::Ready(Ok(()))
163                         }
164
165                         Poll::Ready(Err(_e)) => {
166                             error!("Encountered error when writing to network stream");
167                             Poll::Ready(Err(RecvError))
168                         }
169                     }
170                 }
171
172                 Poll::Ready(Err(_e)) => {
173                     error!("Encountered error when flushing network stream");
174                     Poll::Ready(Err(RecvError))
175                 }
176             }
177         } else {
178             Poll::Ready(Ok(()))
179         };
180
181         match flush {
182             Poll::Pending => Poll::Pending,
183
184             Poll::Ready(Ok(_)) => {
185                 let stream = self.write_stream.as_mut();
186
187                 match stream.poll_close(cx) {
188                     Poll::Pending => Poll::Pending,
189
190                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
191
192                     Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
193                 }
194             }
195
196             err => err,
197         }
198     }
199 }