]> git.lizzy.rs Git - mt.git/blob - rudp/process.go
rudp: optimize and refactor
[mt.git] / rudp / process.go
1 package rudp
2
3 import (
4         "errors"
5         "fmt"
6         "io"
7         "net"
8 )
9
10 // A PktError is an error that occured while processing a packet.
11 type PktError struct {
12         Type string // "net", "raw" or "rel".
13         Data []byte
14         Err  error
15 }
16
17 func (e PktError) Error() string {
18         return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
19 }
20
21 func (e PktError) Unwrap() error { return e.Err }
22
23 func (p *Peer) processNetPkts(pkts <-chan netPkt) {
24         for pkt := range pkts {
25                 if err := p.processNetPkt(pkt); err != nil {
26                         p.errs <- PktError{"net", pkt.Data, err}
27                 }
28         }
29
30         close(p.pkts)
31 }
32
33 // A TrailingDataError reports a packet with trailing data,
34 // it doesn't stop a packet from being processed.
35 type TrailingDataError []byte
36
37 func (e TrailingDataError) Error() string {
38         return fmt.Sprintf("trailing data: %x", []byte(e))
39 }
40
41 func (p *Peer) processNetPkt(pkt netPkt) (err error) {
42         if pkt.SrcAddr.String() != p.Addr().String() {
43                 return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
44         }
45
46         if len(pkt.Data) < MtHdrSize {
47                 return io.ErrUnexpectedEOF
48         }
49
50         if id := be.Uint32(pkt.Data[0:4]); id != protoID {
51                 return fmt.Errorf("unsupported protocol id: 0x%08x", id)
52         }
53
54         // src PeerID at pkt.Data[4:6]
55
56         chno := pkt.Data[6]
57         if chno >= ChannelCount {
58                 return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
59         }
60
61         p.mu.RLock()
62         if p.timeout != nil {
63                 p.timeout.Reset(ConnTimeout)
64         }
65         p.mu.RUnlock()
66
67         rpkt := rawPkt{
68                 Data:  pkt.Data[MtHdrSize:],
69                 ChNo:  chno,
70                 Unrel: true,
71         }
72         if err := p.processRawPkt(rpkt); err != nil {
73                 p.errs <- PktError{"raw", rpkt.Data, err}
74         }
75
76         return nil
77 }
78
79 func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
80         errWrap := func(format string, a ...interface{}) {
81                 if err != nil {
82                         err = fmt.Errorf(format, append(a, err)...)
83                 }
84         }
85
86         c := &p.chans[pkt.ChNo]
87
88         if len(pkt.Data) < 1 {
89                 return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
90         }
91         switch t := rawType(pkt.Data[0]); t {
92         case rawTypeCtl:
93                 defer errWrap("ctl: %w")
94
95                 if len(pkt.Data) < 1+1 {
96                         return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
97                 }
98                 switch ct := ctlType(pkt.Data[1]); ct {
99                 case ctlAck:
100                         defer errWrap("ack: %w")
101
102                         if len(pkt.Data) < 1+1+2 {
103                                 return io.ErrUnexpectedEOF
104                         }
105
106                         sn := seqnum(be.Uint16(pkt.Data[2:4]))
107
108                         if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
109                                 close(ack.(chan struct{}))
110                         }
111
112                         if len(pkt.Data) > 1+1+2 {
113                                 return TrailingDataError(pkt.Data[1+1+2:])
114                         }
115                 case ctlSetPeerID:
116                         defer errWrap("set peer id: %w")
117
118                         if len(pkt.Data) < 1+1+2 {
119                                 return io.ErrUnexpectedEOF
120                         }
121
122                         // Ensure no concurrent senders while peer id changes.
123                         p.mu.Lock()
124                         if p.idOfPeer != PeerIDNil {
125                                 return errors.New("peer id already set")
126                         }
127
128                         p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
129                         p.mu.Unlock()
130
131                         if len(pkt.Data) > 1+1+2 {
132                                 return TrailingDataError(pkt.Data[1+1+2:])
133                         }
134                 case ctlPing:
135                         defer errWrap("ping: %w")
136
137                         if len(pkt.Data) > 1+1 {
138                                 return TrailingDataError(pkt.Data[1+1:])
139                         }
140                 case ctlDisco:
141                         defer errWrap("disco: %w")
142
143                         p.Close()
144
145                         if len(pkt.Data) > 1+1 {
146                                 return TrailingDataError(pkt.Data[1+1:])
147                         }
148                 default:
149                         return fmt.Errorf("unsupported ctl type: %d", ct)
150                 }
151         case rawTypeOrig:
152                 p.pkts <- Pkt{
153                         Data:  pkt.Data[1:],
154                         ChNo:  pkt.ChNo,
155                         Unrel: pkt.Unrel,
156                 }
157         case rawTypeSplit:
158                 defer errWrap("split: %w")
159
160                 if len(pkt.Data) < 1+2+2+2 {
161                         return io.ErrUnexpectedEOF
162                 }
163
164                 sn := seqnum(be.Uint16(pkt.Data[1:3]))
165                 count := be.Uint16(pkt.Data[3:5])
166                 i := be.Uint16(pkt.Data[5:7])
167
168                 if i >= count {
169                         return nil
170                 }
171
172                 splits := p.chans[pkt.ChNo].inSplit
173
174                 // Delete old incomplete split packets
175                 // so new ones don't get corrupted.
176                 splits[sn-0x8000] = nil
177
178                 if splits[sn] == nil {
179                         splits[sn] = &inSplit{chunks: make([][]byte, count)}
180                 }
181
182                 s := splits[sn]
183
184                 if int(count) != len(s.chunks) {
185                         return fmt.Errorf("chunk count changed on split packet: %d", sn)
186                 }
187
188                 s.chunks[i] = pkt.Data[7:]
189                 s.size += len(s.chunks[i])
190                 s.got++
191
192                 if s.got == len(s.chunks) {
193                         data := make([]byte, 0, s.size)
194                         for _, chunk := range s.chunks {
195                                 data = append(data, chunk...)
196                         }
197
198                         p.pkts <- Pkt{
199                                 Data:  data,
200                                 ChNo:  pkt.ChNo,
201                                 Unrel: pkt.Unrel,
202                         }
203
204                         splits[sn] = nil
205                 }
206         case rawTypeRel:
207                 defer errWrap("rel: %w")
208
209                 if len(pkt.Data) < 1+2 {
210                         return io.ErrUnexpectedEOF
211                 }
212
213                 sn := seqnum(be.Uint16(pkt.Data[1:3]))
214
215                 ack := make([]byte, 1+1+2)
216                 ack[0] = uint8(rawTypeCtl)
217                 ack[1] = uint8(ctlAck)
218                 be.PutUint16(ack[2:4], uint16(sn))
219                 if _, err := p.sendRaw(rawPkt{
220                         Data:  ack,
221                         ChNo:  pkt.ChNo,
222                         Unrel: true,
223                 }); err != nil {
224                         if errors.Is(err, net.ErrClosed) {
225                                 return nil
226                         }
227                         return fmt.Errorf("can't ack %d: %w", sn, err)
228                 }
229
230                 if sn-c.inRelSN >= 0x8000 {
231                         return nil // Already received.
232                 }
233
234                 c.inRel[sn] = pkt.Data[3:]
235
236                 for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
237                         rpkt := rawPkt{
238                                 Data:  c.inRel[c.inRelSN],
239                                 ChNo:  pkt.ChNo,
240                                 Unrel: false,
241                         }
242                         c.inRel[c.inRelSN] = nil
243
244                         if err := p.processRawPkt(rpkt); err != nil {
245                                 p.errs <- PktError{"rel", rpkt.Data, err}
246                         }
247                 }
248         default:
249                 return fmt.Errorf("unsupported pkt type: %d", t)
250         }
251
252         return nil
253 }