]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
remove dependency on protobuf and introduce basic custom wire format
[connect-rs.git] / src / writer.rs
1 use async_channel::RecvError;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use futures::io::IoSlice;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, Sink};
7 use log::*;
8
9 use crate::protocol::ConnectDatagram;
10 pub use futures::SinkExt;
11 pub use futures::StreamExt;
12
13 /// An interface to write messages to the network connection
14 ///
15 /// Implements the [`Sink`] trait to asynchronously write messages to the network connection.
16 ///
17 /// # Example
18 ///
19 /// Basic usage:
20 ///
21 /// ```ignore
22 /// writer.send(msg).await?;
23 /// ```
24 ///
25 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
26 /// example program or other client example programs for a more thorough showcase.
27 ///
28 pub struct ConnectionWriter {
29     local_addr: SocketAddr,
30     peer_addr: SocketAddr,
31     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
32     pending_writes: Vec<Vec<u8>>,
33     closed: bool,
34 }
35
36 impl ConnectionWriter {
37     /// Creates a new [`ConnectionWriter`] from an [`AsyncWrite`] trait object and the local and peer
38     /// socket metadata
39     pub fn new(
40         local_addr: SocketAddr,
41         peer_addr: SocketAddr,
42         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
43     ) -> Self {
44         Self {
45             local_addr,
46             peer_addr,
47             write_stream,
48             pending_writes: Vec::new(),
49             closed: false,
50         }
51     }
52
53     /// Get the local IP address and port
54     pub fn local_addr(&self) -> SocketAddr {
55         self.local_addr.clone()
56     }
57
58     /// Get the peer IP address and port
59     pub fn peer_addr(&self) -> SocketAddr {
60         self.peer_addr.clone()
61     }
62
63     /// Check if the [`Sink`] of messages to the network is closed
64     pub fn is_closed(&self) -> bool {
65         self.closed
66     }
67
68     pub(crate) fn write_pending_bytes(
69         &mut self,
70         cx: &mut Context<'_>,
71     ) -> Poll<Result<(), RecvError>> {
72         if self.pending_writes.len() > 0 {
73             let stream = self.write_stream.as_mut();
74
75             match stream.poll_flush(cx) {
76                 Poll::Pending => Poll::Pending,
77
78                 Poll::Ready(Ok(_)) => {
79                     trace!("Sending pending bytes");
80
81                     let pending = self.pending_writes.split_off(0);
82                     let writeable_vec: Vec<IoSlice> =
83                         pending.iter().map(|p| IoSlice::new(p)).collect();
84
85                     let stream = self.write_stream.as_mut();
86                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
87                         Poll::Pending => Poll::Pending,
88
89                         Poll::Ready(Ok(bytes_written)) => {
90                             trace!("Wrote {} bytes to network stream", bytes_written);
91                             Poll::Ready(Ok(()))
92                         }
93
94                         Poll::Ready(Err(_e)) => {
95                             error!("Encountered error when writing to network stream");
96                             Poll::Ready(Err(RecvError))
97                         }
98                     }
99                 }
100
101                 Poll::Ready(Err(_e)) => {
102                     error!("Encountered error when flushing network stream");
103                     Poll::Ready(Err(RecvError))
104                 }
105             }
106         } else {
107             Poll::Ready(Ok(()))
108         }
109     }
110 }
111
112 impl Sink<ConnectDatagram> for ConnectionWriter {
113     type Error = RecvError;
114
115     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
116         if self.is_closed() {
117             trace!("Connection is closed, cannot send message");
118             Poll::Ready(Err(RecvError))
119         } else {
120             trace!("Connection ready to send message");
121             Poll::Ready(Ok(()))
122         }
123     }
124
125     fn start_send(mut self: Pin<&mut Self>, item: ConnectDatagram) -> Result<(), Self::Error> {
126         trace!("Preparing message to be sent next");
127
128         let buffer = item.encode();
129         let msg_size = buffer.len();
130         trace!("Serialized pending message into {} bytes", msg_size);
131
132         self.pending_writes.push(buffer);
133
134         Ok(())
135     }
136
137     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
138         self.write_pending_bytes(cx)
139     }
140
141     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
142         self.closed = true;
143
144         match self.write_pending_bytes(cx) {
145             Poll::Pending => Poll::Pending,
146
147             Poll::Ready(Ok(_)) => {
148                 let stream = self.write_stream.as_mut();
149
150                 match stream.poll_close(cx) {
151                     Poll::Pending => Poll::Pending,
152
153                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
154
155                     Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
156                 }
157             }
158
159             err => err,
160         }
161     }
162 }