]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
add return stmt to fix bug
[connect-rs.git] / src / reader.rs
1 use crate::schema::ConnectionMessage;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use bytes::{Buf, BytesMut};
5 use futures::task::{Context, Poll};
6 use futures::{AsyncRead, Stream};
7 use log::*;
8 use protobuf::Message;
9 use std::convert::TryInto;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13 use protobuf::well_known_types::Any;
14
15 const BUFFER_SIZE: usize = 8192;
16
17 pub struct ConnectionReader {
18     local_addr:   SocketAddr,
19     peer_addr:    SocketAddr,
20     read_stream:  Pin<Box<dyn AsyncRead + Send + Sync>>,
21     pending_read: Option<BytesMut>,
22     closed:       bool,
23 }
24
25 impl ConnectionReader {
26     pub fn new(
27         local_addr: SocketAddr,
28         peer_addr: SocketAddr,
29         read_stream: Pin<Box<dyn AsyncRead + Send + Sync>>,
30     ) -> Self {
31         Self {
32             local_addr,
33             peer_addr,
34             read_stream,
35             pending_read: None,
36             closed: false,
37         }
38     }
39
40     pub fn local_addr(&self) -> SocketAddr {
41         self.local_addr.clone()
42     }
43
44     pub fn peer_addr(&self) -> SocketAddr {
45         self.peer_addr.clone()
46     }
47
48     pub fn is_closed(&self) -> bool {
49         self.closed
50     }
51
52     pub(crate) fn close_stream(&mut self) {
53         trace!("Closing the stream for connection with {}", self.peer_addr);
54         self.pending_read.take();
55         self.closed = true;
56     }
57 }
58
59 impl Stream for ConnectionReader {
60     type Item = Any;
61
62     fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
63         let mut buffer = BytesMut::new();
64         buffer.resize(BUFFER_SIZE, 0);
65
66         trace!("Starting new read loop for {}", self.local_addr);
67         loop {
68             trace!("Reading from the stream");
69             let stream = self.read_stream.as_mut();
70
71             match stream.poll_read(cx, &mut buffer) {
72                 Poll::Pending => return Poll::Pending,
73
74                 Poll::Ready(Ok(mut bytes_read)) => {
75                     if bytes_read > 0 {
76                         trace!("Read {} bytes from the network stream", bytes_read)
77                     } else if self.pending_read.is_none() {
78                         self.close_stream();
79                         return Poll::Ready(None)
80                     }
81
82                     if let Some(mut pending_buf) = self.pending_read.take() {
83                         trace!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
84                         bytes_read += pending_buf.len();
85
86                         pending_buf.unsplit(buffer);
87                         buffer = pending_buf;
88                     }
89
90                     let mut bytes_read_u64: u64 = bytes_read.try_into().expect(
91                         format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
92                     );
93                     while bytes_read_u64 > 0 {
94                         trace!(
95                             "{} bytes from network stream still unprocessed",
96                             bytes_read_u64
97                         );
98
99                         buffer.resize(bytes_read, 0);
100
101                         match ConnectionMessage::parse_from_bytes(buffer.as_ref()) {
102                             Ok(mut data) => {
103                                 let serialized_size = data.compute_size();
104                                 trace!("Deserialized message of size {} bytes", serialized_size);
105
106                                 buffer.advance(serialized_size as usize);
107
108                                 let serialized_size_u64: u64 = serialized_size.try_into().expect(
109                                     format!(
110                                         "Conversion from usize ({}) to u64 failed",
111                                         serialized_size
112                                     )
113                                     .as_str(),
114                                 );
115                                 bytes_read_u64 -= serialized_size_u64;
116                                 trace!("{} bytes still unprocessed", bytes_read_u64);
117
118                                 trace!("Sending deserialized message downstream");
119                                 return Poll::Ready(Some(data.take_payload()));
120                             }
121
122                             Err(err) => {
123                                 warn!(
124                                     "Could not deserialize data from the received bytes: {:#?}",
125                                     err
126                                 );
127
128                                 self.pending_read = Some(buffer);
129                                 buffer = BytesMut::new();
130                                 break;
131                             }
132                         }
133                     }
134
135                     buffer.resize(BUFFER_SIZE, 0);
136                 }
137
138                 // Close the stream
139                 Poll::Ready(Err(_e)) => {
140                     self.close_stream();
141                     return Poll::Ready(None)
142                 }
143             }
144         }
145     }
146 }