]> git.lizzy.rs Git - mt_rudp.git/blob - src/send.rs
Implement sending splits
[mt_rudp.git] / src / send.rs
1 use super::*;
2 use byteorder::{BigEndian, WriteBytesExt};
3 use std::{
4     borrow::Cow,
5     collections::HashMap,
6     io::{self, Write},
7     sync::Arc,
8 };
9 use tokio::sync::{watch, Mutex, RwLock};
10
11 pub type Ack = Option<watch::Receiver<bool>>;
12
13 #[derive(Debug)]
14 pub(crate) struct AckWait {
15     pub(crate) tx: watch::Sender<bool>,
16     pub(crate) rx: watch::Receiver<bool>,
17     pub(crate) data: Vec<u8>,
18 }
19
20 #[derive(Debug)]
21 pub(crate) struct Chan {
22     pub(crate) acks: HashMap<u16, AckWait>,
23     pub(crate) seqnum: u16,
24     pub(crate) splits_seqnum: u16,
25 }
26
27 #[derive(Debug)]
28 pub struct Sender<S: UdpSender> {
29     pub(crate) id: u16,
30     pub(crate) remote_id: RwLock<u16>,
31     pub(crate) chans: [Mutex<Chan>; NUM_CHANS],
32     udp: S,
33     close: watch::Sender<bool>,
34 }
35
36 impl<S: UdpSender> Sender<S> {
37     pub fn new(udp: S, close: watch::Sender<bool>, id: u16, remote_id: u16) -> Arc<Self> {
38         Arc::new(Self {
39             id,
40             remote_id: RwLock::new(remote_id),
41             udp,
42             close,
43             chans: std::array::from_fn(|_| {
44                 Mutex::new(Chan {
45                     acks: HashMap::new(),
46                     seqnum: INIT_SEQNUM,
47                     splits_seqnum: INIT_SEQNUM,
48                 })
49             }),
50         })
51     }
52
53     pub async fn send_rudp(&self, pkt: Pkt<'_>) -> io::Result<Ack> {
54         if pkt.size() > UDP_PKT_SIZE {
55             let chunks = pkt
56                 .data
57                 .chunks(UDP_PKT_SIZE - (pkt.header_size() + 1 + 2 + 2 + 2));
58             let num_chunks: u16 = chunks
59                 .len()
60                 .try_into()
61                 .map_err(|_| io::Error::new(io::ErrorKind::Other, "too many chunks"))?;
62
63             let seqnum = {
64                 let mut chan = self.chans[pkt.chan as usize].lock().await;
65                 let sn = chan.splits_seqnum;
66                 chan.splits_seqnum = chan.splits_seqnum.overflowing_add(1).0;
67
68                 sn
69             };
70
71             for (i, ch) in chunks.enumerate() {
72                 self.send_rudp_type(
73                     PktType::Orig,
74                     Some((seqnum, num_chunks, i as u16)),
75                     Pkt {
76                         unrel: pkt.unrel,
77                         chan: pkt.chan,
78                         data: Cow::Borrowed(ch),
79                     },
80                 )
81                 .await?;
82             }
83
84             Ok(None) // TODO: ack
85         } else {
86             self.send_rudp_type(PktType::Orig, None, pkt).await
87         }
88     }
89
90     pub async fn send_rudp_type(
91         &self,
92         tp: PktType,
93         chunk: Option<(u16, u16, u16)>,
94         pkt: Pkt<'_>,
95     ) -> io::Result<Ack> {
96         let mut buf =
97             Vec::with_capacity(pkt.size() + if chunk.is_some() { 1 + 2 + 2 + 2 } else { 0 });
98
99         buf.write_u32::<BigEndian>(PROTO_ID)?;
100         buf.write_u16::<BigEndian>(*self.remote_id.read().await)?;
101         buf.write_u8(pkt.chan)?;
102
103         let mut chan = self.chans[pkt.chan as usize].lock().await;
104         let seqnum = chan.seqnum;
105
106         if !pkt.unrel {
107             buf.write_u8(PktType::Rel as u8)?;
108             buf.write_u16::<BigEndian>(seqnum)?;
109         }
110
111         if let Some((seqnum, count, index)) = chunk {
112             buf.write_u8(PktType::Split as u8)?;
113             buf.write_u16::<BigEndian>(seqnum)?;
114             buf.write_u16::<BigEndian>(count)?;
115             buf.write_u16::<BigEndian>(index)?;
116         } else {
117             buf.write_u8(tp as u8)?;
118         }
119
120         buf.write_all(pkt.data.as_ref())?;
121
122         self.send_udp(&buf).await?;
123
124         if pkt.unrel {
125             Ok(None)
126         } else {
127             // TODO: reliable window
128             let (tx, rx) = watch::channel(false);
129             chan.acks.insert(
130                 seqnum,
131                 AckWait {
132                     tx,
133                     rx: rx.clone(),
134                     data: buf,
135                 },
136             );
137             chan.seqnum = chan.seqnum.overflowing_add(1).0;
138
139             Ok(Some(rx))
140         }
141     }
142
143     pub async fn send_udp(&self, data: &[u8]) -> io::Result<()> {
144         if data.len() > UDP_PKT_SIZE {
145             panic!(
146                 "attempted to send a packet with len {} > {UDP_PKT_SIZE}",
147                 data.len()
148             );
149         }
150
151         self.udp.send(data).await
152     }
153
154     pub async fn peer_id(&self) -> u16 {
155         self.id
156     }
157
158     pub async fn is_server(&self) -> bool {
159         self.id == PeerID::Srv as u16
160     }
161
162     pub fn close(&self) {
163         self.close.send(true).ok();
164     }
165 }