]> git.lizzy.rs Git - mt.git/blobdiff - rudp/send.go
rudp: partial rewrite with new API supporting io.Readers
[mt.git] / rudp / send.go
index c43d0561a1f041a0adab453e3b082636975c0390..5522ee0f08417fafd4c111f044613474ef7b5fe6 100644 (file)
 package rudp
 
 import (
+       "bytes"
        "errors"
        "fmt"
-       "math"
+       "io"
        "net"
        "sync"
+       "sync/atomic"
        "time"
 )
 
-const (
-       // protoID + src PeerID + channel number
-       MtHdrSize = 4 + 2 + 1
-
-       // rawTypeOrig
-       OrigHdrSize = 1
-
-       // rawTypeSpilt + seqnum + chunk count + chunk number
-       SplitHdrSize = 1 + 2 + 2 + 2
-
-       // rawTypeRel + seqnum
-       RelHdrSize = 1 + 2
-)
-
-const (
-       MaxNetPktSize = 512
-
-       MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
-       MaxRelRawPktSize   = MaxUnrelRawPktSize - RelHdrSize
-
-       MaxRelPktSize   = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
-       MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
-)
-
 var ErrPktTooBig = errors.New("can't send pkt: too big")
-var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
-
-// Send sends a packet to the Peer.
-// It returns a channel that's closed when all chunks are acked or an error.
-// The ack channel is nil if pkt.Unrel is true.
-func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
-       if pkt.ChNo >= ChannelCount {
-               return nil, ErrChNoTooBig
-       }
-
-       hdrSize := MtHdrSize
-       if !pkt.Unrel {
-               hdrSize += RelHdrSize
-       }
-
-       if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
-               c := &p.chans[pkt.ChNo]
-
-               c.outSplitMu.Lock()
-               sn := c.outSplitSN
-               c.outSplitSN++
-               c.outSplitMu.Unlock()
 
-               chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize))
+// A TooBigChError reports a Channel greater than or equal to ChannelCount.
+type TooBigChError Channel
 
-               if len(chunks) > math.MaxUint16 {
-                       return nil, ErrPktTooBig
-               }
+func (e TooBigChError) Error() string {
+       return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e)
+}
 
-               var wg sync.WaitGroup
+// Send sends a Pkt to the Conn.
+// Ack is closed when the packet is acknowledged.
+// Ack is nil if pkt.Unrel is true or err != nil.
+func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+       if pkt.Channel >= ChannelCount {
+               return nil, TooBigChError(pkt.Channel)
+       }
 
-               for i, chunk := range chunks {
-                       data := make([]byte, SplitHdrSize+len(chunk))
-                       data[0] = uint8(rawTypeSplit)
-                       be.PutUint16(data[1:3], uint16(sn))
-                       be.PutUint16(data[3:5], uint16(len(chunks)))
-                       be.PutUint16(data[5:7], uint16(i))
-                       copy(data[SplitHdrSize:], chunk)
+       var e error
+       send := c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawOrig)
 
-                       wg.Add(1)
-                       ack, err := p.sendRaw(rawPkt{
-                               Data:  data,
-                               ChNo:  pkt.ChNo,
-                               Unrel: pkt.Unrel,
-                       })
+               nn := 1
+               for nn < len(buf) {
+                       n, err := pkt.Read(buf[nn:])
+                       nn += n
                        if err != nil {
-                               return nil, err
-                       }
-                       if !pkt.Unrel {
-                               go func() {
-                                       <-ack
-                                       wg.Done()
-                               }()
+                               e = err
+                               return nn
                        }
                }
 
