]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
better handling of stream/sink closing
[connect-rs.git] / src / writer.rs
1 use crate::schema::ConnectionMessage;
2 use async_channel::RecvError;
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, AsyncWriteExt, Sink};
7 use log::*;
8 use protobuf::Message;
9
10 pub use futures::SinkExt;
11 pub use futures::StreamExt;
12
13 pub struct ConnectionWriter {
14     local_addr:    SocketAddr,
15     peer_addr:     SocketAddr,
16     write_stream:  Box<dyn AsyncWrite + Send + Sync + Unpin>,
17     pending_write: Option<ConnectionMessage>,
18     closed:        bool,
19 }
20
21 impl ConnectionWriter {
22     pub fn new(
23         local_addr: SocketAddr,
24         peer_addr: SocketAddr,
25         write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
26     ) -> Self {
27         Self {
28             local_addr,
29             peer_addr,
30             write_stream,
31             pending_write: None,
32             closed: false,
33         }
34     }
35
36     pub fn local_addr(&self) -> SocketAddr {
37         self.local_addr.clone()
38     }
39
40     pub fn peer_addr(&self) -> SocketAddr {
41         self.peer_addr.clone()
42     }
43
44     pub fn is_closed(&self) -> bool {
45         self.closed
46     }
47
48     fn send_to_conn(&mut self) -> Poll<Result<(), RecvError>> {
49         if let Some(pending_msg) = self.pending_write.take() {
50             trace!("Send pending message");
51             if let Ok(buffer) = pending_msg.write_to_bytes() {
52                 let msg_size = buffer.len();
53                 trace!("{} bytes to be sent over network connection", msg_size);
54
55                 return if let Ok(_) =
56                     futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
57                 {
58                     if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) {
59                         trace!("Sent message of {} bytes", msg_size);
60                         Poll::Ready(Ok(()))
61                     } else {
62                         trace!("Encountered error while flushing queued bytes to network stream");
63                         Poll::Ready(Err(RecvError))
64                     }
65                 } else {
66                     error!("Encountered error when writing to network stream");
67                     Poll::Ready(Err(RecvError))
68                 };
69             } else {
70                 error!("Encountered error when serializing message to bytes");
71                 return Poll::Ready(Err(RecvError));
72             }
73         } else {
74             trace!("No message to send over connection");
75         }
76
77         Poll::Ready(Ok(()))
78     }
79 }
80
81 impl<M: Message> Sink<M> for ConnectionWriter {
82     type Error = RecvError;
83
84     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
85         if self.pending_write.is_some() {
86             trace!("Connection not ready to send message yet, waiting for prior message");
87             Poll::Pending
88         } else {
89             trace!("Connection ready to send message");
90             Poll::Ready(Ok(()))
91         }
92     }
93
94     fn start_send(mut self: Pin<&mut Self>, item: M) -> Result<(), Self::Error> {
95         trace!("Preparing message to be sent next");
96         let stitch_msg: ConnectionMessage = ConnectionMessage::from_msg(item);
97         self.pending_write.replace(stitch_msg);
98
99         Ok(())
100     }
101
102     fn poll_flush(
103         mut self: Pin<&mut Self>,
104         _cx: &mut Context<'_>,
105     ) -> Poll<Result<(), Self::Error>> {
106         self.send_to_conn()
107     }
108
109     fn poll_close(
110         mut self: Pin<&mut Self>,
111         _cx: &mut Context<'_>,
112     ) -> Poll<Result<(), Self::Error>> {
113         let _ = self.send_to_conn();
114
115         self.closed = true;
116         if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
117             Poll::Ready(Ok(()))
118         } else {
119             Poll::Ready(Err(RecvError))
120         }
121     }
122 }