]> git.lizzy.rs Git - connect-rs.git/blob - src/reader.rs
rename stitch-net to connect
[connect-rs.git] / src / reader.rs
1 use crate::schema::StitchMessage;
2 use async_std::net::SocketAddr;
3 use async_std::pin::Pin;
4 use bytes::{Buf, BytesMut};
5 use futures::task::{Context, Poll};
6 use futures::{AsyncRead, AsyncReadExt, Stream};
7 use log::*;
8 use protobuf::Message;
9 use std::convert::TryInto;
10
11 pub use futures::SinkExt;
12 pub use futures::StreamExt;
13 use protobuf::well_known_types::Any;
14
15 const BUFFER_SIZE: usize = 8192;
16
17 pub struct StitchConnectionReader {
18     local_addr: SocketAddr,
19     peer_addr: SocketAddr,
20     read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
21     pending_read: Option<BytesMut>,
22 }
23
24 impl StitchConnectionReader {
25     pub fn new(
26         local_addr: SocketAddr,
27         peer_addr: SocketAddr,
28         read_stream: Box<dyn AsyncRead + Send + Sync + Unpin>,
29     ) -> Self {
30         Self {
31             local_addr,
32             peer_addr,
33             read_stream,
34             pending_read: None,
35         }
36     }
37
38     pub fn local_addr(&self) -> SocketAddr {
39         self.local_addr.clone()
40     }
41
42     pub fn peer_addr(&self) -> SocketAddr {
43         self.peer_addr.clone()
44     }
45 }
46
47 impl Stream for StitchConnectionReader {
48     type Item = Any;
49
50     fn poll_next(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
51         let mut buffer = BytesMut::new();
52         buffer.resize(BUFFER_SIZE, 0);
53
54         debug!("Starting new read loop for {}", self.local_addr);
55         loop {
56             trace!("Reading from the stream");
57             match futures::executor::block_on(self.read_stream.read(&mut buffer)) {
58                 Ok(mut bytes_read) => {
59                     if bytes_read > 0 {
60                         debug!("Read {} bytes from the network stream", bytes_read)
61                     }
62
63                     if let Some(mut pending_buf) = self.pending_read.take() {
64                         debug!("Prepending broken data ({} bytes) encountered from earlier read of network stream", pending_buf.len());
65                         bytes_read += pending_buf.len();
66
67                         pending_buf.unsplit(buffer);
68                         buffer = pending_buf;
69                     }
70
71                     let mut bytes_read_u64: u64 = bytes_read.try_into().expect(
72                         format!("Conversion from usize ({}) to u64 failed", bytes_read).as_str(),
73                     );
74                     while bytes_read_u64 > 0 {
75                         debug!(
76                             "{} bytes from network stream still unprocessed",
77                             bytes_read_u64
78                         );
79
80                         buffer.resize(bytes_read, 0);
81                         debug!("{:?}", buffer.as_ref());
82
83                         match StitchMessage::parse_from_bytes(buffer.as_ref()) {
84                             Ok(mut data) => {
85                                 let serialized_size = data.compute_size();
86                                 debug!("Deserialized message of size {} bytes", serialized_size);
87
88                                 buffer.advance(serialized_size as usize);
89
90                                 let serialized_size_u64: u64 = serialized_size.try_into().expect(
91                                     format!(
92                                         "Conversion from usize ({}) to u64 failed",
93                                         serialized_size
94                                     )
95                                     .as_str(),
96                                 );
97                                 bytes_read_u64 -= serialized_size_u64;
98                                 debug!("{} bytes still unprocessed", bytes_read_u64);
99
100                                 debug!("Sending deserialized message downstream");
101                                 return Poll::Ready(Some(data.take_payload()));
102                             }
103
104                             Err(err) => {
105                                 warn!(
106                                     "Could not deserialize data from the received bytes: {:#?}",
107                                     err
108                                 );
109
110                                 self.pending_read = Some(buffer);
111                                 buffer = BytesMut::new();
112                                 break;
113                             }
114                         }
115                     }
116
117                     buffer.resize(BUFFER_SIZE, 0);
118                 }
119
120                 Err(_err) => return Poll::Pending,
121             }
122         }
123     }
124 }