]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
refactor to have datagram already serialized in memory
[connect-rs.git] / src / reader.rs
1 use crate::SIZE_PREFIX_BYTE_SIZE;
2 use crate::{protocol::ConnectDatagram, DATAGRAM_HEADER_BYTE_SIZE};
3 use async_std::net::SocketAddr;
4 use async_std::pin::Pin;
5 use bytes::BytesMut;
6 use futures::task::{Context, Poll};
7 use futures::{AsyncRead, Stream};
8 use log::*;
9 use std::convert::TryInto;
10
11 pub use futures::{SinkExt, StreamExt};
12
13 /// A default buffer size to read in bytes and then deserialize as messages.
14 pub(crate) const BUFFER_SIZE: usize = 8192;
15
16 /// An interface to read messages from the network connection.
17 ///
18 /// Implements the `Stream` trait to asynchronously read messages from the network connection.
19 ///
20 /// # Example
21 ///
22 /// Basic usage:
23 ///
24 /// ```ignore
25 /// while let Some(msg) = reader.next().await {
26 ///   // handle the received message
27 /// }
28 /// ```
29 ///
30 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
31 /// example program or other client example programs for a more thorough showcase.
32 ///
33
34 pub struct ConnectionReader {
35     local_addr: SocketAddr,
36     peer_addr: SocketAddr,
37     read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
38     buffer: Option<BytesMut>,
39     pending_read: Option<BytesMut>,
40     pending_datagram: Option<usize>,
41     closed: bool,
42 }
43
44 impl ConnectionReader {
45     /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
46     /// socket metadata.
47     pub fn new(
48         local_addr: SocketAddr,
49         peer_addr: SocketAddr,
50         read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
51     ) -> Self {
52         let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
53         buffer.resize(BUFFER_SIZE, 0);
54
55         Self {
56             local_addr,
57             peer_addr,
58             read_stream,
59             buffer: Some(buffer),
60             pending_read: None,
61             pending_datagram: None,
62             closed: false,
63         }
64     }
65
66     /// Get the local IP address and port.
67     pub fn local_addr(&self) -> SocketAddr {
68         self.local_addr.clone()
69     }
70
71     /// Get the peer IP address and port.
72     pub fn peer_addr(&self) -> SocketAddr {
73         self.peer_addr.clone()
74     }
75
76     /// Check if the `Stream` of messages from the network is closed.
77     pub fn is_closed(&self) -> bool {
78         self.closed
79     }
80
81     pub(crate) fn close_stream(&mut self) {
82         debug!("Closing the stream for connection with {}", self.peer_addr);
83         self.buffer.take();
84         self.pending_datagram.take();
85         self.pending_read.take();
86         self.closed = true;
87     }
88 }
89
90 impl Stream for ConnectionReader {
91     type Item = ConnectDatagram;
92
93     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94         loop {
95             if let Some(size) = self.pending_datagram.take() {
96                 if let Some(pending_buf) = self.pending_read.take() {
97                     if pending_buf.len() >= size {
98                         trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
99                         let mut data_buf = pending_buf;
100                         let pending_buf = data_buf.split_off(size);
101
102                         let datagram = ConnectDatagram::from_bytes_without_prefix(
103                             data_buf.as_ref(),
104                         )
105                         .expect(
106                             "could not construct ConnectDatagram from bytes despite explicit check",
107                         );
108
109                         trace!(
110                             "deserialized message of size {} bytes",
111                             datagram.serialized_size()
112                         );
113                         return match datagram.version() {
114                             // do some special work based on version number if necessary
115                             _ => {
116                                 if pending_buf.len() >= DATAGRAM_HEADER_BYTE_SIZE {
117                                     trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
118
119                                     let mut size_buf = pending_buf;
120                                     let pending_buf = size_buf.split_off(SIZE_PREFIX_BYTE_SIZE);
121
122                                     let size = u32::from_be_bytes(
123                                         size_buf
124                                             .to_vec()
125                                             .as_slice()
126                                             .try_into()
127                                             .expect("could not parse bytes into u32"),
128                                     ) as usize;
129
130                                     trace!("removed size of next datagram from pending bytes ({}), leaving {} pending bytes remaining", size, pending_buf.len());
131                                     self.pending_datagram.replace(size);
132                                     self.pending_read.replace(pending_buf);
133                                 } else {
134                                     trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
135                                     self.pending_read.replace(pending_buf);
136                                 }
137
138                                 trace!("returning deserialized datagram to user");
139                                 Poll::Ready(Some(datagram))
140                             }
141                         };
142                     } else {
143                         trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
144                         self.pending_datagram.replace(size);
145                         self.pending_read.replace(pending_buf);
146                     }
147                 } else {
148                     unreachable!()
149                 }
150             }
151
152             let mut buffer = if let Some(buffer) = self.buffer.take() {
153                 trace!("prepare buffer to read from the network stream");
154                 buffer
155             } else {
156                 trace!("construct new buffer to read from the network stream");
157                 let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
158                 buffer.resize(BUFFER_SIZE, 0);
159                 buffer
160             };
161
162             trace!("reading from the network stream");
163             let stream = self.read_stream.as_mut();
164             match stream.poll_read(cx, &mut buffer) {
165                 Poll::Ready(Ok(bytes_read)) => {
166                     if bytes_read > 0 {
167                         trace!("read {} bytes from the network stream", bytes_read);
168                     } else {
169                         self.close_stream();
170                         return Poll::Ready(None);
171                     }
172
173                     let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() {
174                         trace!("preparing {} pending bytes", pending_buf.len());
175                         pending_buf
176                     } else {
177                         trace!("constructing new pending bytes");
178                         BytesMut::new()
179                     };
180
181                     trace!(
182                         "prepending incomplete data ({} bytes) from earlier read of network stream",
183                         pending_buf.len()
184                     );
185                     pending_buf.extend_from_slice(&buffer[0..bytes_read]);
186
187                     if self.pending_datagram.is_none() && pending_buf.len() >= SIZE_PREFIX_BYTE_SIZE
188                     {
189                         trace!(
190                             "can deserialize size of next datagram from remaining {} pending bytes",
191                             pending_buf.len()
192                         );
193                         let mut size_buf = pending_buf;
194                         let pending_buf = size_buf.split_off(SIZE_PREFIX_BYTE_SIZE);
195
196                         let size = u32::from_be_bytes(
197                             size_buf
198                                 .to_vec()
199                                 .as_slice()
200                                 .try_into()
201                                 .expect("could not parse bytes into u32"),
202                         ) as usize;
203
204                         trace!("removed size of next datagram from pending bytes ({}), leaving {} pending bytes remaining", size, pending_buf.len());
205                         self.pending_datagram.replace(size);
206                         self.pending_read.replace(pending_buf);
207                     } else {
208                         trace!("size of next datagram already deserialized");
209                         self.pending_read.replace(pending_buf);
210                     }
211
212                     trace!("finished reading from stream and storing buffer");
213                     self.buffer.replace(buffer);
214                 }
215
216                 Poll::Ready(Err(err)) => {
217                     error!(
218                         "Encountered error when trying to read from network stream {}",
219                         err
220                     );
221                     self.close_stream();
222                     return Poll::Ready(None);
223                 }
224
225                 Poll::Pending => {
226                     self.buffer.replace(buffer);
227                     return Poll::Pending;
228                 }
229             }
230         }
231     }
232 }