-               if pkt.Unrel {
-                       return nil, nil
-               } else {
-                       ack := make(chan struct{})
-
-                       go func() {
-                               wg.Wait()
-                               close(ack)
-                       }()
+               if _, e = pkt.Read(nil); e != nil {
+                       return nn
+               }
 
-                       return ack, nil
+               pkt.Reader = io.MultiReader(
+                       bytes.NewReader([]byte(buf[1:nn])),
+                       pkt.Reader,
+               )
+               return nn
+       }, pkt.PktInfo)
+       if e != nil {
+               if e == io.EOF {
+                       return send()
                }
+               return nil, e
        }
 
-       return p.sendRaw(rawPkt{
-               Data:  append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
-               ChNo:  pkt.ChNo,
-               Unrel: pkt.Unrel,
-       })
-}
+       var (
+               sn seqnum
+               i  uint16
+
+               sends []func() (<-chan struct{}, error)
+       )
+
+       for {
+               var (
+                       b []byte
+                       e error
+               )
+               send := c.sendRaw(func(buf []byte) int {
+                       buf[0] = uint8(rawSplit)
+
+                       n, err := io.ReadFull(pkt, buf[7:])
+                       if err != nil && err != io.ErrUnexpectedEOF {
+                               e = err
+                               return 0
+                       }
 
-// sendRaw sends a raw packet to the Peer.
-func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
-       if pkt.ChNo >= ChannelCount {
-               return nil, ErrChNoTooBig
-       }
+                       be.PutUint16(buf[5:7], i)
+                       if i++; i == 0 {
+                               e = ErrPktTooBig
+                               return 0
+                       }
 
-       p.mu.RLock()
-       defer p.mu.RUnlock()
+                       b = buf
+                       return 7 + n
+               }, pkt.PktInfo)
+               if e != nil {
+                       if e == io.EOF {
+                               break
+                       }
+                       return nil, e
+               }
 
-       select {
-       case <-p.Disco():
-               return nil, net.ErrClosed
-       default:
+               sends = append(sends, func() (<-chan struct{}, error) {
+                       be.PutUint16(b[1:3], uint16(sn))
+                       be.PutUint16(b[3:5], i)
+                       return send()
+               })
        }
 
-       if !pkt.Unrel {
-               return p.sendRel(pkt)
-       }
+       ch := &c.chans[pkt.Channel]
 
-       data := make([]byte, MtHdrSize+len(pkt.Data))
-       be.PutUint32(data[0:4], protoID)
-       be.PutUint16(data[4:6], uint16(p.idOfPeer))
-       data[6] = pkt.ChNo
-       copy(data[MtHdrSize:], pkt.Data)
+       ch.outSplitMu.Lock()
+       sn = ch.outSplitSN
+       ch.outSplitSN++
+       ch.outSplitMu.Unlock()
 
-       if len(data) > MaxNetPktSize {
-               return nil, ErrPktTooBig
-       }
+       var wg sync.WaitGroup
 
-       if p.conn != nil {
-               _, err = p.conn.Write(data)
-       } else {
-               _, err = p.pc.WriteTo(data, p.Addr())
-       }
-       if err != nil {
-               return nil, err
+       for _, send := range sends {
+               ack, err := send()
+               if err != nil {
+                       return nil, err
+               }
+               if !pkt.Unrel {
+                       wg.Add(1)
+                       go func() {
+                               <-ack
+                               wg.Done()
+                       }()
+               }
        }
 
-       p.ping.Reset(PingTimeout)
+       if !pkt.Unrel {
+               ack := make(chan struct{})
+               go func() {
+                       wg.Wait()
+                       close(ack)
+               }()
+               return ack, nil
+       }
 
        return nil, nil
 }
 
-// sendRel sends a reliable raw packet to the Peer.
-func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
-       if pkt.Unrel {
-               panic("pkt.Unrel is true")
-       }
-
-       c := &p.chans[pkt.ChNo]
+func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) {
+       if pi.Unrel {
+               buf := make([]byte, maxUDPPktSize)
+               be.PutUint32(buf[0:4], protoID)
+               c.mu.RLock()
+               be.PutUint16(buf[4:6], uint16(c.remoteID))
+               c.mu.RUnlock()
+               buf[6] = uint8(pi.Channel)
+               buf = buf[:7+read(buf[7:])]
+
+               return func() (<-chan struct{}, error) {
+                       if _, err := c.udpConn.Write(buf); err != nil {
+                               c.close(err)
+                               return nil, net.ErrClosed
+                       }
 
-       c.outRelMu.Lock()
-       defer c.outRelMu.Unlock()
+                       c.ping.Reset(PingTimeout)
+                       if atomic.LoadUint32(&c.closing) == 1 {
+                               c.ping.Stop()
+                       }
 
-       sn := c.outRelSN
-       for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ {
-               if ack, ok := c.ackChans.Load(c.outRelWin); ok {
-                       <-ack.(chan struct{})
+                       return nil, nil
                }
        }
-       c.outRelSN++
-
-       rwack := make(chan struct{}) // close-only
-       c.ackChans.Store(sn, rwack)
-       ack = rwack
-
-       data := make([]byte, RelHdrSize+len(pkt.Data))
-       data[0] = uint8(rawTypeRel)
-       be.PutUint16(data[1:3], uint16(sn))
-       copy(data[RelHdrSize:], pkt.Data)
-       rel := rawPkt{
-               Data:  data,
-               ChNo:  pkt.ChNo,
-               Unrel: true,
-       }
-
-       if _, err := p.sendRaw(rel); err != nil {
-               c.ackChans.Delete(sn)
 
-               return nil, err
-       }
-
-       go func() {
-               for {
-                       select {
-                       case <-time.After(500 * time.Millisecond):
-                               if _, err := p.sendRaw(rel); err != nil {
-                                       if errors.Is(err, net.ErrClosed) {
-                                               return
-                                       }
-                                       p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+       pi.Unrel = true
+       var snBuf []byte
+       send := c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawRel)
+               snBuf = buf[1:3]
+               return 3 + read(buf[3:])
+       }, pi)
+
+       return func() (<-chan struct{}, error) {
+               ch := &c.chans[pi.Channel]
+
+               ch.outRelMu.Lock()
+               defer ch.outRelMu.Unlock()
+
+               sn := ch.outRelSN
+               be.PutUint16(snBuf, uint16(sn))
+               for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ {
+                       if ack, ok := ch.ackChans.Load(ch.outRelWin); ok {
+                               select {
+                               case <-ack.(chan struct{}):
+                               case <-c.Closed():
                                }
-                       case <-ack:
-                               return
-                       case <-p.Disco():
-                               return
                        }
                }
-       }()
-
-       return ack, nil
-}
-
-// SendDisco sends a disconnect packet to the Peer but does not close it.
-// It returns a channel that's closed when it's acked or an error.
-// The ack channel is nil if unrel is true.
-func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
-       return p.sendRaw(rawPkt{
-               Data:  []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
-               ChNo:  chno,
-               Unrel: unrel,
-       })
-}
 
-func split(data []byte, chunksize int) [][]byte {
-       chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+               ack := make(chan struct{})
+               ch.ackChans.Store(sn, ack)
 
-       for i := 0; i < len(data); i += chunksize {
-               end := i + chunksize
-               if end > len(data) {
-                       end = len(data)
+               if _, err := send(); err != nil {
+                       if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+                               close(ack.(chan struct{}))
+                       }
+                       return nil, err
                }
+               ch.outRelSN++
+
+               go func() {
+                       t := time.NewTimer(500 * time.Millisecond)
+                       defer t.Stop()
+
+                       for {
+                               select {
+                               case <-ack:
+                                       return
+                               case <-t.C:
+                                       send()
+                                       t.Reset(500 * time.Millisecond)
+                               case <-c.Closed():
+                                       return
+                               }
+                       }
+               }()
 
-               chunks = append(chunks, data[i:end])
+               return ack, nil
        }
-
-       return chunks
 }