]> git.lizzy.rs Git - connect-rs.git/blob - src/writer.rs
rename stitch-net to connect
[connect-rs.git] / src / writer.rs
1 use crate::schema::StitchMessage;
2 use async_channel::RecvError;
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncWrite, AsyncWriteExt, Sink};
7 use log::*;
8 use protobuf::Message;
9
10 pub use futures::SinkExt;
11 pub use futures::StreamExt;
12
13 pub struct StitchConnectionWriter {
14     local_addr: SocketAddr,
15     peer_addr: SocketAddr,
16     write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
17     pending_write: Option<StitchMessage>,
18 }
19
20 impl StitchConnectionWriter {
21     pub fn new(
22         local_addr: SocketAddr,
23         peer_addr: SocketAddr,
24         write_stream: Box<dyn AsyncWrite + Send + Sync + Unpin>,
25     ) -> Self {
26         Self {
27             local_addr,
28             peer_addr,
29             write_stream,
30             pending_write: None,
31         }
32     }
33
34     pub fn local_addr(&self) -> SocketAddr {
35         self.local_addr.clone()
36     }
37
38     pub fn peer_addr(&self) -> SocketAddr {
39         self.peer_addr.clone()
40     }
41 }
42
43 impl<T: Message> Sink<T> for StitchConnectionWriter {
44     type Error = RecvError;
45
46     fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
47         if self.pending_write.is_some() {
48             debug!("Connection not ready to send message yet, waiting for prior message");
49             Poll::Pending
50         } else {
51             debug!("Connection ready to send message");
52             Poll::Ready(Ok(()))
53         }
54     }
55
56     fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
57         debug!("Preparing message to be sent next");
58         let stitch_msg: StitchMessage = StitchMessage::from_msg(item);
59         self.pending_write.replace(stitch_msg);
60
61         Ok(())
62     }
63
64     fn poll_flush(
65         mut self: Pin<&mut Self>,
66         _cx: &mut Context<'_>,
67     ) -> Poll<Result<(), Self::Error>> {
68         if let Some(pending_msg) = self.pending_write.take() {
69             debug!("Send pending message");
70             if let Ok(buffer) = pending_msg.write_to_bytes() {
71                 let msg_size = buffer.len();
72                 debug!("{} bytes to be sent over network connection", msg_size);
73
74                 debug!("{:?}", buffer.as_slice());
75
76                 return if let Ok(_) =
77                     futures::executor::block_on(self.write_stream.write_all(buffer.as_slice()))
78                 {
79                     if let Ok(_) = futures::executor::block_on(self.write_stream.flush()) {
80                         debug!("Sent message of {} bytes", msg_size);
81                         Poll::Ready(Ok(()))
82                     } else {
83                         debug!("Encountered error while flushing queued bytes to network stream");
84                         Poll::Ready(Err(RecvError))
85                     }
86                 } else {
87                     debug!("Encountered error when writing to network stream");
88                     Poll::Ready(Err(RecvError))
89                 };
90             } else {
91                 debug!("Encountered error when serializing message to bytes");
92                 return Poll::Ready(Err(RecvError));
93             }
94         } else {
95             debug!("No message to send over connection");
96         }
97
98         Poll::Ready(Ok(()))
99     }
100
101     fn poll_close(
102         mut self: Pin<&mut Self>,
103         _cx: &mut Context<'_>,
104     ) -> Poll<Result<(), Self::Error>> {
105         if let Ok(_) = futures::executor::block_on(self.write_stream.close()) {
106             Poll::Ready(Ok(()))
107         } else {
108             Poll::Ready(Err(RecvError))
109         }
110     }
111 }