]> git.lizzy.rs Git - mt_rudp.git/blobdiff - src/recv.rs
timeouts
[mt_rudp.git] / src / recv.rs
index 15811f2ab548b138b4bdf74ef24e45fe82da1e04..2fabe3ade59587d6564e2ec0ca122aa13e0f72a1 100644 (file)
@@ -5,6 +5,7 @@ use std::{
     cell::{Cell, OnceCell},
     collections::HashMap,
     io,
+    pin::Pin,
     sync::Arc,
     time::{Duration, Instant},
 };
@@ -65,6 +66,8 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
     }
 
     pub async fn run(&self) {
+        use Error::*;
+
         let cleanup_chans = Arc::clone(&self.chans);
         let mut cleanup_close = self.close.clone();
         self.share
@@ -90,36 +93,54 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
             });
 
         let mut close = self.close.clone();
+        let timeout = tokio::time::sleep(Duration::from_secs(TIMEOUT));
+        tokio::pin!(timeout);
+
         loop {
-            if let Err(e) = self.handle(self.recv_pkt(&mut close).await) {
-                if let Error::LocalDisco = e {
-                    self.share
-                        .send(
-                            PktType::Ctl,
-                            Pkt {
-                                unrel: true,
-                                chan: 0,
-                                data: &[CtlType::Disco as u8],
-                            },
-                        )
-                        .await
-                        .ok();
+            if let Err(e) = self.handle(self.recv_pkt(&mut close, timeout.as_mut()).await) {
+                // TODO: figure out whether this is a good idea
+                if let RemoteDisco(to) = e {
+                    self.pkt_tx.send(Err(RemoteDisco(to))).ok();
                 }
+
+                match e {
+                                       // anon5's mt notifies the peer on timeout, C++ MT does not
+                                       LocalDisco /*| RemoteDisco(true)*/ => drop(
+                                               self.share
+                                                       .send(
+                                                               PktType::Ctl,
+                                                               Pkt {
+                                                                       unrel: true,
+                                                                       chan: 0,
+                                                                       data: &[CtlType::Disco as u8],
+                                                               },
+                                                       )
+                                                       .await
+                                                       .ok(),
+                                       ),
+                                       _ => {}
+                               }
+
                 break;
             }
         }
     }
 
-    async fn recv_pkt(&self, close: &mut watch::Receiver<bool>) -> Result<()> {
+    async fn recv_pkt(
+        &self,
+        close: &mut watch::Receiver<bool>,
+        timeout: Pin<&mut tokio::time::Sleep>,
+    ) -> Result<()> {
         use Error::*;
 
         // TODO: reset timeout
         let mut cursor = io::Cursor::new(tokio::select! {
             pkt = self.udp_rx.recv() => pkt?,
+            _ = tokio::time::sleep_until(timeout.deadline()) => return Err(RemoteDisco(true)),
             _ = close.changed() => return Err(LocalDisco),
         });
 
-        println!("recv");
+        timeout.reset(tokio::time::Instant::now() + Duration::from_secs(TIMEOUT));
 
         let proto_id = cursor.read_u32::<BigEndian>()?;
         if proto_id != PROTO_ID {
@@ -179,7 +200,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
                 }
                 CtlType::Disco => {
                     println!("Disco");
-                    return Err(RemoteDisco);
+                    return Err(RemoteDisco(false));
                 }
             },
             PktType::Orig => {
@@ -275,7 +296,7 @@ impl<R: UdpReceiver, S: UdpSender> RecvWorker<R, S> {
 
         match res {
             Ok(v) => Ok(v),
-            Err(RemoteDisco) => Err(RemoteDisco),
+            Err(RemoteDisco(to)) => Err(RemoteDisco(to)),
             Err(LocalDisco) => Err(LocalDisco),
             Err(e) => Ok(self.pkt_tx.send(Err(e))?),
         }