]> git.lizzy.rs Git - mt.git/blob - rudp/send.go
rudp: deprecate ErrClosed and replace with net.ErrClosed
[mt.git] / rudp / send.go
1 package rudp
2
3 import (
4         "encoding/binary"
5         "errors"
6         "fmt"
7         "math"
8         "net"
9         "sync"
10         "time"
11 )
12
13 const (
14         // protoID + src PeerID + channel number
15         MtHdrSize = 4 + 2 + 1
16
17         // rawTypeOrig
18         OrigHdrSize = 1
19
20         // rawTypeSpilt + seqnum + chunk count + chunk number
21         SplitHdrSize = 1 + 2 + 2 + 2
22
23         // rawTypeRel + seqnum
24         RelHdrSize = 1 + 2
25 )
26
27 const (
28         MaxNetPktSize = 512
29
30         MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
31         MaxRelRawPktSize   = MaxUnrelRawPktSize - RelHdrSize
32
33         MaxRelPktSize   = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
34         MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
35 )
36
37 var ErrPktTooBig = errors.New("can't send pkt: too big")
38 var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
39
40 // Send sends a packet to the Peer.
41 // It returns a channel that's closed when all chunks are acked or an error.
42 // The ack channel is nil if pkt.Unrel is true.
43 func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
44         if pkt.ChNo >= ChannelCount {
45                 return nil, ErrChNoTooBig
46         }
47
48         hdrsize := MtHdrSize
49         if !pkt.Unrel {
50                 hdrsize += RelHdrSize
51         }
52
53         if hdrsize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
54                 c := &p.chans[pkt.ChNo]
55
56                 c.outsplitmu.Lock()
57                 sn := c.outsplitsn
58                 c.outsplitsn++
59                 c.outsplitmu.Unlock()
60
61                 chunks := split(pkt.Data, MaxNetPktSize-(hdrsize+SplitHdrSize))
62
63                 if len(chunks) > math.MaxUint16 {
64                         return nil, ErrPktTooBig
65                 }
66
67                 var wg sync.WaitGroup
68
69                 for i, chunk := range chunks {
70                         data := make([]byte, SplitHdrSize+len(chunk))
71                         data[0] = uint8(rawTypeSplit)
72                         binary.BigEndian.PutUint16(data[1:3], uint16(sn))
73                         binary.BigEndian.PutUint16(data[3:5], uint16(len(chunks)))
74                         binary.BigEndian.PutUint16(data[5:7], uint16(i))
75                         copy(data[SplitHdrSize:], chunk)
76
77                         wg.Add(1)
78                         ack, err := p.sendRaw(rawPkt{
79                                 Data:  data,
80                                 ChNo:  pkt.ChNo,
81                                 Unrel: pkt.Unrel,
82                         })
83                         if err != nil {
84                                 return nil, err
85                         }
86                         if !pkt.Unrel {
87                                 if ack == nil {
88                                         panic("ack is nil")
89                                 }
90                                 go func() {
91                                         <-ack
92                                         wg.Done()
93                                 }()
94                         }
95                 }
96
97                 if pkt.Unrel {
98                         return nil, nil
99                 } else {
100                         ack := make(chan struct{})
101
102                         go func() {
103                                 wg.Wait()
104                                 close(ack)
105                         }()
106
107                         return ack, nil
108                 }
109         }
110
111         return p.sendRaw(rawPkt{
112                 Data:  append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
113                 ChNo:  pkt.ChNo,
114                 Unrel: pkt.Unrel,
115         })
116 }
117
118 // sendRaw sends a raw packet to the Peer.
119 func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
120         if pkt.ChNo >= ChannelCount {
121                 return nil, ErrChNoTooBig
122         }
123
124         p.mu.RLock()
125         defer p.mu.RUnlock()
126
127         select {
128         case <-p.Disco():
129                 return nil, net.ErrClosed
130         default:
131         }
132
133         if !pkt.Unrel {
134                 return p.sendRel(pkt)
135         }
136
137         data := make([]byte, MtHdrSize+len(pkt.Data))
138         binary.BigEndian.PutUint32(data[0:4], protoID)
139         binary.BigEndian.PutUint16(data[4:6], uint16(p.idOfPeer))
140         data[6] = pkt.ChNo
141         copy(data[MtHdrSize:], pkt.Data)
142
143         if len(data) > MaxNetPktSize {
144                 return nil, ErrPktTooBig
145         }
146
147         _, err = p.Conn().WriteTo(data, p.Addr())
148         if errors.Is(err, net.ErrWriteToConnected) {
149                 conn, ok := p.Conn().(net.Conn)
150                 if !ok {
151                         return nil, err
152                 }
153                 _, err = conn.Write(data)
154         }
155         if err != nil {
156                 return nil, err
157         }
158
159         p.ping.Reset(PingTimeout)
160
161         return nil, nil
162 }
163
164 // sendRel sends a reliable raw packet to the Peer.
165 func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
166         if pkt.Unrel {
167                 panic("mt/rudp: sendRel: pkt.Unrel is true")
168         }
169
170         c := &p.chans[pkt.ChNo]
171
172         c.outrelmu.Lock()
173         defer c.outrelmu.Unlock()
174
175         sn := c.outrelsn
176         for ; sn-c.outrelwin >= 0x8000; c.outrelwin++ {
177                 if ack, ok := c.ackchans.Load(c.outrelwin); ok {
178                         <-ack.(chan struct{})
179                 }
180         }
181         c.outrelsn++
182
183         rwack := make(chan struct{}) // close-only
184         c.ackchans.Store(sn, rwack)
185         ack = rwack
186
187         reldata := make([]byte, RelHdrSize+len(pkt.Data))
188         reldata[0] = uint8(rawTypeRel)
189         binary.BigEndian.PutUint16(reldata[1:3], uint16(sn))
190         copy(reldata[RelHdrSize:], pkt.Data)
191         relpkt := rawPkt{
192                 Data:  reldata,
193                 ChNo:  pkt.ChNo,
194                 Unrel: true,
195         }
196
197         if _, err := p.sendRaw(relpkt); err != nil {
198                 c.ackchans.Delete(sn)
199
200                 return nil, err
201         }
202
203         go func() {
204                 for {
205                         select {
206                         case <-time.After(500 * time.Millisecond):
207                                 if _, err := p.sendRaw(relpkt); err != nil {
208                                         p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
209                                 }
210                         case <-ack:
211                                 return
212                         case <-p.Disco():
213                                 return
214                         }
215                 }
216         }()
217
218         return ack, nil
219 }
220
221 // SendDisco sends a disconnect packet to the Peer but does not close it.
222 // It returns a channel that's closed when it's acked or an error.
223 // The ack channel is nil if unrel is true.
224 func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
225         return p.sendRaw(rawPkt{
226                 Data:  []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
227                 ChNo:  chno,
228                 Unrel: unrel,
229         })
230 }
231
232 func split(data []byte, chunksize int) [][]byte {
233         chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
234
235         for i := 0; i < len(data); i += chunksize {
236                 end := i + chunksize
237                 if end > len(data) {
238                         end = len(data)
239                 }
240
241                 chunks = append(chunks, data[i:end])
242         }
243
244         return chunks
245 }