1 use crate::protocol::ConnectDatagram;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
5 use futures::task::{Context, Poll};
6 use futures::{AsyncRead, Stream};
8 use std::convert::TryInto;
10 pub use futures::{SinkExt, StreamExt};
12 /// A default buffer size to read in bytes and then deserialize as messages.
13 const BUFFER_SIZE: usize = 8192;
15 /// An interface to read messages from the network connection.
17 /// Implements the [`Stream`] trait to asynchronously read messages from the network connection.
24 /// while let Some(msg) = reader.next().await {
25 /// // handle the received message
29 /// Please see the [tcp-client](https://github.com/sachanganesh/connect-rs/blob/main/examples/tcp-client/)
30 /// example program or other client example programs for a more thorough showcase.
33 pub struct ConnectionReader {
34 local_addr: SocketAddr,
35 peer_addr: SocketAddr,
36 read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
37 buffer: Option<BytesMut>,
38 pending_read: Option<BytesMut>,
39 pending_datagram: Option<usize>,
43 impl ConnectionReader {
44 /// Creates a new [`ConnectionReader`] from an [`AsyncRead`] trait object and the local and peer
47 local_addr: SocketAddr,
48 peer_addr: SocketAddr,
49 read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
51 let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
52 buffer.resize(BUFFER_SIZE, 0);
60 pending_datagram: None,
65 /// Get the local IP address and port.
66 pub fn local_addr(&self) -> SocketAddr {
67 self.local_addr.clone()
70 /// Get the peer IP address and port.
71 pub fn peer_addr(&self) -> SocketAddr {
72 self.peer_addr.clone()
75 /// Check if the [`Stream`] of messages from the network is closed.
76 pub fn is_closed(&self) -> bool {
80 pub(crate) fn close_stream(&mut self) {
81 trace!("closing the stream for connection with {}", self.peer_addr);
83 self.pending_datagram.take();
84 self.pending_read.take();
89 impl Stream for ConnectionReader {
90 type Item = ConnectDatagram;
92 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
94 if let Some(size) = self.pending_datagram.take() {
95 if let Some(pending_buf) = self.pending_read.take() {
96 if pending_buf.len() >= size {
97 trace!("{} pending bytes is large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
98 let mut data_buf = pending_buf;
99 let pending_buf = data_buf.split_off(size);
101 let datagram = ConnectDatagram::decode(data_buf.to_vec()).expect(
102 "could not construct ConnectDatagram from bytes despite explicit check",
105 trace!("deserialized message of size {} bytes", datagram.size());
106 return match datagram.version() {
107 // do some special work based on version number if necessary
109 if pending_buf.len() >= std::mem::size_of::<u32>() {
110 trace!("can deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
112 let mut size_buf = pending_buf;
114 size_buf.split_off(std::mem::size_of::<u32>());
115 let size = u32::from_be_bytes(
120 .expect("could not parse bytes into u32"),
123 self.pending_datagram.replace(size);
124 self.pending_read.replace(pending_buf);
126 trace!("cannot deserialize size of next datagram from remaining {} pending bytes", pending_buf.len());
127 self.pending_read.replace(pending_buf);
130 trace!("returning deserialized datagram to user");
131 Poll::Ready(Some(datagram))
135 trace!("{} pending bytes is not large enough to deserialize datagram of size {} bytes", pending_buf.len(), size);
136 self.pending_datagram.replace(size);
137 self.pending_read.replace(pending_buf);
144 let mut buffer = if let Some(buffer) = self.buffer.take() {
145 trace!("prepare buffer to read from the network stream");
148 trace!("construct new buffer to read from the network stream");
149 let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
150 buffer.resize(BUFFER_SIZE, 0);
154 trace!("reading from the network stream");
155 let stream = self.read_stream.as_mut();
156 match stream.poll_read(cx, &mut buffer) {
157 Poll::Ready(Ok(bytes_read)) => {
159 trace!("read {} bytes from the network stream", bytes_read);
162 return Poll::Ready(None);
165 let mut pending_buf = if let Some(pending_buf) = self.pending_read.take() {
166 trace!("preparing {} pending bytes", pending_buf.len());
169 trace!("constructing new pending bytes");
174 "prepending incomplete data ({} bytes) from earlier read of network stream",
177 pending_buf.extend_from_slice(&buffer[0..bytes_read]);
179 if self.pending_datagram.is_none()
180 && pending_buf.len() >= std::mem::size_of::<u32>()
183 "can deserialize size of next datagram from remaining {} pending bytes",
186 let mut size_buf = pending_buf;
187 let pending_buf = size_buf.split_off(std::mem::size_of::<u32>());
188 let size = u32::from_be_bytes(
193 .expect("could not parse bytes into u32"),
196 self.pending_datagram.replace(size);
197 self.pending_read.replace(pending_buf);
199 trace!("size of next datagram already deserialized");
200 self.pending_read.replace(pending_buf);
203 trace!("finished reading from stream and storing buffer");
204 self.buffer.replace(buffer);
207 Poll::Ready(Err(err)) => {
209 "Encountered error when trying to read from network stream {}",
213 return Poll::Ready(None);
217 self.buffer.replace(buffer);
218 return Poll::Pending;