]> git.lizzy.rs Git - mt.git/blob - rudp/send.go
rudp: partial rewrite with new API supporting io.Readers
[mt.git] / rudp / send.go
1 package rudp
2
3 import (
4         "bytes"
5         "errors"
6         "fmt"
7         "io"
8         "net"
9         "sync"
10         "sync/atomic"
11         "time"
12 )
13
14 var ErrPktTooBig = errors.New("can't send pkt: too big")
15
16 // A TooBigChError reports a Channel greater than or equal to ChannelCount.
17 type TooBigChError Channel
18
19 func (e TooBigChError) Error() string {
20         return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e)
21 }
22
23 // Send sends a Pkt to the Conn.
24 // Ack is closed when the packet is acknowledged.
25 // Ack is nil if pkt.Unrel is true or err != nil.
26 func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) {
27         if pkt.Channel >= ChannelCount {
28                 return nil, TooBigChError(pkt.Channel)
29         }
30
31         var e error
32         send := c.sendRaw(func(buf []byte) int {
33                 buf[0] = uint8(rawOrig)
34
35                 nn := 1
36                 for nn < len(buf) {
37                         n, err := pkt.Read(buf[nn:])
38                         nn += n
39                         if err != nil {
40                                 e = err
41                                 return nn
42                         }
43                 }
44
45                 if _, e = pkt.Read(nil); e != nil {
46                         return nn
47                 }
48
49                 pkt.Reader = io.MultiReader(
50                         bytes.NewReader([]byte(buf[1:nn])),
51                         pkt.Reader,
52                 )
53                 return nn
54         }, pkt.PktInfo)
55         if e != nil {
56                 if e == io.EOF {
57                         return send()
58                 }
59                 return nil, e
60         }
61
62         var (
63                 sn seqnum
64                 i  uint16
65
66                 sends []func() (<-chan struct{}, error)
67         )
68
69         for {
70                 var (
71                         b []byte
72                         e error
73                 )
74                 send := c.sendRaw(func(buf []byte) int {
75                         buf[0] = uint8(rawSplit)
76
77                         n, err := io.ReadFull(pkt, buf[7:])
78                         if err != nil && err != io.ErrUnexpectedEOF {
79                                 e = err
80                                 return 0
81                         }
82
83                         be.PutUint16(buf[5:7], i)
84                         if i++; i == 0 {
85                                 e = ErrPktTooBig
86                                 return 0
87                         }
88
89                         b = buf
90                         return 7 + n
91                 }, pkt.PktInfo)
92                 if e != nil {
93                         if e == io.EOF {
94                                 break
95                         }
96                         return nil, e
97                 }
98
99                 sends = append(sends, func() (<-chan struct{}, error) {
100                         be.PutUint16(b[1:3], uint16(sn))
101                         be.PutUint16(b[3:5], i)
102                         return send()
103                 })
104         }
105
106         ch := &c.chans[pkt.Channel]
107
108         ch.outSplitMu.Lock()
109         sn = ch.outSplitSN
110         ch.outSplitSN++
111         ch.outSplitMu.Unlock()
112
113         var wg sync.WaitGroup
114
115         for _, send := range sends {
116                 ack, err := send()
117                 if err != nil {
118                         return nil, err
119                 }
120                 if !pkt.Unrel {
121                         wg.Add(1)
122                         go func() {
123                                 <-ack
124                                 wg.Done()
125                         }()
126                 }
127         }
128
129         if !pkt.Unrel {
130                 ack := make(chan struct{})
131                 go func() {
132                         wg.Wait()
133                         close(ack)
134                 }()
135                 return ack, nil
136         }
137
138         return nil, nil
139 }
140
141 func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) {
142         if pi.Unrel {
143                 buf := make([]byte, maxUDPPktSize)
144                 be.PutUint32(buf[0:4], protoID)
145                 c.mu.RLock()
146                 be.PutUint16(buf[4:6], uint16(c.remoteID))
147                 c.mu.RUnlock()
148                 buf[6] = uint8(pi.Channel)
149                 buf = buf[:7+read(buf[7:])]
150
151                 return func() (<-chan struct{}, error) {
152                         if _, err := c.udpConn.Write(buf); err != nil {
153                                 c.close(err)
154                                 return nil, net.ErrClosed
155                         }
156
157                         c.ping.Reset(PingTimeout)
158                         if atomic.LoadUint32(&c.closing) == 1 {
159                                 c.ping.Stop()
160                         }
161
162                         return nil, nil
163                 }
164         }
165
166         pi.Unrel = true
167         var snBuf []byte
168         send := c.sendRaw(func(buf []byte) int {
169                 buf[0] = uint8(rawRel)
170                 snBuf = buf[1:3]
171                 return 3 + read(buf[3:])
172         }, pi)
173
174         return func() (<-chan struct{}, error) {
175                 ch := &c.chans[pi.Channel]
176
177                 ch.outRelMu.Lock()
178                 defer ch.outRelMu.Unlock()
179
180                 sn := ch.outRelSN
181                 be.PutUint16(snBuf, uint16(sn))
182                 for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ {
183                         if ack, ok := ch.ackChans.Load(ch.outRelWin); ok {
184                                 select {
185                                 case <-ack.(chan struct{}):
186                                 case <-c.Closed():
187                                 }
188                         }
189                 }
190
191                 ack := make(chan struct{})
192                 ch.ackChans.Store(sn, ack)
193
194                 if _, err := send(); err != nil {
195                         if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
196                                 close(ack.(chan struct{}))
197                         }
198                         return nil, err
199                 }
200                 ch.outRelSN++
201
202                 go func() {
203                         t := time.NewTimer(500 * time.Millisecond)
204                         defer t.Stop()
205
206                         for {
207                                 select {
208                                 case <-ack:
209                                         return
210                                 case <-t.C:
211                                         send()
212                                         t.Reset(500 * time.Millisecond)
213                                 case <-c.Closed():
214                                         return
215                                 }
216                         }
217                 }()
218
219                 return ack, nil
220         }
221 }