]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
428f6d19902f4c7afc614e8b1f66704d39a117e0
[connect-rs.git] / src / reader.rs
1 use crate::schema::ConnectionMessage;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use bytes::{Buf, BytesMut};
5 use futures::task::{Context, Poll};
6 use futures::{AsyncRead, Stream};
7 use log::*;
8 use protobuf::Message;
9 use std::convert::TryInto;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13 use protobuf::well_known_types::Any;
14
15 /// A default buffer size to read in bytes and then deserialize as messages
16 const BUFFER_SIZE: usize = 8192;
17
18 /// An interface to read messages from the network connection
19 ///
20 /// Implements the [`Stream`] trait to asynchronously read messages from the network connection.
21 ///
22 /// # Example
23 ///
24 /// Basic usage:
25 ///
26 /// ```ignore
27 /// while let Some(msg) = reader.next().await {
28 ///   // handle the received message
29 /// }
30 /// ```
31 ///
32 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
33 /// example program or other client example programs for a more thorough showcase.
34 ///
35
36 pub struct ConnectionReader {
37     local_addr: SocketAddr,
38     peer_addr: SocketAddr,
39     read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
40     pending_read: Option<BytesMut>,
41     closed: bool,
42 }
43
44 impl ConnectionReader {
45     /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
46     /// socket metadata
47     pub fn new(
48         local_addr: SocketAddr,
49         peer_addr: SocketAddr,
50         read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
51     ) -> Self {
52         Self {
53             local_addr,
54             peer_addr,
55             read_stream,
56             pending_read: None,
57             closed: false,
58         }
59     }
60
61     /// Get the local IP address and port
62     pub fn local_addr(&self) -> SocketAddr {
63         self.local_addr.clone()
64     }
65
66     /// Get the peer IP address and port
67     pub fn peer_addr(&self) -> SocketAddr {
68         self.peer_addr.clone()
69     }
70
71     /// Check if the [`Stream`] of messages from the network is closed
72     pub fn is_closed(&self) -> bool {
73         self.closed
74     }
75
76     pub(crate) fn close_stream(&mut self) {
77         trace!("Closing the stream for connection with {}", self.peer_addr);
78         self.pending_read.take();
79         self.closed = true;
80     }
81 }
82
83 impl Stream for ConnectionReader {
84     type Item = Any;
85
86     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
87         let mut buffer = BytesMut::new();
88         buffer.resize(BUFFER_SIZE, 0);
89
90         trace!("Starting new read loop for {}", self.local_addr);
91         loop {
92             trace!("Reading from the stream");
93             let stream = self.read_stream.as_mut();
94
95             match stream.poll_read(cx, &mut buffer) {
96                 Poll::Pending => return Poll::Pending,
97
98                 Poll::Ready(Ok(mut bytes_read)) => {
99                     if bytes_read > 0 {
100                         trace!("Read {} bytes from the network stream", bytes_read)
101                     } else if self.pending_read.is_none() {
102                         self.close_stream();
103                         return Poll::Ready(None);
104                     }
105
106                     if let Some(mut pending_buf) = self.pending_read.take() {
107                         trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
108                         bytes_read += pending_buf.len();
109
110                         pending_buf.unsplit(buffer);
111                         buffer = pending_buf;
112                     }
113
114                     let mut bytes_read_u64: u64 = bytes_read.try_into().expect(
115                         format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
116                     );
117                     while bytes_read_u64 > 0 {
118                         trace!(
119                             "{} bytes from network stream still unprocessed",
120                             bytes_read_u64
121                         );
122
123                         buffer.resize(bytes_read, 0);
124
125                         match ConnectionMessage::parse_from_bytes(buffer.as_ref()) {
126                             Ok(mut data) => {
127                                 let serialized_size = data.compute_size();
128                                 trace!("Deserialized message of size {} bytes", serialized_size);
129
130                                 buffer.advance(serialized_size as usize);
131
132                                 let serialized_size_u64: u64 = serialized_size.try_into().expect(
133                                     format!(
134                                         "Conversion from usize ({}) to u64 failed",
135                                         serialized_size
136                                     )
137                                     .as_str(),
138                                 );
139                                 bytes_read_u64 -= serialized_size_u64;
140                                 trace!("{} bytes still unprocessed", bytes_read_u64);
141
142                                 trace!("Sending deserialized message downstream");
143                                 return Poll::Ready(Some(data.take_payload()));
144                             }
145
146                             Err(err) => {
147                                 warn!(
148                                     "Could not deserialize data from the received bytes: {:#?}",
149                                     err
150                                 );
151
152                                 self.pending_read = Some(buffer);
153                                 buffer = BytesMut::new();
154                                 break;
155                             }
156                         }
157                     }
158
159                     buffer.resize(BUFFER_SIZE, 0);
160                 }
161
162                 // Close the stream
163                 Poll::Ready(Err(_e)) => {
164                     self.close_stream();
165                     return Poll::Ready(None);
166                 }
167             }
168         }
169     }
170 }