]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
3954c4eee7ac3a22a03f34b82386e00da68e87cb
[connect-rs.git] / src / reader.rs
1 use crate::protocol::ConnectDatagram;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use bytes::BytesMut;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncRead, Stream};
7 use log::*;
8 use std::convert::TryInto;
9
10 pub use futures::{SinkExt, StreamExt};
11
12 /// A default buffer size to read in bytes and then deserialize as messages.
13 const BUFFER_SIZE: usize = 8192;
14
15 /// An interface to read messages from the network connection.
16 ///
17 /// Implements the [`Stream`] trait to asynchronously read messages from the network connection.
18 ///
19 /// # Example
20 ///
21 /// Basic usage:
22 ///
23 /// ```ignore
24 /// while let Some(msg) = reader.next().await {
25 ///   // handle the received message
26 /// }
27 /// ```
28 ///
29 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
30 /// example program or other client example programs for a more thorough showcase.
31 ///
32
33 pub struct ConnectionReader {
34     local_addr: SocketAddr,
35     peer_addr: SocketAddr,
36     read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
37     buffer: Option<BytesMut>,
38     pending_read: Option<BytesMut>,
39     pending_datagram: Option<usize>,
40     closed: bool,
41 }
42
43 impl ConnectionReader {
44     /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
45     /// socket metadata.
46     pub fn new(
47         local_addr: SocketAddr,
48         peer_addr: SocketAddr,
49         read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
50     ) -> Self {
51         let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
52         buffer.resize(BUFFER_SIZE, 0);
53
54         Self {
55             local_addr,
56             peer_addr,
57             read_stream,
58             buffer: Some(buffer),
59             pending_read: None,
60             pending_datagram: None,
61             closed: false,
62         }
63     }
64
65     /// Get the local IP address and port.
66     pub fn local_addr(&self) -> SocketAddr {
67         self.local_addr.clone()
68     }
69
70     /// Get the peer IP address and port.
71     pub fn peer_addr(&self) -> SocketAddr {
72         self.peer_addr.clone()
73     }
74
75     /// Check if the [`Stream`] of messages from the network is closed.
76     pub fn is_closed(&self) -> bool {
77         self.closed
78     }
79
80     pub(crate) fn close_stream(&mut self) {
81         trace!("closing the stream for connection with {}", self.peer_addr);
82         self.buffer.take();
83         self.pending_datagram.take();
84         self.pending_read.take();
85         self.closed = true;
86     }
87 }
88
89 impl Stream for ConnectionReader {
90     type Item = ConnectDatagram;
91
92     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
93         loop {
94             if let Some(size) = self.pending_datagram.take() {
95                 if let Some(pending_buf) = self.pending_read.take() {
96                     if pending_buf.len() >= size {
97                         trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
98                         let mut data_buf = pending_buf;
99                         let pending_buf = data_buf.split_off(size);
100
101                         let datagram = ConnectDatagram::decode(data_buf.to_vec()).expect(
102                             "could not construct ConnectDatagram from bytes despite explicit check",
103                         );
104
105                         trace!("deserialized message of size {} bytes", datagram.size());
106                         return match datagram.version() {
107                             // do some special work based on version number if necessary
108                             _ => {
109                                 if pending_buf.len() >= std::mem::size_of::<u32>() {
110                                     trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
111
112                                     let mut size_buf = pending_buf;
113                                     let pending_buf =
114                                         size_buf.split_off(std::mem::size_of::<u32>());
115                                     let size = u32::from_be_bytes(
116                                         size_buf
117                                             .to_vec()
118                                             .as_slice()
119                                             .try_into()
120                                             .expect("could not parse bytes into u32"),
121                                     ) as usize;
122
123                                     self.pending_datagram.replace(size);
124                                     self.pending_read.replace(pending_buf);
125                                 } else {
126                                     trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
127                                     self.pending_read.replace(pending_buf);
128                                 }
129
130                                 trace!("returning deserialized datagram to user");
131                                 Poll::Ready(Some(datagram))
132                             }
133                         };
134                     } else {
135                         trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
136                         self.pending_datagram.replace(size);
137                         self.pending_read.replace(pending_buf);
138                     }
139                 } else {
140                     unreachable!()
141                 }
142             }
143
144             let mut buffer = if let Some(buffer) = self.buffer.take() {
145                 trace!("prepare buffer to read from the network stream");
146                 buffer
147             } else {
148                 trace!("construct new buffer to read from the network stream");
149                 let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
150                 buffer.resize(BUFFER_SIZE, 0);
151                 buffer
152             };
153
154             trace!("reading from the network stream");
155             let stream = self.read_stream.as_mut();
156             match stream.poll_read(cx, &mut buffer) {
157                 Poll::Ready(Ok(bytes_read)) => {
158                     if bytes_read > 0 {
159                         trace!("read {} bytes from the network stream", bytes_read);
160                     } else {
161                         self.close_stream();
162                         return Poll::Ready(None);
163                     }
164
165                     let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() {
166                         trace!("preparing {} pending bytes", pending_buf.len());
167                         pending_buf
168                     } else {
169                         trace!("constructing new pending bytes");
170                         BytesMut::new()
171                     };
172
173                     trace!(
174                         "prepending incomplete data ({} bytes) from earlier read of network stream",
175                         pending_buf.len()
176                     );
177                     pending_buf.extend_from_slice(&buffer[0..bytes_read]);
178
179                     if self.pending_datagram.is_none()
180                         && pending_buf.len() >= std::mem::size_of::<u32>()
181                     {
182                         trace!(
183                             "can deserialize size of next datagram from remaining {} pending bytes",
184                             pending_buf.len()
185                         );
186                         let mut size_buf = pending_buf;
187                         let pending_buf = size_buf.split_off(std::mem::size_of::<u32>());
188                         let size = u32::from_be_bytes(
189                             size_buf
190                                 .to_vec()
191                                 .as_slice()
192                                 .try_into()
193                                 .expect("could not parse bytes into u32"),
194                         ) as usize;
195
196                         self.pending_datagram.replace(size);
197                         self.pending_read.replace(pending_buf);
198                     } else {
199                         trace!("size of next datagram already deserialized");
200                         self.pending_read.replace(pending_buf);
201                     }
202
203                     trace!("finished reading from stream and storing buffer");
204                     self.buffer.replace(buffer);
205                 }
206
207                 Poll::Ready(Err(err)) => {
208                     error!(
209                         "Encountered error when trying to read from network stream {}",
210                         err
211                     );
212                     self.close_stream();
213                     return Poll::Ready(None);
214                 }
215
216                 Poll::Pending => {
217                     self.buffer.replace(buffer);
218                     return Poll::Pending;
219                 }
220             }
221         }
222     }
223 }