]> git.lizzy.rs Git - mt.git/blob - rudp/process.go
70c04a23f3eb161428b15bbd13da47a3472e0e45
[mt.git] / rudp / process.go
1 package rudp
2
3 import (
4         "encoding/binary"
5         "errors"
6         "fmt"
7         "io"
8 )
9
10 // A PktError is an error that occured while processing a packet.
11 type PktError struct {
12         Type string // "net", "raw" or "rel".
13         Data []byte
14         Err  error
15 }
16
17 func (e PktError) Error() string {
18         return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
19 }
20
21 func (e PktError) Unwrap() error { return e.Err }
22
23 func (p *Peer) processNetPkts(pkts <-chan netPkt) {
24         for pkt := range pkts {
25                 if err := p.processNetPkt(pkt); err != nil {
26                         p.errs <- PktError{"net", pkt.Data, err}
27                 }
28         }
29
30         close(p.pkts)
31 }
32
33 // A TrailingDataError reports a packet with trailing data,
34 // it doesn't stop a packet from being processed.
35 type TrailingDataError []byte
36
37 func (e TrailingDataError) Error() string {
38         return fmt.Sprintf("trailing data: %x", []byte(e))
39 }
40
41 func (p *Peer) processNetPkt(pkt netPkt) (err error) {
42         if pkt.SrcAddr.String() != p.Addr().String() {
43                 return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
44         }
45
46         if len(pkt.Data) < MtHdrSize {
47                 return io.ErrUnexpectedEOF
48         }
49
50         if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
51                 return fmt.Errorf("unsupported protocol id: 0x%08x", id)
52         }
53
54         // src PeerID at pkt.Data[4:6]
55
56         chno := pkt.Data[6]
57         if chno >= ChannelCount {
58                 return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
59         }
60
61         p.mu.RLock()
62         if p.timeout != nil {
63                 p.timeout.Reset(ConnTimeout)
64         }
65         p.mu.RUnlock()
66
67         rpkt := rawPkt{
68                 Data:  pkt.Data[MtHdrSize:],
69                 ChNo:  chno,
70                 Unrel: true,
71         }
72         if err := p.processRawPkt(rpkt); err != nil {
73                 p.errs <- PktError{"raw", rpkt.Data, err}
74         }
75
76         return nil
77 }
78
79 func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
80         errWrap := func(format string, a ...interface{}) {
81                 if err != nil {
82                         err = fmt.Errorf(format, append(a, err)...)
83                 }
84         }
85
86         c := &p.chans[pkt.ChNo]
87
88         if len(pkt.Data) < 1 {
89                 return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
90         }
91         switch t := rawType(pkt.Data[0]); t {
92         case rawTypeCtl:
93                 defer errWrap("ctl: %w")
94
95                 if len(pkt.Data) < 1+1 {
96                         return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
97                 }
98                 switch ct := ctlType(pkt.Data[1]); ct {
99                 case ctlAck:
100                         defer errWrap("ack: %w")
101
102                         if len(pkt.Data) < 1+1+2 {
103                                 return io.ErrUnexpectedEOF
104                         }
105
106                         sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
107
108                         if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
109                                 close(ack.(chan struct{}))
110                         }
111
112                         if len(pkt.Data) > 1+1+2 {
113                                 return TrailingDataError(pkt.Data[1+1+2:])
114                         }
115                 case ctlSetPeerID:
116                         defer errWrap("set peer id: %w")
117
118                         if len(pkt.Data) < 1+1+2 {
119                                 return io.ErrUnexpectedEOF
120                         }
121
122                         // Ensure no concurrent senders while peer id changes.
123                         p.mu.Lock()
124                         if p.idOfPeer != PeerIDNil {
125                                 return errors.New("peer id already set")
126                         }
127
128                         p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
129                         p.mu.Unlock()
130
131                         if len(pkt.Data) > 1+1+2 {
132                                 return TrailingDataError(pkt.Data[1+1+2:])
133                         }
134                 case ctlPing:
135                         defer errWrap("ping: %w")
136
137                         if len(pkt.Data) > 1+1 {
138                                 return TrailingDataError(pkt.Data[1+1:])
139                         }
140                 case ctlDisco:
141                         defer errWrap("disco: %w")
142
143                         if err := p.Close(); err != nil {
144                                 return fmt.Errorf("can't close: %w", err)
145                         }
146
147                         if len(pkt.Data) > 1+1 {
148                                 return TrailingDataError(pkt.Data[1+1:])
149                         }
150                 default:
151                         return fmt.Errorf("unsupported ctl type: %d", ct)
152                 }
153         case rawTypeOrig:
154                 p.pkts <- Pkt{
155                         Data:  pkt.Data[1:],
156                         ChNo:  pkt.ChNo,
157                         Unrel: pkt.Unrel,
158                 }
159         case rawTypeSplit:
160                 defer errWrap("split: %w")
161
162                 if len(pkt.Data) < 1+2+2+2 {
163                         return io.ErrUnexpectedEOF
164                 }
165
166                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
167                 count := binary.BigEndian.Uint16(pkt.Data[3:5])
168                 i := binary.BigEndian.Uint16(pkt.Data[5:7])
169
170                 if i >= count {
171                         return nil
172                 }
173
174                 splitpkts := p.chans[pkt.ChNo].insplit
175
176                 // Delete old incomplete split packets
177                 // so new ones don't get corrupted.
178                 delete(splitpkts, sn-0x8000)
179
180                 if splitpkts[sn] == nil {
181                         splitpkts[sn] = make([][]byte, count)
182                 }
183
184                 chunks := splitpkts[sn]
185
186                 if int(count) != len(chunks) {
187                         return fmt.Errorf("chunk count changed on seqnum: %d", sn)
188                 }
189
190                 chunks[i] = pkt.Data[7:]
191
192                 for _, chunk := range chunks {
193                         if chunk == nil {
194                                 return nil
195                         }
196                 }
197
198                 var data []byte
199                 for _, chunk := range chunks {
200                         data = append(data, chunk...)
201                 }
202
203                 p.pkts <- Pkt{
204                         Data:  data,
205                         ChNo:  pkt.ChNo,
206                         Unrel: pkt.Unrel,
207                 }
208
209                 delete(splitpkts, sn)
210         case rawTypeRel:
211                 defer errWrap("rel: %w")
212
213                 if len(pkt.Data) < 1+2 {
214                         return io.ErrUnexpectedEOF
215                 }
216
217                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
218
219                 ackdata := make([]byte, 1+1+2)
220                 ackdata[0] = uint8(rawTypeCtl)
221                 ackdata[1] = uint8(ctlAck)
222                 binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
223                 ack := rawPkt{
224                         Data:  ackdata,
225                         ChNo:  pkt.ChNo,
226                         Unrel: true,
227                 }
228                 if _, err := p.sendRaw(ack); err != nil {
229                         return fmt.Errorf("can't ack %d: %w", sn, err)
230                 }
231
232                 if sn-c.inrelsn >= 0x8000 {
233                         return nil // Already received.
234                 }
235
236                 c.inrel[sn] = pkt.Data[3:]
237
238                 for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
239                         data := c.inrel[c.inrelsn]
240                         delete(c.inrel, c.inrelsn)
241
242                         rpkt := rawPkt{
243                                 Data:  data,
244                                 ChNo:  pkt.ChNo,
245                                 Unrel: false,
246                         }
247                         if err := p.processRawPkt(rpkt); err != nil {
248                                 p.errs <- PktError{"rel", rpkt.Data, err}
249                         }
250                 }
251         default:
252                 return fmt.Errorf("unsupported pkt type: %d", t)
253         }
254
255         return nil
256 }