]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
24030249110beb3768a0211e46654c779e488d71
[connect-rs.git] / src / writer.rs
1 use crate::schema::ConnectionMessage;
2 use async_channel::RecvError;
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use futures::io::IoSlice;
6 use futures::task::{Context, Poll};
7 use futures::{AsyncWrite, Sink};
8 use log::*;
9 use protobuf::Message;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13
14 pub struct ConnectionWriter {
15     local_addr: SocketAddr,
16     peer_addr: SocketAddr,
17     write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
18     pending_writes: Vec<Vec<u8>>,
19     closed: bool,
20 }
21
22 impl ConnectionWriter {
23     pub fn new(
24         local_addr: SocketAddr,
25         peer_addr: SocketAddr,
26         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
27     ) -> Self {
28         Self {
29             local_addr,
30             peer_addr,
31             write_stream,
32             pending_writes: Vec::new(),
33             closed: false,
34         }
35     }
36
37     pub fn local_addr(&self) -> SocketAddr {
38         self.local_addr.clone()
39     }
40
41     pub fn peer_addr(&self) -> SocketAddr {
42         self.peer_addr.clone()
43     }
44
45     pub fn is_closed(&self) -> bool {
46         self.closed
47     }
48 }
49
50 impl<M: Message> Sink<M> for ConnectionWriter {
51     type Error = RecvError;
52
53     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54         if self.is_closed() {
55             trace!("Connection is closed, cannot send message");
56             Poll::Ready(Err(RecvError))
57         } else {
58             trace!("Connection ready to send message");
59             Poll::Ready(Ok(()))
60         }
61     }
62
63     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
64         trace!("Preparing message to be sent next");
65         let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
66
67         if let Ok(buffer) = msg.write_to_bytes() {
68             let msg_size = buffer.len();
69             trace!("Serialized pending message into {} bytes", msg_size);
70
71             self.pending_writes.push(buffer);
72
73             Ok(())
74         } else {
75             error!("Encountered error when serializing message to bytes");
76             Err(RecvError)
77         }
78     }
79
80     fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
81         if self.pending_writes.len() > 0 {
82             let stream = self.write_stream.as_mut();
83
84             match stream.poll_flush(cx) {
85                 Poll::Pending => Poll::Pending,
86
87                 Poll::Ready(Ok(_)) => {
88                     trace!("Sending pending bytes");
89
90                     let pending = self.pending_writes.split_off(0);
91                     let writeable_vec: Vec<IoSlice> =
92                         pending.iter().map(|p| IoSlice::new(p)).collect();
93
94                     let stream = self.write_stream.as_mut();
95                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
96                         Poll::Pending => Poll::Pending,
97
98                         Poll::Ready(Ok(bytes_written)) => {
99                             trace!("Wrote {} bytes to network stream", bytes_written);
100                             Poll::Ready(Ok(()))
101                         }
102
103                         Poll::Ready(Err(_e)) => {
104                             error!("Encountered error when writing to network stream");
105                             Poll::Ready(Err(RecvError))
106                         }
107                     }
108                 }
109
110                 Poll::Ready(Err(_e)) => {
111                     error!("Encountered error when flushing network stream");
112                     Poll::Ready(Err(RecvError))
113                 }
114             }
115         } else {
116             Poll::Ready(Ok(()))
117         }
118     }
119
120     fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121         self.closed = true;
122
123         let flush = if self.pending_writes.len() > 0 {
124             let stream = self.write_stream.as_mut();
125
126             match stream.poll_flush(cx) {
127                 Poll::Pending => Poll::Pending,
128
129                 Poll::Ready(Ok(_)) => {
130                     trace!("Sending pending bytes");
131
132                     let pending = self.pending_writes.split_off(0);
133                     let writeable_vec: Vec<IoSlice> =
134                         pending.iter().map(|p| IoSlice::new(p)).collect();
135
136                     let stream = self.write_stream.as_mut();
137                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
138                         Poll::Pending => Poll::Pending,
139
140                         Poll::Ready(Ok(bytes_written)) => {
141                             trace!("Wrote {} bytes to network stream", bytes_written);
142                             Poll::Ready(Ok(()))
143                         }
144
145                         Poll::Ready(Err(_e)) => {
146                             error!("Encountered error when writing to network stream");
147                             Poll::Ready(Err(RecvError))
148                         }
149                     }
150                 }
151
152                 Poll::Ready(Err(_e)) => {
153                     error!("Encountered error when flushing network stream");
154                     Poll::Ready(Err(RecvError))
155                 }
156             }
157         } else {
158             Poll::Ready(Ok(()))
159         };
160
161         match flush {
162             Poll::Pending => Poll::Pending,
163
164             Poll::Ready(Ok(_)) => {
165                 let stream = self.write_stream.as_mut();
166
167                 match stream.poll_close(cx) {
168                     Poll::Pending => Poll::Pending,
169
170                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
171
172                     Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
173                 }
174             }
175
176             err => err,
177         }
178     }
179 }