]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
39b1048f896fca459f491d4b566cf74c4acd11ad
[connect-rs.git] / src / reader.rs
1 use async_std::net::SocketAddr;
2 use async_std::pin::Pin;
3 use bytes::{Buf, BytesMut};
4 use futures::task::{Context, Poll};
5 use futures::{AsyncRead, Stream};
6 use log::*;
7
8 use crate::protocol::ConnectDatagram;
9 pub use futures::SinkExt;
10 pub use futures::StreamExt;
11 use std::io::Cursor;
12
13 /// A default buffer size to read in bytes and then deserialize as messages
14 const BUFFER_SIZE: usize = 8192;
15
16 /// An interface to read messages from the network connection
17 ///
18 /// Implements the [`Stream`] trait to asynchronously read messages from the network connection.
19 ///
20 /// # Example
21 ///
22 /// Basic usage:
23 ///
24 /// ```ignore
25 /// while let Some(msg) = reader.next().await {
26 ///   // handle the received message
27 /// }
28 /// ```
29 ///
30 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
31 /// example program or other client example programs for a more thorough showcase.
32 ///
33
34 pub struct ConnectionReader {
35     local_addr: SocketAddr,
36     peer_addr: SocketAddr,
37     read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
38     pending_read: Option<BytesMut>,
39     closed: bool,
40 }
41
42 impl ConnectionReader {
43     /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
44     /// socket metadata
45     pub fn new(
46         local_addr: SocketAddr,
47         peer_addr: SocketAddr,
48         read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
49     ) -> Self {
50         Self {
51             local_addr,
52             peer_addr,
53             read_stream,
54             pending_read: None,
55             closed: false,
56         }
57     }
58
59     /// Get the local IP address and port
60     pub fn local_addr(&self) -> SocketAddr {
61         self.local_addr.clone()
62     }
63
64     /// Get the peer IP address and port
65     pub fn peer_addr(&self) -> SocketAddr {
66         self.peer_addr.clone()
67     }
68
69     /// Check if the [`Stream`] of messages from the network is closed
70     pub fn is_closed(&self) -> bool {
71         self.closed
72     }
73
74     pub(crate) fn close_stream(&mut self) {
75         trace!("Closing the stream for connection with {}", self.peer_addr);
76         self.pending_read.take();
77         self.closed = true;
78     }
79 }
80
81 impl Stream for ConnectionReader {
82     type Item = ConnectDatagram;
83
84     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
85         let mut buffer = BytesMut::new();
86         buffer.resize(BUFFER_SIZE, 0);
87
88         trace!("Starting new read loop for {}", self.local_addr);
89         loop {
90             trace!("Reading from the stream");
91             let stream = self.read_stream.as_mut();
92
93             match stream.poll_read(cx, &mut buffer) {
94                 Poll::Pending => return Poll::Pending,
95
96                 Poll::Ready(Ok(mut bytes_read)) => {
97                     if bytes_read > 0 {
98                         trace!("Read {} bytes from the network stream", bytes_read)
99                     } else if self.pending_read.is_none() {
100                         self.close_stream();
101                         return Poll::Ready(None);
102                     }
103
104                     if let Some(mut pending_buf) = self.pending_read.take() {
105                         trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
106                         bytes_read += pending_buf.len();
107
108                         pending_buf.unsplit(buffer);
109                         buffer = pending_buf;
110                     }
111
112                     while bytes_read > 0 {
113                         trace!("{} bytes from network stream still unprocessed", bytes_read);
114
115                         buffer.resize(bytes_read, 0);
116
117                         let mut cursor = Cursor::new(buffer.as_mut());
118                         match ConnectDatagram::decode(&mut cursor) {
119                             Ok(data) => {
120                                 return match data.version() {
121                                     _ => {
122                                         let serialized_size = data.size();
123                                         trace!(
124                                             "Deserialized message of size {} bytes",
125                                             serialized_size
126                                         );
127
128                                         buffer.advance(serialized_size);
129                                         bytes_read -= serialized_size;
130                                         trace!("{} bytes still unprocessed", bytes_read);
131
132                                         trace!("Sending deserialized message downstream");
133                                         Poll::Ready(Some(data))
134                                     }
135                                 }
136                             }
137
138                             Err(err) => {
139                                 warn!(
140                                     "Could not deserialize data from the received bytes: {:#?}",
141                                     err
142                                 );
143
144                                 self.pending_read = Some(buffer);
145                                 buffer = BytesMut::new();
146                                 break;
147                             }
148                         }
149                     }
150
151                     buffer.resize(BUFFER_SIZE, 0);
152                 }
153
154                 // Close the stream
155                 Poll::Ready(Err(_e)) => {
156                     self.close_stream();
157                     return Poll::Ready(None);
158                 }
159             }
160         }
161     }
162 }