]> git.lizzy.rs Git - mt.git/blob - rudp/recv.go
rudp: partial rewrite with new API supporting io.Readers
[mt.git] / rudp / recv.go
1 package rudp
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "net"
9         "time"
10 )
11
12 // Recv receives a Pkt from the Conn.
13 func (c *Conn) Recv() (Pkt, error) {
14         select {
15         case pkt := <-c.pkts:
16                 return pkt, nil
17         case err := <-c.errs:
18                 return Pkt{}, err
19         case <-c.Closed():
20                 return Pkt{}, net.ErrClosed
21         }
22 }
23
24 func (c *Conn) gotPkt(pkt Pkt) {
25         select {
26         case c.pkts <- pkt:
27         case <-c.Closed():
28         }
29 }
30
31 func (c *Conn) gotErr(kind string, data []byte, err error) {
32         select {
33         case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err):
34         case <-c.Closed():
35         }
36 }
37
38 func (c *Conn) recvUDPPkts() {
39         for {
40                 pkt, err := c.udpConn.recvUDP()
41                 if err != nil {
42                         c.closeDisco(err)
43                         break
44                 }
45
46                 if err := c.processUDPPkt(pkt); err != nil {
47                         c.gotErr("udp", pkt, err)
48                 }
49         }
50 }
51
52 func (c *Conn) processUDPPkt(pkt []byte) error {
53         if c.timeout.Stop() {
54                 c.timeout.Reset(ConnTimeout)
55         }
56
57         if len(pkt) < 6 {
58                 return io.ErrUnexpectedEOF
59         }
60
61         if id := be.Uint32(pkt[0:4]); id != protoID {
62                 return fmt.Errorf("unsupported protocol id: 0x%08x", id)
63         }
64
65         ch := Channel(pkt[6])
66         if ch >= ChannelCount {
67                 return TooBigChError(ch)
68         }
69
70         if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil {
71                 c.gotErr("raw", pkt, err)
72         }
73
74         return nil
75 }
76
77 // A TrailingDataError reports trailing data after a packet.
78 type TrailingDataError []byte
79
80 func (e TrailingDataError) Error() string {
81         return fmt.Sprintf("trailing data: %x", []byte(e))
82 }
83
84 func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) {
85         errWrap := func(format string, a ...interface{}) {
86                 if err != nil {
87                         err = fmt.Errorf(format+": %w", append(a, err)...)
88                 }
89         }
90
91         eof := new(byte)
92         defer func() {
93                 switch r := recover(); r {
94                 case nil:
95                 case eof:
96                         err = io.ErrUnexpectedEOF
97                 default:
98                         panic(r)
99                 }
100         }()
101
102         off := 0
103         eat := func(n int) []byte {
104                 i := off
105                 off += n
106                 if i > len(data) {
107                         panic(eof)
108                 }
109                 return data[i:off]
110         }
111
112         ch := &c.chans[pi.Channel]
113
114         switch t := rawType(eat(1)[0]); t {
115         case rawCtl:
116                 defer errWrap("ctl")
117
118                 switch ct := ctlType(eat(1)[0]); ct {
119                 case ctlAck:
120                         defer errWrap("ack")
121
122                         sn := seqnum(be.Uint16(eat(2)))
123
124                         if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
125                                 close(ack.(chan struct{}))
126                         }
127                 case ctlSetPeerID:
128                         defer errWrap("set peer id")
129
130                         c.mu.Lock()
131                         if c.remoteID != PeerIDNil {
132                                 return errors.New("peer id already set")
133                         }
134
135                         c.remoteID = PeerID(be.Uint16(eat(2)))
136                         c.mu.Unlock()
137
138                         c.newAckBuf()
139                 case ctlPing:
140                         defer errWrap("ping")
141                 case ctlDisco:
142                         defer errWrap("disco")
143
144                         c.close(nil)
145                 default:
146                         return fmt.Errorf("unsupported ctl type: %d", ct)
147                 }
148
149                 if off < len(data) {
150                         return TrailingDataError(data[off:])
151                 }
152         case rawOrig:
153                 c.gotPkt(Pkt{
154                         Reader:  bytes.NewReader(data[off:]),
155                         PktInfo: pi,
156                 })
157         case rawSplit:
158                 defer errWrap("split")
159
160                 sn := seqnum(be.Uint16(eat(2)))
161                 n := be.Uint16(eat(2))
162                 i := be.Uint16(eat(2))
163
164                 defer errWrap("%d", sn)
165
166                 if i >= n {
167                         return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n)
168                 }
169
170                 ch.inSplitsMu.RLock()
171                 s := ch.inSplits[sn]
172                 ch.inSplitsMu.RUnlock()
173
174                 if s == nil {
175                         s = &inSplit{chunks: make([][]byte, n)}
176                         if pi.Unrel {
177                                 s.timeout = time.AfterFunc(ConnTimeout, func() {
178                                         ch.inSplitsMu.Lock()
179                                         delete(ch.inSplits, sn)
180                                         ch.inSplitsMu.Unlock()
181                                 })
182                         }
183
184                         ch.inSplitsMu.Lock()
185                         ch.inSplits[sn] = s
186                         ch.inSplitsMu.Unlock()
187                 }
188
189                 if int(n) != len(s.chunks) {
190                         return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n)
191                 }
192
193                 if s.chunks[i] == nil {
194                         s.chunks[i] = data[off:]
195                         s.got++
196                 }
197
198                 if s.got < len(s.chunks) {
199                         if s.timeout != nil && s.timeout.Stop() {
200                                 s.timeout.Reset(ConnTimeout)
201                         }
202                         return
203                 }
204
205                 if s.timeout != nil {
206                         s.timeout.Stop()
207                 }
208
209                 ch.inSplitsMu.Lock()
210                 delete(ch.inSplits, sn)
211                 ch.inSplitsMu.Unlock()
212
213                 c.gotPkt(Pkt{
214                         Reader:  (*net.Buffers)(&s.chunks),
215                         PktInfo: pi,
216                 })
217         case rawRel:
218                 defer errWrap("rel")
219
220                 sn := seqnum(be.Uint16(eat(2)))
221
222                 defer errWrap("%d", sn)
223
224                 be.PutUint16(ch.ackBuf, uint16(sn))
225                 ch.sendAck()
226
227                 if sn-ch.inRelSN >= 0x8000 {
228                         // Already received.
229                         return nil
230                 }
231
232                 ch.inRels[sn&0x7fff] = data[off:]
233
234                 i := func() seqnum { return ch.inRelSN & 0x7fff }
235                 for ; ch.inRels[i()] != nil; ch.inRelSN++ {
236                         data := ch.inRels[i()]
237                         ch.inRels[i()] = nil
238                         if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil {
239                                 c.gotErr("rel", data, err)
240                         }
241                 }
242         default:
243                 return fmt.Errorf("unsupported pkt type: %d", t)
244         }
245
246         return nil
247 }
248
249 func (c *Conn) newAckBuf() {
250         for i := range c.chans {
251                 ch := &c.chans[i]
252                 ch.sendAck = c.sendRaw(func(buf []byte) int {
253                         buf[0] = uint8(rawCtl)
254                         buf[1] = uint8(ctlAck)
255                         ch.ackBuf = buf[2:4]
256                         return 4
257                 }, PktInfo{Channel: Channel(i), Unrel: true})
258         }
259 }