]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
f26ab4b12eba3719b98947c25fc80dc96fa5f24d
[connect-rs.git] / src / writer.rs
1 use crate::schema::ConnectionMessage;
2 use async_channel::RecvError;
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use futures::{AsyncWrite, Sink};
6 use futures::io::IoSlice;
7 use futures::task::{Context, Poll};
8 use log::*;
9 use protobuf::Message;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13
14 pub struct ConnectionWriter {
15     local_addr:     SocketAddr,
16     peer_addr:      SocketAddr,
17     write_stream:   Pin<Box<dyn AsyncWrite + Send + Sync>>,
18     pending_writes: Vec<Vec<u8>>,
19     closed:         bool,
20 }
21
22 impl ConnectionWriter {
23     pub fn new(
24         local_addr: SocketAddr,
25         peer_addr: SocketAddr,
26         write_stream: Pin<Box<dyn AsyncWrite + Send + Sync>>,
27     ) -> Self {
28         Self {
29             local_addr,
30             peer_addr,
31             write_stream,
32             pending_writes: Vec::new(),
33             closed: false,
34         }
35     }
36
37     pub fn local_addr(&self) -> SocketAddr {
38         self.local_addr.clone()
39     }
40
41     pub fn peer_addr(&self) -> SocketAddr {
42         self.peer_addr.clone()
43     }
44
45     pub fn is_closed(&self) -> bool {
46         self.closed
47     }
48 }
49
50 impl<M: Message> Sink<M> for ConnectionWriter {
51     type Error = RecvError;
52
53     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
54         if self.is_closed() {
55             trace!("Connection is closed, cannot send message");
56             Poll::Ready(Err(RecvError))
57         } else {
58             trace!("Connection ready to send message");
59             Poll::Ready(Ok(()))
60         }
61     }
62
63     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
64         trace!("Preparing message to be sent next");
65         let msg: ConnectionMessage = ConnectionMessage::from_msg(item);
66
67         if let Ok(buffer) = msg.write_to_bytes() {
68             let msg_size = buffer.len();
69             trace!("Serialized pending message into {} bytes", msg_size);
70
71             self.pending_writes.push(buffer);
72
73             Ok(())
74         } else {
75             error!("Encountered error when serializing message to bytes");
76             Err(RecvError)
77         }
78     }
79
80     fn poll_flush(
81         mut self: Pin<&mut Self>,
82         cx: &mut Context<'_>,
83     ) -> Poll<Result<(), Self::Error>> {
84         if self.pending_writes.len() > 0 {
85             let stream = self.write_stream.as_mut();
86
87             match stream.poll_flush(cx) {
88                 Poll::Pending => Poll::Pending,
89
90                 Poll::Ready(Ok(_)) => {
91                     trace!("Sending pending bytes");
92
93                     let pending = self.pending_writes.split_off(0);
94                     let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
95                         IoSlice::new(p)
96                     }).collect();
97
98                     let stream = self.write_stream.as_mut();
99                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
100                         Poll::Pending => Poll::Pending,
101
102                         Poll::Ready(Ok(bytes_written)) => {
103                             trace!("Wrote {} bytes to network stream", bytes_written);
104                             Poll::Ready(Ok(()))
105                         },
106
107                         Poll::Ready(Err(_e)) => {
108                             error!("Encountered error when writing to network stream");
109                             Poll::Ready(Err(RecvError))
110                         },
111                     }
112                 },
113
114                 Poll::Ready(Err(_e)) => {
115                     error!("Encountered error when flushing network stream");
116                     Poll::Ready(Err(RecvError))
117                 }
118             }
119         } else {
120             Poll::Ready(Ok(()))
121         }
122     }
123
124     fn poll_close(
125         mut self: Pin<&mut Self>,
126         cx: &mut Context<'_>,
127     ) -> Poll<Result<(), Self::Error>> {
128         self.closed = true;
129
130         let flush = if self.pending_writes.len() > 0 {
131             let stream = self.write_stream.as_mut();
132
133             match stream.poll_flush(cx) {
134                 Poll::Pending => Poll::Pending,
135
136                 Poll::Ready(Ok(_)) => {
137                     trace!("Sending pending bytes");
138
139                     let pending = self.pending_writes.split_off(0);
140                     let writeable_vec: Vec<IoSlice> = pending.iter().map(|p| {
141                         IoSlice::new(p)
142                     }).collect();
143
144                     let stream = self.write_stream.as_mut();
145                     match stream.poll_write_vectored(cx, writeable_vec.as_slice()) {
146                         Poll::Pending => Poll::Pending,
147
148                         Poll::Ready(Ok(bytes_written)) => {
149                             trace!("Wrote {} bytes to network stream", bytes_written);
150                             Poll::Ready(Ok(()))
151                         },
152
153                         Poll::Ready(Err(_e)) => {
154                             error!("Encountered error when writing to network stream");
155                             Poll::Ready(Err(RecvError))
156                         },
157                     }
158                 },
159
160                 Poll::Ready(Err(_e)) => {
161                     error!("Encountered error when flushing network stream");
162                     Poll::Ready(Err(RecvError))
163                 }
164             }
165         } else {
166             Poll::Ready(Ok(()))
167         };
168
169         match flush {
170             Poll::Pending => Poll::Pending,
171
172             Poll::Ready(Ok(_)) => {
173                 let stream = self.write_stream.as_mut();
174
175                 match stream.poll_close(cx) {
176                     Poll::Pending => Poll::Pending,
177
178                     Poll::Ready(Ok(_)) => Poll::Ready(Ok(())),
179
180                     Poll::Ready(Err(_e)) => Poll::Ready(Err(RecvError)),
181                 }
182             },
183
184             err => err,
185         }
186     }
187 }