]> git.lizzy.rs Git - mt.git/blob - rudp/process.go
c36af81f7066a311b19f3ee06f1cd2c77e2a051d
[mt.git] / rudp / process.go
1 package rudp
2
3 import (
4         "encoding/binary"
5         "encoding/hex"
6         "errors"
7         "fmt"
8         "io"
9 )
10
11 // A PktError is an error that occured while processing a packet.
12 type PktError struct {
13         Type string // "net", "raw" or "rel".
14         Data []byte
15         Err  error
16 }
17
18 func (e PktError) Error() string {
19         return "error processing " + e.Type + " pkt: " +
20                 hex.EncodeToString(e.Data) + ": " +
21                 e.Err.Error()
22 }
23
24 func (e PktError) Unwrap() error { return e.Err }
25
26 func (p *Peer) processNetPkts(pkts <-chan netPkt) {
27         for pkt := range pkts {
28                 if err := p.processNetPkt(pkt); err != nil {
29                         p.errs <- PktError{"net", pkt.Data, err}
30                 }
31         }
32
33         close(p.pkts)
34 }
35
36 // A TrailingDataError reports a packet with trailing data,
37 // it doesn't stop a packet from being processed.
38 type TrailingDataError []byte
39
40 func (e TrailingDataError) Error() string {
41         return "trailing data: " + hex.EncodeToString([]byte(e))
42 }
43
44 func (p *Peer) processNetPkt(pkt netPkt) (err error) {
45         if pkt.SrcAddr.String() != p.Addr().String() {
46                 return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
47         }
48
49         if len(pkt.Data) < MtHdrSize {
50                 return io.ErrUnexpectedEOF
51         }
52
53         if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
54                 return fmt.Errorf("unsupported protocol id: 0x%08x", id)
55         }
56
57         // src PeerID at pkt.Data[4:6]
58
59         chno := pkt.Data[6]
60         if chno >= ChannelCount {
61                 return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
62         }
63
64         p.mu.RLock()
65         if p.timeout != nil {
66                 p.timeout.Reset(ConnTimeout)
67         }
68         p.mu.RUnlock()
69
70         rpkt := rawPkt{
71                 Data:  pkt.Data[MtHdrSize:],
72                 ChNo:  chno,
73                 Unrel: true,
74         }
75         if err := p.processRawPkt(rpkt); err != nil {
76                 p.errs <- PktError{"raw", rpkt.Data, err}
77         }
78
79         return nil
80 }
81
82 func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
83         errWrap := func(format string, a ...interface{}) {
84                 if err != nil {
85                         err = fmt.Errorf(format, append(a, err)...)
86                 }
87         }
88
89         c := &p.chans[pkt.ChNo]
90
91         if len(pkt.Data) < 1 {
92                 return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
93         }
94         switch t := rawType(pkt.Data[0]); t {
95         case rawTypeCtl:
96                 defer errWrap("ctl: %w")
97
98                 if len(pkt.Data) < 1+1 {
99                         return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
100                 }
101                 switch ct := ctlType(pkt.Data[1]); ct {
102                 case ctlAck:
103                         defer errWrap("ack: %w")
104
105                         if len(pkt.Data) < 1+1+2 {
106                                 return io.ErrUnexpectedEOF
107                         }
108
109                         sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
110
111                         if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
112                                 close(ack.(chan struct{}))
113                         }
114
115                         if len(pkt.Data) > 1+1+2 {
116                                 return TrailingDataError(pkt.Data[1+1+2:])
117                         }
118                 case ctlSetPeerID:
119                         defer errWrap("set peer id: %w")
120
121                         if len(pkt.Data) < 1+1+2 {
122                                 return io.ErrUnexpectedEOF
123                         }
124
125                         // Ensure no concurrent senders while peer id changes.
126                         p.mu.Lock()
127                         if p.idOfPeer != PeerIDNil {
128                                 return errors.New("peer id already set")
129                         }
130
131                         p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
132                         p.mu.Unlock()
133
134                         if len(pkt.Data) > 1+1+2 {
135                                 return TrailingDataError(pkt.Data[1+1+2:])
136                         }
137                 case ctlPing:
138                         defer errWrap("ping: %w")
139
140                         if len(pkt.Data) > 1+1 {
141                                 return TrailingDataError(pkt.Data[1+1:])
142                         }
143                 case ctlDisco:
144                         defer errWrap("disco: %w")
145
146                         if err := p.Close(); err != nil {
147                                 return fmt.Errorf("can't close: %w", err)
148                         }
149
150                         if len(pkt.Data) > 1+1 {
151                                 return TrailingDataError(pkt.Data[1+1:])
152                         }
153                 default:
154                         return fmt.Errorf("unsupported ctl type: %d", ct)
155                 }
156         case rawTypeOrig:
157                 p.pkts <- Pkt{
158                         Data:  pkt.Data[1:],
159                         ChNo:  pkt.ChNo,
160                         Unrel: pkt.Unrel,
161                 }
162         case rawTypeSplit:
163                 defer errWrap("split: %w")
164
165                 if len(pkt.Data) < 1+2+2+2 {
166                         return io.ErrUnexpectedEOF
167                 }
168
169                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
170                 count := binary.BigEndian.Uint16(pkt.Data[3:5])
171                 i := binary.BigEndian.Uint16(pkt.Data[5:7])
172
173                 if i >= count {
174                         return nil
175                 }
176
177                 splitpkts := p.chans[pkt.ChNo].insplit
178
179                 // Delete old incomplete split packets
180                 // so new ones don't get corrupted.
181                 delete(splitpkts, sn-0x8000)
182
183                 if splitpkts[sn] == nil {
184                         splitpkts[sn] = make([][]byte, count)
185                 }
186
187                 chunks := splitpkts[sn]
188
189                 if int(count) != len(chunks) {
190                         return fmt.Errorf("chunk count changed on seqnum: %d", sn)
191                 }
192
193                 chunks[i] = pkt.Data[7:]
194
195                 for _, chunk := range chunks {
196                         if chunk == nil {
197                                 return nil
198                         }
199                 }
200
201                 var data []byte
202                 for _, chunk := range chunks {
203                         data = append(data, chunk...)
204                 }
205
206                 p.pkts <- Pkt{
207                         Data:  data,
208                         ChNo:  pkt.ChNo,
209                         Unrel: pkt.Unrel,
210                 }
211
212                 delete(splitpkts, sn)
213         case rawTypeRel:
214                 defer errWrap("rel: %w")
215
216                 if len(pkt.Data) < 1+2 {
217                         return io.ErrUnexpectedEOF
218                 }
219
220                 sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
221
222                 ackdata := make([]byte, 1+1+2)
223                 ackdata[0] = uint8(rawTypeCtl)
224                 ackdata[1] = uint8(ctlAck)
225                 binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
226                 ack := rawPkt{
227                         Data:  ackdata,
228                         ChNo:  pkt.ChNo,
229                         Unrel: true,
230                 }
231                 if _, err := p.sendRaw(ack); err != nil {
232                         return fmt.Errorf("can't ack %d: %w", sn, err)
233                 }
234
235                 if sn-c.inrelsn >= 0x8000 {
236                         return nil // Already received.
237                 }
238
239                 c.inrel[sn] = pkt.Data[3:]
240
241                 for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
242                         data := c.inrel[c.inrelsn]
243                         delete(c.inrel, c.inrelsn)
244
245                         rpkt := rawPkt{
246                                 Data:  data,
247                                 ChNo:  pkt.ChNo,
248                                 Unrel: false,
249                         }
250                         if err := p.processRawPkt(rpkt); err != nil {
251                                 p.errs <- PktError{"rel", rpkt.Data, err}
252                         }
253                 }
254         default:
255                 return fmt.Errorf("unsupported pkt type: %d", t)
256         }
257
258         return nil
259 }