]> git.lizzy.rs Git - mt.git/blob - rudp/process.go
c85aba4edafec01fb008bc174a4c20b4c0e066a7
[mt.git] / rudp / process.go
1 package rudp
2
3 import (
4         "encoding/binary"
5         "errors"
6         "fmt"
7         "io"
8         "net"
9 )
10
11 // A PktError is an error that occured while processing a packet.
12 type PktError struct {
13         Type string // "net", "raw" or "rel".
14         Data []byte
15         Err  error
16 }
17
18 func (e PktError) Error() string {
19         return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
20 }
21
22 func (e PktError) Unwrap() error { return e.Err }
23
24 func (p *Peer) processNetPkts(pkts <-chan netPkt) {
25         for pkt := range pkts {
26                 if err := p.processNetPkt(pkt); err != nil {
27                         p.errs <- PktError{"net", pkt.Data, err}
28                 }
29         }
30
31         close(p.pkts)
32 }
33
34 // A TrailingDataError reports a packet with trailing data,
35 // it doesn't stop a packet from being processed.
36 type TrailingDataError []byte
37
38 func (e TrailingDataError) Error() string {
39         return fmt.Sprintf("trailing data: %x", []byte(e))
40 }
41
42 func (p *Peer) processNetPkt(pkt netPkt) (err error) {
43         if pkt.SrcAddr.String() != p.Addr().String() {
44                 return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
45         }
46
47         if len(pkt.Data) < MtHdrSize {
48                 return io.ErrUnexpectedEOF
49         }
50
51         if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
52                 return fmt.Errorf("unsupported protocol id: 0x%08x", id)
53         }
54
55         // src PeerID at pkt.Data[4:6]
56
57         chno := pkt.Data[6]
58         if chno >= ChannelCount {
59                 return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
60         }
61
62         p.mu.RLock()
63         if p.timeout != nil {
64                 p.timeout.Reset(ConnTimeout)
65         }
66         p.mu.RUnlock()
67
68         rpkt := rawPkt{
69                 Data:  pkt.Data[MtHdrSize:],
70                 ChNo:  chno,
71                 Unrel: true,
72         }
73         if err := p.processRawPkt(rpkt); err != nil {
74                 p.errs <- PktError{"raw", rpkt.Data, err}
75         }
76
77         return nil
78 }
79
80 func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
81         errWrap := func(format string, a ...interface{}) {
82                 if err != nil {
83                         err = fmt.Errorf(format, append(a, err)...)
84                 }
85         }
86
87         c := &p.chans[pkt.ChNo]
88
89         if len(pkt.Data) < 1 {
90                 return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
91         }
92         switch t := rawType(pkt.Data[0]); t {
93         case rawTypeCtl:
94                 defer errWrap("ctl: %w")
95
96                 if len(pkt.Data) < 1+1 {
97                         return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
98                 }
99                 switch ct := ctlType(pkt.Data[1]); ct {
100                 case ctlAck:
101                         defer errWrap("ack: %w")
102
103                         if len(pkt.Data) < 1+1+2 {
104                                 return io.ErrUnexpectedEOF
105                         }
106
107                         sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
108
109                         if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
110                                 close(ack.(chan struct{}))
111                         }
112
113                         if len(pkt.Data) > 1+1+2 {
114                                 return TrailingDataError(pkt.Data[1+1+2:])
115                         }
116                 case ctlSetPeerID:
117                         defer errWrap("set peer id: %w")
118
119                         if len(pkt.Data) < 1+1+2 {
120                                 return io.ErrUnexpectedEOF
121                         }
122
123                         // Ensure no concurrent senders while peer id changes.
124                         p.mu.Lock()
125                         if p.idOfPeer != PeerIDNil {
126                                 return errors.New("peer id already set")
127                         }
128
129                         p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
130                         p.mu.Unlock()
131
132                         if len(pkt.Data) > 1+1+2 {
133                                 return TrailingDataError(pkt.Data[1+1+2:])
134                         }
135                 case ctlPing:
136                         defer errWrap("ping: %w")
137
138                         if len(pkt.Data) > 1+1 {
139                                 return TrailingDataError(pkt.Data[1+1:])
140                         }
141                 case ctlDisco:
142                         defer errWrap("disco: %w")
143
144                         p.Close()
145
146                         if len(pkt.Data) > 1+1 {
147                                 return TrailingDataError(pkt.Data[1+1:])
148                         }
149                 default:
150                         return fmt.Errorf("unsupported ctl type: %d", ct)
151                 }
152         case rawTypeOrig:
153                 p.pkts <- Pkt{
154                         Data:  pkt.Data[1:],
155                         ChNo:  pkt.ChNo,
156                         Unrel: pkt.Unrel,
157                 }
158         case rawTypeSplit:
159                 defer errWrap("split: %w")
160
161                 if len(pkt.Data) < 1+2+2+2 {
162                         return io.ErrUnexpectedEOF
163                 }
164
165                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
166                 count := binary.BigEndian.Uint16(pkt.Data[3:5])
167                 i := binary.BigEndian.Uint16(pkt.Data[5:7])
168
169                 if i >= count {
170                         return nil
171                 }
172
173                 splitpkts := p.chans[pkt.ChNo].insplit
174
175                 // Delete old incomplete split packets
176                 // so new ones don't get corrupted.
177                 delete(splitpkts, sn-0x8000)
178
179                 if splitpkts[sn] == nil {
180                         splitpkts[sn] = make([][]byte, count)
181                 }
182
183                 chunks := splitpkts[sn]
184
185                 if int(count) != len(chunks) {
186                         return fmt.Errorf("chunk count changed on seqnum: %d", sn)
187                 }
188
189                 chunks[i] = pkt.Data[7:]
190
191                 for _, chunk := range chunks {
192                         if chunk == nil {
193                                 return nil
194                         }
195                 }
196
197                 var data []byte
198                 for _, chunk := range chunks {
199                         data = append(data, chunk...)
200                 }
201
202                 p.pkts <- Pkt{
203                         Data:  data,
204                         ChNo:  pkt.ChNo,
205                         Unrel: pkt.Unrel,
206                 }
207
208                 delete(splitpkts, sn)
209         case rawTypeRel:
210                 defer errWrap("rel: %w")
211
212                 if len(pkt.Data) < 1+2 {
213                         return io.ErrUnexpectedEOF
214                 }
215
216                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
217
218                 ackdata := make([]byte, 1+1+2)
219                 ackdata[0] = uint8(rawTypeCtl)
220                 ackdata[1] = uint8(ctlAck)
221                 binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
222                 ack := rawPkt{
223                         Data:  ackdata,
224                         ChNo:  pkt.ChNo,
225                         Unrel: true,
226                 }
227                 if _, err := p.sendRaw(ack); err != nil {
228                         if errors.Is(err, net.ErrClosed) {
229                                 return nil
230                         }
231                         return fmt.Errorf("can't ack %d: %w", sn, err)
232                 }
233
234                 if sn-c.inrelsn >= 0x8000 {
235                         return nil // Already received.
236                 }
237
238                 c.inrel[sn] = pkt.Data[3:]
239
240                 for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
241                         data := c.inrel[c.inrelsn]
242                         delete(c.inrel, c.inrelsn)
243
244                         rpkt := rawPkt{
245                                 Data:  data,
246                                 ChNo:  pkt.ChNo,
247                                 Unrel: false,
248                         }
249                         if err := p.processRawPkt(rpkt); err != nil {
250                                 p.errs <- PktError{"rel", rpkt.Data, err}
251                         }
252                 }
253         default:
254                 return fmt.Errorf("unsupported pkt type: %d", t)
255         }
256
257         return nil
258 }