]> git.lizzy.rs Git - mt.git/commitdiff
rudp: partial rewrite with new API supporting io.Readers
authoranon5 <anon5clam@protonmail.com>
Mon, 22 Mar 2021 18:37:36 +0000 (18:37 +0000)
committeranon5 <anon5clam@protonmail.com>
Mon, 22 Mar 2021 18:37:36 +0000 (18:37 +0000)
rudp/conn.go [new file with mode: 0644]
rudp/connect.go [new file with mode: 0644]
rudp/listen.go
rudp/net.go [deleted file]
rudp/peer.go [deleted file]
rudp/process.go [deleted file]
rudp/proxy/proxy.go
rudp/recv.go [new file with mode: 0644]
rudp/rudp.go
rudp/send.go
rudp/udp.go [new file with mode: 0644]

diff --git a/rudp/conn.go b/rudp/conn.go
new file mode 100644 (file)
index 0000000..7e241a8
--- /dev/null
@@ -0,0 +1,173 @@
+package rudp
+
+import (
+       "net"
+       "sync"
+       "sync/atomic"
+       "time"
+)
+
+// A Conn is a connection to a client or server.
+// All Conn's methods are safe for concurrent use.
+type Conn struct {
+       udpConn udpConn
+
+       id PeerID
+
+       pkts chan Pkt
+       errs chan error
+
+       timeout *time.Timer
+       ping    *time.Ticker
+
+       closing uint32
+       closed  chan struct{}
+       err     error
+
+       mu       sync.RWMutex
+       remoteID PeerID
+
+       chans [ChannelCount]pktChan // read/write
+}
+
+// ID returns the PeerID of the Conn.
+func (c *Conn) ID() PeerID { return c.id }
+
+// IsSrv reports whether the Conn is a connection to a server.
+func (c *Conn) IsSrv() bool { return c.ID() == PeerIDSrv }
+
+// Closed returns a channel which is closed when the Conn is closed.
+func (c *Conn) Closed() <-chan struct{} { return c.closed }
+
+// WhyClosed returns the error that caused the Conn to be closed or nil
+// if the Conn was closed using the Close method or by the peer.
+// WhyClosed returns nil if the Conn is not closed.
+func (c *Conn) WhyClosed() error {
+       select {
+       case <-c.Closed():
+               return c.err
+       default:
+               return nil
+       }
+}
+
+// LocalAddr returns the local network address.
+func (c *Conn) LocalAddr() net.Addr { return c.udpConn.LocalAddr() }
+
+// RemoteAddr returns the remote network address.
+func (c *Conn) RemoteAddr() net.Addr { return c.udpConn.RemoteAddr() }
+
+type pktChan struct {
+       // Only accessed by Conn.recvUDPPkts goroutine.
+       inRels  *[0x8000][]byte
+       inRelSN seqnum
+       sendAck func() (<-chan struct{}, error)
+       ackBuf  []byte
+
+       inSplitsMu sync.RWMutex
+       inSplits   map[seqnum]*inSplit
+
+       ackChans sync.Map // map[seqnum]chan struct{}
+
+       outSplitMu sync.Mutex
+       outSplitSN seqnum
+
+       outRelMu  sync.Mutex
+       outRelSN  seqnum
+       outRelWin seqnum
+}
+
+type inSplit struct {
+       chunks  [][]byte
+       got     int
+       timeout *time.Timer
+}
+
+// Close closes the Conn.
+// Any blocked Send or Recv calls will return net.ErrClosed.
+func (c *Conn) Close() error {
+       return c.closeDisco(nil)
+}
+
+func (c *Conn) closeDisco(err error) error {
+       c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawCtl)
+               buf[1] = uint8(ctlDisco)
+               return 2
+       }, PktInfo{Unrel: true})()
+
+       return c.close(err)
+}
+
+func (c *Conn) close(err error) error {
+       if atomic.SwapUint32(&c.closing, 1) == 1 {
+               return net.ErrClosed
+       }
+
+       c.timeout.Stop()
+       c.ping.Stop()
+
+       c.err = err
+       defer close(c.closed)
+
+       return c.udpConn.Close()
+}
+
+func newConn(uc udpConn, id, remoteID PeerID) *Conn {
+       var c *Conn
+       c = &Conn{
+               udpConn: uc,
+
+               id: id,
+
+               pkts: make(chan Pkt),
+               errs: make(chan error),
+
+               timeout: time.AfterFunc(ConnTimeout, func() {
+                       c.closeDisco(ErrTimedOut)
+               }),
+               ping: time.NewTicker(PingTimeout),
+
+               closed: make(chan struct{}),
+
+               remoteID: remoteID,
+       }
+
+       for i := range c.chans {
+               c.chans[i] = pktChan{
+                       inRels:  new([0x8000][]byte),
+                       inRelSN: initSeqnum,
+
+                       inSplits: make(map[seqnum]*inSplit),
+
+                       outSplitSN: initSeqnum,
+
+                       outRelSN:  initSeqnum,
+                       outRelWin: initSeqnum,
+               }
+       }
+
+       c.newAckBuf()
+
+       go c.sendPings(c.ping.C)
+       go c.recvUDPPkts()
+
+       return c
+}
+
+func (c *Conn) sendPings(ping <-chan time.Time) {
+       send := c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawCtl)
+               buf[1] = uint8(ctlPing)
+               return 2
+       }, PktInfo{})
+
+       for {
+               select {
+               case <-ping:
+                       send()
+               case <-c.Closed():
+                       return
+               }
+       }
+}
diff --git a/rudp/connect.go b/rudp/connect.go
new file mode 100644 (file)
index 0000000..548ab15
--- /dev/null
@@ -0,0 +1,18 @@
+package rudp
+
+import "net"
+
+type udpSrv struct {
+       net.Conn
+}
+
+func (us udpSrv) recvUDP() ([]byte, error) {
+       buf := make([]byte, maxUDPPktSize)
+       n, err := us.Read(buf)
+       return buf[:n], err
+}
+
+// Connect returns a Conn connected to conn.
+func Connect(conn net.Conn) *Conn {
+       return newConn(udpSrv{conn}, PeerIDSrv, PeerIDNil)
+}
index 5b7154a56c7dc02ef91e89a5ff5888e074a4c21f..e1cacf49fb267ceab8a1ccdda932dc786934744f 100644 (file)
@@ -2,155 +2,213 @@ package rudp
 
 import (
        "errors"
-       "fmt"
        "net"
        "sync"
 )
 
+func tryClose(ch chan struct{}) (ok bool) {
+       defer func() { recover() }()
+       close(ch)
+       return true
+}
+
+type udpClt struct {
+       l      *Listener
+       id     PeerID
+       addr   net.Addr
+       pkts   chan []byte
+       closed chan struct{}
+}
+
+func (c *udpClt) mkConn() {
+       conn := newConn(c, c.id, PeerIDSrv)
+       go func() {
+               <-conn.Closed()
+               c.l.wg.Done()
+       }()
+       conn.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawCtl)
+               buf[1] = uint8(ctlSetPeerID)
+               be.PutUint16(buf[2:4], uint16(conn.ID()))
+               return 4
+       }, PktInfo{})()
+       select {
+       case c.l.conns <- conn:
+       case <-c.l.closed:
+               conn.Close()
+       }
+}
+
+func (c *udpClt) Write(pkt []byte) (int, error) {
+       select {
+       case <-c.closed:
+               return 0, net.ErrClosed
+       default:
+       }
+
+       return c.l.pc.WriteTo(pkt, c.addr)
+}
+
+func (c *udpClt) recvUDP() ([]byte, error) {
+       select {
+       case pkt := <-c.pkts:
+               return pkt, nil
+       case <-c.closed:
+               return nil, net.ErrClosed
+       }
+}
+
+func (c *udpClt) Close() error {
+       if !tryClose(c.closed) {
+               return net.ErrClosed
+       }
+
+       c.l.mu.Lock()
+       defer c.l.mu.Unlock()
+
+       delete(c.l.ids, c.id)
+       delete(c.l.clts, c.addr.String())
+
+       return nil
+}
+
+func (c *udpClt) LocalAddr() net.Addr  { return c.l.pc.LocalAddr() }
+func (c *udpClt) RemoteAddr() net.Addr { return c.addr }
+
+// All Listener's methods are safe for concurrent use.
 type Listener struct {
-       conn net.PacketConn
+       pc net.PacketConn
 
-       clts chan cltPeer
-       errs chan error
+       peerID PeerID
+       conns  chan *Conn
+       errs   chan error
+       closed chan struct{}
+       wg     sync.WaitGroup
 
-       mu        sync.Mutex
-       addr2peer map[string]cltPeer
-       id2peer   map[PeerID]cltPeer
-       peerID    PeerID
+       mu   sync.RWMutex
+       ids  map[PeerID]bool
+       clts map[string]*udpClt
 }
 
-// Listen listens for packets on conn until it is closed.
-func Listen(conn net.PacketConn) *Listener {
+// Listen listens for connections on pc, pc is closed once the returned Listener
+// and all Conns connected through it are closed.
+func Listen(pc net.PacketConn) *Listener {
        l := &Listener{
-               conn: conn,
+               pc: pc,
 
-               clts: make(chan cltPeer),
-               errs: make(chan error),
+               conns:  make(chan *Conn),
+               closed: make(chan struct{}),
 
-               addr2peer: make(map[string]cltPeer),
-               id2peer:   make(map[PeerID]cltPeer),
+               ids:  make(map[PeerID]bool),
+               clts: make(map[string]*udpClt),
        }
 
-       pkts := make(chan netPkt)
-       go readNetPkts(l.conn, pkts, l.errs)
        go func() {
-               for pkt := range pkts {
-                       if err := l.processNetPkt(pkt); err != nil {
-                               l.errs <- err
+               for {
+                       if err := l.processNetPkt(); err != nil {
+                               if errors.Is(err, net.ErrClosed) {
+                                       break
+                               }
+                               select {
+                               case l.errs <- err:
+                               case <-l.closed:
+                               }
                        }
                }
-
-               close(l.clts)
-
-               for _, clt := range l.addr2peer {
-                       clt.Close()
-               }
        }()
 
        return l
 }
 
-// Accept waits for and returns a connecting Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (l *Listener) Accept() (*Peer, error) {
+// Accept waits for and returns the next incoming Conn or an error.
+func (l *Listener) Accept() (*Conn, error) {
        select {
-       case clt, ok := <-l.clts:
-               if !ok {
-                       select {
-                       case err := <-l.errs:
-                               return nil, err
-                       default:
-                               return nil, net.ErrClosed
-                       }
-               }
-               close(clt.accepted)
-               return clt.Peer, nil
+       case c := <-l.conns:
+               return c, nil
        case err := <-l.errs:
                return nil, err
+       case <-l.closed:
+               return nil, net.ErrClosed
        }
 }
 
-// Addr returns the net.PacketConn the Listener is listening on.
-func (l *Listener) Conn() net.PacketConn { return l.conn }
+// Close makes the Listener stop listening for new Conns.
+// Blocked Accept calls will return net.ErrClosed.
+// Already Accepted Conns are not closed.
+func (l *Listener) Close() error {
+       if !tryClose(l.closed) {
+               return net.ErrClosed
+       }
 
-var ErrOutOfPeerIDs = errors.New("out of peer ids")
+       go func() {
+               l.wg.Wait()
+               l.pc.Close()
+       }()
 
-type cltPeer struct {
-       *Peer
-       pkts     chan<- netPkt
-       accepted chan struct{} // close-only
+       return nil
 }
 
-func (l *Listener) processNetPkt(pkt netPkt) error {
-       l.mu.Lock()
-       defer l.mu.Unlock()
-
-       addrstr := pkt.SrcAddr.String()
+// Addr returns the Listener's network address.
+func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() }
 
-       clt, ok := l.addr2peer[addrstr]
-       if !ok {
-               prev := l.peerID
-               for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin {
-                       if l.peerID == prev-1 {
-                               return ErrOutOfPeerIDs
-                       }
-                       l.peerID++
-               }
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
 
-               pkts := make(chan netPkt, 256)
+func (l *Listener) processNetPkt() error {
+       buf := make([]byte, maxUDPPktSize)
+       n, addr, err := l.pc.ReadFrom(buf)
+       if err != nil {
+               return err
+       }
 
-               clt = cltPeer{
-                       Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv),
-                       pkts:     pkts,
-                       accepted: make(chan struct{}),
+       l.mu.RLock()
+       clt, ok := l.clts[addr.String()]
+       l.mu.RUnlock()
+       if !ok {
+               select {
+               case <-l.closed:
+                       return nil
+               default:
                }
 
-               l.addr2peer[addrstr] = clt
-               l.id2peer[clt.ID()] = clt
-
-               data := make([]byte, 1+1+2)
-               data[0] = uint8(rawTypeCtl)
-               data[1] = uint8(ctlSetPeerID)
-               be.PutUint16(data[2:4], uint16(clt.ID()))
-               if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
-                       if errors.Is(err, net.ErrClosed) {
-                               return nil
-                       }
-                       return fmt.Errorf("can't set client peer id: %w", err)
+               clt, err = l.add(addr)
+               if err != nil {
+                       return err
                }
+       }
 
-               go func() {
-                       select {
-                       case l.clts <- clt:
-                       case <-clt.Disco():
-                       }
-
-                       clt.processNetPkts(pkts)
-               }()
+       select {
+       case clt.pkts <- buf[:n]:
+       case <-clt.closed:
+       }
 
-               go func() {
-                       <-clt.Disco()
+       return nil
+}
 
-                       l.mu.Lock()
-                       close(pkts)
-                       delete(l.addr2peer, addrstr)
-                       delete(l.id2peer, clt.ID())
-                       l.mu.Unlock()
-               }()
-       }
+func (l *Listener) add(addr net.Addr) (*udpClt, error) {
+       l.mu.Lock()
+       defer l.mu.Unlock()
 
-       select {
-       case <-clt.accepted:
-               clt.pkts <- pkt
-       default:
-               select {
-               case clt.pkts <- pkt:
-               default:
-                       // It's OK to drop packets if the buffer is full
-                       // because MT RUDP can cope with packet loss.
+       start := l.peerID
+       l.peerID++
+       for l.peerID < PeerIDCltMin || l.ids[l.peerID] {
+               if l.peerID == start {
+                       return nil, ErrOutOfPeerIDs
                }
+               l.peerID++
+       }
+       l.ids[l.peerID] = true
+
+       clt := &udpClt{
+               l:      l,
+               id:     l.peerID,
+               addr:   addr,
+               pkts:   make(chan []byte),
+               closed: make(chan struct{}),
        }
+       l.clts[addr.String()] = clt
 
-       return nil
+       l.wg.Add(1)
+       go clt.mkConn()
+
+       return clt, nil
 }
diff --git a/rudp/net.go b/rudp/net.go
deleted file mode 100644 (file)
index e2e7289..0000000
+++ /dev/null
@@ -1,41 +0,0 @@
-package rudp
-
-import (
-       "errors"
-       "net"
-)
-
-// ErrClosed is deprecated, use net.ErrClosed instead.
-var ErrClosed = net.ErrClosed
-
-/*
-netPkt.Data format (big endian):
-
-       ProtoID
-       Src PeerID
-       ChNo uint8 // Must be < ChannelCount.
-       RawPkt.Data
-*/
-type netPkt struct {
-       SrcAddr net.Addr
-       Data    []byte
-}
-
-func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) {
-       for {
-               buf := make([]byte, MaxNetPktSize)
-               n, addr, err := conn.ReadFrom(buf)
-               if err != nil {
-                       if errors.Is(err, net.ErrClosed) {
-                               break
-                       }
-
-                       errs <- err
-                       continue
-               }
-
-               pkts <- netPkt{addr, buf[:n]}
-       }
-
-       close(pkts)
-}
diff --git a/rudp/peer.go b/rudp/peer.go
deleted file mode 100644 (file)
index 791249c..0000000
+++ /dev/null
@@ -1,207 +0,0 @@
-package rudp
-
-import (
-       "errors"
-       "fmt"
-       "net"
-       "sync"
-       "time"
-)
-
-const (
-       // ConnTimeout is the amount of time after no packets being received
-       // from a Peer that it is automatically disconnected.
-       ConnTimeout = 30 * time.Second
-
-       // ConnTimeout is the amount of time after no packets being sent
-       // to a Peer that a CtlPing is automatically sent to prevent timeout.
-       PingTimeout = 5 * time.Second
-)
-
-// A Peer is a connection to a client or server.
-type Peer struct {
-       pc   net.PacketConn
-       addr net.Addr
-       conn net.Conn
-
-       disco chan struct{} // close-only
-
-       id PeerID
-
-       pkts     chan Pkt
-       errs     chan error    // don't close
-       timedOut chan struct{} // close-only
-
-       chans [ChannelCount]pktchan // read/write
-
-       mu       sync.RWMutex
-       idOfPeer PeerID
-       timeout  *time.Timer
-       ping     *time.Ticker
-}
-
-// Conn returns the net.PacketConn used to communicate with the Peer.
-func (p *Peer) Conn() net.PacketConn { return p.pc }
-
-// Addr returns the address of the Peer.
-func (p *Peer) Addr() net.Addr { return p.addr }
-
-// Disco returns a channel that is closed when the Peer is closed.
-func (p *Peer) Disco() <-chan struct{} { return p.disco }
-
-// ID returns the ID of the Peer.
-func (p *Peer) ID() PeerID { return p.id }
-
-// IsSrv reports whether the Peer is a server.
-func (p *Peer) IsSrv() bool {
-       return p.ID() == PeerIDSrv
-}
-
-// TimedOut reports whether the Peer has timed out.
-func (p *Peer) TimedOut() bool {
-       select {
-       case <-p.timedOut:
-               return true
-       default:
-               return false
-       }
-}
-
-type inSplit struct {
-       chunks    [][]byte
-       size, got int
-}
-
-type pktchan struct {
-       // Only accessed by Peer.processRawPkt.
-       inSplit *[65536]*inSplit
-       inRel   *[65536][]byte
-       inRelSN seqnum
-
-       ackChans sync.Map // map[seqnum]chan struct{}
-
-       outSplitMu sync.Mutex
-       outSplitSN seqnum
-
-       outRelMu  sync.Mutex
-       outRelSN  seqnum
-       outRelWin seqnum
-}
-
-// Recv recieves a packet from the Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (p *Peer) Recv() (Pkt, error) {
-       select {
-       case pkt, ok := <-p.pkts:
-               if !ok {
-                       select {
-                       case err := <-p.errs:
-                               return Pkt{}, err
-                       default:
-                               return Pkt{}, net.ErrClosed
-                       }
-               }
-               return pkt, nil
-       case err := <-p.errs:
-               return Pkt{}, err
-       }
-}
-
-// Close closes the Peer but does not send a disconnect packet.
-func (p *Peer) Close() error {
-       p.mu.Lock()
-       defer p.mu.Unlock()
-
-       select {
-       case <-p.Disco():
-               return net.ErrClosed
-       default:
-       }
-
-       p.timeout.Stop()
-       p.timeout = nil
-       p.ping.Stop()
-       p.ping = nil
-
-       close(p.disco)
-
-       return nil
-}
-
-func newPeer(pc net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
-       p := &Peer{
-               pc:       pc,
-               addr:     addr,
-               id:       id,
-               idOfPeer: idOfPeer,
-
-               pkts:  make(chan Pkt),
-               disco: make(chan struct{}),
-               errs:  make(chan error),
-       }
-
-       if conn, ok := pc.(net.Conn); ok && conn.RemoteAddr() != nil {
-               p.conn = conn
-       }
-
-       for i := range p.chans {
-               p.chans[i] = pktchan{
-                       inSplit: new([65536]*inSplit),
-                       inRel:   new([65536][]byte),
-                       inRelSN: seqnumInit,
-
-                       outSplitSN: seqnumInit,
-                       outRelSN:   seqnumInit,
-                       outRelWin:  seqnumInit,
-               }
-       }
-
-       p.timedOut = make(chan struct{})
-       p.timeout = time.AfterFunc(ConnTimeout, func() {
-               close(p.timedOut)
-
-               p.SendDisco(0, true)
-               p.Close()
-       })
-
-       p.ping = time.NewTicker(PingTimeout)
-       go p.sendPings(p.ping.C)
-
-       return p
-}
-
-func (p *Peer) sendPings(ping <-chan time.Time) {
-       pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
-
-       for {
-               select {
-               case <-ping:
-                       if _, err := p.sendRaw(pkt); err != nil {
-                               if errors.Is(err, net.ErrClosed) {
-                                       return
-                               }
-                               p.errs <- fmt.Errorf("can't send ping: %w", err)
-                       }
-               case <-p.Disco():
-                       return
-               }
-       }
-}
-
-// Connect connects to addr using pc
-// and closes pc when the returned *Peer disconnects.
-func Connect(pc net.PacketConn, addr net.Addr) *Peer {
-       srv := newPeer(pc, addr, PeerIDSrv, PeerIDNil)
-
-       pkts := make(chan netPkt)
-       go readNetPkts(pc, pkts, srv.errs)
-       go srv.processNetPkts(pkts)
-
-       go func() {
-               <-srv.Disco()
-               pc.Close()
-       }()
-
-       return srv
-}
diff --git a/rudp/process.go b/rudp/process.go
deleted file mode 100644 (file)
index 7238fe5..0000000
+++ /dev/null
@@ -1,253 +0,0 @@
-package rudp
-
-import (
-       "errors"
-       "fmt"
-       "io"
-       "net"
-)
-
-// A PktError is an error that occured while processing a packet.
-type PktError struct {
-       Type string // "net", "raw" or "rel".
-       Data []byte
-       Err  error
-}
-
-func (e PktError) Error() string {
-       return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
-}
-
-func (e PktError) Unwrap() error { return e.Err }
-
-func (p *Peer) processNetPkts(pkts <-chan netPkt) {
-       for pkt := range pkts {
-               if err := p.processNetPkt(pkt); err != nil {
-                       p.errs <- PktError{"net", pkt.Data, err}
-               }
-       }
-
-       close(p.pkts)
-}
-
-// A TrailingDataError reports a packet with trailing data,
-// it doesn't stop a packet from being processed.
-type TrailingDataError []byte
-
-func (e TrailingDataError) Error() string {
-       return fmt.Sprintf("trailing data: %x", []byte(e))
-}
-
-func (p *Peer) processNetPkt(pkt netPkt) (err error) {
-       if pkt.SrcAddr.String() != p.Addr().String() {
-               return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
-       }
-
-       if len(pkt.Data) < MtHdrSize {
-               return io.ErrUnexpectedEOF
-       }
-
-       if id := be.Uint32(pkt.Data[0:4]); id != protoID {
-               return fmt.Errorf("unsupported protocol id: 0x%08x", id)
-       }
-
-       // src PeerID at pkt.Data[4:6]
-
-       chno := pkt.Data[6]
-       if chno >= ChannelCount {
-               return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
-       }
-
-       p.mu.RLock()
-       if p.timeout != nil {
-               p.timeout.Reset(ConnTimeout)
-       }
-       p.mu.RUnlock()
-
-       rpkt := rawPkt{
-               Data:  pkt.Data[MtHdrSize:],
-               ChNo:  chno,
-               Unrel: true,
-       }
-       if err := p.processRawPkt(rpkt); err != nil {
-               p.errs <- PktError{"raw", rpkt.Data, err}
-       }
-
-       return nil
-}
-
-func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
-       errWrap := func(format string, a ...interface{}) {
-               if err != nil {
-                       err = fmt.Errorf(format, append(a, err)...)
-               }
-       }
-
-       c := &p.chans[pkt.ChNo]
-
-       if len(pkt.Data) < 1 {
-               return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
-       }
-       switch t := rawType(pkt.Data[0]); t {
-       case rawTypeCtl:
-               defer errWrap("ctl: %w")
-
-               if len(pkt.Data) < 1+1 {
-                       return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
-               }
-               switch ct := ctlType(pkt.Data[1]); ct {
-               case ctlAck:
-                       defer errWrap("ack: %w")
-
-                       if len(pkt.Data) < 1+1+2 {
-                               return io.ErrUnexpectedEOF
-                       }
-
-                       sn := seqnum(be.Uint16(pkt.Data[2:4]))
-
-                       if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
-                               close(ack.(chan struct{}))
-                       }
-
-                       if len(pkt.Data) > 1+1+2 {
-                               return TrailingDataError(pkt.Data[1+1+2:])
-                       }
-               case ctlSetPeerID:
-                       defer errWrap("set peer id: %w")
-
-                       if len(pkt.Data) < 1+1+2 {
-                               return io.ErrUnexpectedEOF
-                       }
-
-                       // Ensure no concurrent senders while peer id changes.
-                       p.mu.Lock()
-                       if p.idOfPeer != PeerIDNil {
-                               return errors.New("peer id already set")
-                       }
-
-                       p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
-                       p.mu.Unlock()
-
-                       if len(pkt.Data) > 1+1+2 {
-                               return TrailingDataError(pkt.Data[1+1+2:])
-                       }
-               case ctlPing:
-                       defer errWrap("ping: %w")
-
-                       if len(pkt.Data) > 1+1 {
-                               return TrailingDataError(pkt.Data[1+1:])
-                       }
-               case ctlDisco:
-                       defer errWrap("disco: %w")
-
-                       p.Close()
-
-                       if len(pkt.Data) > 1+1 {
-                               return TrailingDataError(pkt.Data[1+1:])
-                       }
-               default:
-                       return fmt.Errorf("unsupported ctl type: %d", ct)
-               }
-       case rawTypeOrig:
-               p.pkts <- Pkt{
-                       Data:  pkt.Data[1:],
-                       ChNo:  pkt.ChNo,
-                       Unrel: pkt.Unrel,
-               }
-       case rawTypeSplit:
-               defer errWrap("split: %w")
-
-               if len(pkt.Data) < 1+2+2+2 {
-                       return io.ErrUnexpectedEOF
-               }
-
-               sn := seqnum(be.Uint16(pkt.Data[1:3]))
-               count := be.Uint16(pkt.Data[3:5])
-               i := be.Uint16(pkt.Data[5:7])
-
-               if i >= count {
-                       return nil
-               }
-
-               splits := p.chans[pkt.ChNo].inSplit
-
-               // Delete old incomplete split packets
-               // so new ones don't get corrupted.
-               splits[sn-0x8000] = nil
-
-               if splits[sn] == nil {
-                       splits[sn] = &inSplit{chunks: make([][]byte, count)}
-               }
-
-               s := splits[sn]
-
-               if int(count) != len(s.chunks) {
-                       return fmt.Errorf("chunk count changed on split packet: %d", sn)
-               }
-
-               s.chunks[i] = pkt.Data[7:]
-               s.size += len(s.chunks[i])
-               s.got++
-
-               if s.got == len(s.chunks) {
-                       data := make([]byte, 0, s.size)
-                       for _, chunk := range s.chunks {
-                               data = append(data, chunk...)
-                       }
-
-                       p.pkts <- Pkt{
-                               Data:  data,
-                               ChNo:  pkt.ChNo,
-                               Unrel: pkt.Unrel,
-                       }
-
-                       splits[sn] = nil
-               }
-       case rawTypeRel:
-               defer errWrap("rel: %w")
-
-               if len(pkt.Data) < 1+2 {
-                       return io.ErrUnexpectedEOF
-               }
-
-               sn := seqnum(be.Uint16(pkt.Data[1:3]))
-
-               ack := make([]byte, 1+1+2)
-               ack[0] = uint8(rawTypeCtl)
-               ack[1] = uint8(ctlAck)
-               be.PutUint16(ack[2:4], uint16(sn))
-               if _, err := p.sendRaw(rawPkt{
-                       Data:  ack,
-                       ChNo:  pkt.ChNo,
-                       Unrel: true,
-               }); err != nil {
-                       if errors.Is(err, net.ErrClosed) {
-                               return nil
-                       }
-                       return fmt.Errorf("can't ack %d: %w", sn, err)
-               }
-
-               if sn-c.inRelSN >= 0x8000 {
-                       return nil // Already received.
-               }
-
-               c.inRel[sn] = pkt.Data[3:]
-
-               for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
-                       rpkt := rawPkt{
-                               Data:  c.inRel[c.inRelSN],
-                               ChNo:  pkt.ChNo,
-                               Unrel: false,
-                       }
-                       c.inRel[c.inRelSN] = nil
-
-                       if err := p.processRawPkt(rpkt); err != nil {
-                               p.errs <- PktError{"rel", rpkt.Data, err}
-                       }
-               }
-       default:
-               return fmt.Errorf("unsupported pkt type: %d", t)
-       }
-
-       return nil
-}
index a80b4487ef4da3f564f0bc9ef539c8662ec49b80..b8ca9f4bd4dccfaf709d219d4e48ae72f0d30bb0 100644 (file)
@@ -25,62 +25,55 @@ func main() {
                os.Exit(1)
        }
 
-       srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1])
+       pc, err := net.ListenPacket("udp", os.Args[2])
        if err != nil {
                log.Fatal(err)
        }
+       defer pc.Close()
 
-       lc, err := net.ListenPacket("udp", os.Args[2])
-       if err != nil {
-               log.Fatal(err)
-       }
-       defer lc.Close()
-
-       l := rudp.Listen(lc)
+       l := rudp.Listen(pc)
        for {
                clt, err := l.Accept()
                if err != nil {
-                       log.Print(err)
+                       log.Print("accept: ", err)
                        continue
                }
 
-               log.Print(clt.Addr(), " connected")
+               log.Print(clt.ID(), ": connected")
 
-               conn, err := net.DialUDP("udp", nil, srvaddr)
+               conn, err := net.Dial("udp", os.Args[1])
                if err != nil {
                        log.Print(err)
                        continue
                }
-               srv := rudp.Connect(conn, conn.RemoteAddr())
+               srv := rudp.Connect(conn)
 
                go proxy(clt, srv)
                go proxy(srv, clt)
        }
 }
 
-func proxy(src, dest *rudp.Peer) {
+func proxy(src, dest *rudp.Conn) {
+       s := fmt.Sprint(src.ID(), " (", src.RemoteAddr(), "): ")
+
        for {
                pkt, err := src.Recv()
                if err != nil {
                        if errors.Is(err, net.ErrClosed) {
-                               msg := src.Addr().String() + " disconnected"
-                               if src.TimedOut() {
-                                       msg += " (timed out)"
+                               if err := src.WhyClosed(); err != nil {
+                                       log.Print(s, "disconnected: ", err)
+                               } else {
+                                       log.Print(s, "disconnected")
                                }
-                               log.Print(msg)
-
                                break
                        }
 
-                       log.Print(err)
+                       log.Print(s, err)
                        continue
                }
 
-               if _, err := dest.Send(pkt); err != nil {
-                       log.Print(err)
-               }
+               dest.Send(pkt)
        }
 
-       dest.SendDisco(0, true)
        dest.Close()
 }
diff --git a/rudp/recv.go b/rudp/recv.go
new file mode 100644 (file)
index 0000000..f5ac236
--- /dev/null
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+       "bytes"
+       "errors"
+       "fmt"
+       "io"
+       "net"
+       "time"
+)
+
+// Recv receives a Pkt from the Conn.
+func (c *Conn) Recv() (Pkt, error) {
+       select {
+       case pkt := <-c.pkts:
+               return pkt, nil
+       case err := <-c.errs:
+               return Pkt{}, err
+       case <-c.Closed():
+               return Pkt{}, net.ErrClosed
+       }
+}
+
+func (c *Conn) gotPkt(pkt Pkt) {
+       select {
+       case c.pkts <- pkt:
+       case <-c.Closed():
+       }
+}
+
+func (c *Conn) gotErr(kind string, data []byte, err error) {
+       select {
+       case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err):
+       case <-c.Closed():
+       }
+}
+
+func (c *Conn) recvUDPPkts() {
+       for {
+               pkt, err := c.udpConn.recvUDP()
+               if err != nil {
+                       c.closeDisco(err)
+                       break
+               }
+
+               if err := c.processUDPPkt(pkt); err != nil {
+                       c.gotErr("udp", pkt, err)
+               }
+       }
+}
+
+func (c *Conn) processUDPPkt(pkt []byte) error {
+       if c.timeout.Stop() {
+               c.timeout.Reset(ConnTimeout)
+       }
+
+       if len(pkt) < 6 {
+               return io.ErrUnexpectedEOF
+       }
+
+       if id := be.Uint32(pkt[0:4]); id != protoID {
+               return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+       }
+
+       ch := Channel(pkt[6])
+       if ch >= ChannelCount {
+               return TooBigChError(ch)
+       }
+
+       if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil {
+               c.gotErr("raw", pkt, err)
+       }
+
+       return nil
+}
+
+// A TrailingDataError reports trailing data after a packet.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+       return fmt.Sprintf("trailing data: %x", []byte(e))
+}
+
+func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) {
+       errWrap := func(format string, a ...interface{}) {
+               if err != nil {
+                       err = fmt.Errorf(format+": %w", append(a, err)...)
+               }
+       }
+
+       eof := new(byte)
+       defer func() {
+               switch r := recover(); r {
+               case nil:
+               case eof:
+                       err = io.ErrUnexpectedEOF
+               default:
+                       panic(r)
+               }
+       }()
+
+       off := 0
+       eat := func(n int) []byte {
+               i := off
+               off += n
+               if i > len(data) {
+                       panic(eof)
+               }
+               return data[i:off]
+       }
+
+       ch := &c.chans[pi.Channel]
+
+       switch t := rawType(eat(1)[0]); t {
+       case rawCtl:
+               defer errWrap("ctl")
+
+               switch ct := ctlType(eat(1)[0]); ct {
+               case ctlAck:
+                       defer errWrap("ack")
+
+                       sn := seqnum(be.Uint16(eat(2)))
+
+                       if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+                               close(ack.(chan struct{}))
+                       }
+               case ctlSetPeerID:
+                       defer errWrap("set peer id")
+
+                       c.mu.Lock()
+                       if c.remoteID != PeerIDNil {
+                               return errors.New("peer id already set")
+                       }
+
+                       c.remoteID = PeerID(be.Uint16(eat(2)))
+                       c.mu.Unlock()
+
+                       c.newAckBuf()
+               case ctlPing:
+                       defer errWrap("ping")
+               case ctlDisco:
+                       defer errWrap("disco")
+
+                       c.close(nil)
+               default:
+                       return fmt.Errorf("unsupported ctl type: %d", ct)
+               }
+
+               if off < len(data) {
+                       return TrailingDataError(data[off:])
+               }
+       case rawOrig:
+               c.gotPkt(Pkt{
+                       Reader:  bytes.NewReader(data[off:]),
+                       PktInfo: pi,
+               })
+       case rawSplit:
+               defer errWrap("split")
+
+               sn := seqnum(be.Uint16(eat(2)))
+               n := be.Uint16(eat(2))
+               i := be.Uint16(eat(2))
+
+               defer errWrap("%d", sn)
+
+               if i >= n {
+                       return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n)
+               }
+
+               ch.inSplitsMu.RLock()
+               s := ch.inSplits[sn]
+               ch.inSplitsMu.RUnlock()
+
+               if s == nil {
+                       s = &inSplit{chunks: make([][]byte, n)}
+                       if pi.Unrel {
+                               s.timeout = time.AfterFunc(ConnTimeout, func() {
+                                       ch.inSplitsMu.Lock()
+                                       delete(ch.inSplits, sn)
+                                       ch.inSplitsMu.Unlock()
+                               })
+                       }
+
+                       ch.inSplitsMu.Lock()
+                       ch.inSplits[sn] = s
+                       ch.inSplitsMu.Unlock()
+               }
+
+               if int(n) != len(s.chunks) {
+                       return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n)
+               }
+
+               if s.chunks[i] == nil {
+                       s.chunks[i] = data[off:]
+                       s.got++
+               }
+
+               if s.got < len(s.chunks) {
+                       if s.timeout != nil && s.timeout.Stop() {
+                               s.timeout.Reset(ConnTimeout)
+                       }
+                       return
+               }
+
+               if s.timeout != nil {
+                       s.timeout.Stop()
+               }
+
+               ch.inSplitsMu.Lock()
+               delete(ch.inSplits, sn)
+               ch.inSplitsMu.Unlock()
+
+               c.gotPkt(Pkt{
+                       Reader:  (*net.Buffers)(&s.chunks),
+                       PktInfo: pi,
+               })
+       case rawRel:
+               defer errWrap("rel")
+
+               sn := seqnum(be.Uint16(eat(2)))
+
+               defer errWrap("%d", sn)
+
+               be.PutUint16(ch.ackBuf, uint16(sn))
+               ch.sendAck()
+
+               if sn-ch.inRelSN >= 0x8000 {
+                       // Already received.
+                       return nil
+               }
+
+               ch.inRels[sn&0x7fff] = data[off:]
+
+               i := func() seqnum { return ch.inRelSN & 0x7fff }
+               for ; ch.inRels[i()] != nil; ch.inRelSN++ {
+                       data := ch.inRels[i()]
+                       ch.inRels[i()] = nil
+                       if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil {
+                               c.gotErr("rel", data, err)
+                       }
+               }
+       default:
+               return fmt.Errorf("unsupported pkt type: %d", t)
+       }
+
+       return nil
+}
+
+func (c *Conn) newAckBuf() {
+       for i := range c.chans {
+               ch := &c.chans[i]
+               ch.sendAck = c.sendRaw(func(buf []byte) int {
+                       buf[0] = uint8(rawCtl)
+                       buf[1] = uint8(ctlAck)
+                       ch.ackBuf = buf[2:4]
+                       return 4
+               }, PktInfo{Channel: Channel(i), Unrel: true})
+       }
+}
index 6b96b5618d063a2d34004f5d71a3e86bcd780d9e..faf67d97bf7174f77ec9d238c42b72259bf3e2fe 100644 (file)
@@ -1,21 +1,44 @@
 /*
 Package rudp implements the low-level Minetest protocol described at
 https://dev.minetest.net/Network_Protocol#Low-level_protocol.
-
-All exported functions and methods in this package are safe for concurrent use
-by multiple goroutines.
 */
 package rudp
 
-import "encoding/binary"
+import (
+       "encoding/binary"
+       "errors"
+       "io"
+       "time"
+)
 
 var be = binary.BigEndian
 
-// protoID must be at the start of every network packet.
+/*
+UDP packet format:
+
+       protoID
+       src PeerID
+       channel uint8
+       rawType...
+*/
+
+var ErrTimedOut = errors.New("timed out")
+
+const (
+       ConnTimeout = 30 * time.Second
+       PingTimeout = 5 * time.Second
+)
+
+const (
+       MaxRelPktSize   = 32439825
+       MaxUnrelPktSize = 32636430
+)
+
+// protoID must be at the start of every UDP packet.
 const protoID uint32 = 0x4f457403
 
-// PeerIDs aren't actually used to identify peers, network addresses are,
-// these just exist for backward compatability.
+// PeerIDs aren't actually used to identify peers, IP addresses and ports are,
+// these just exist for backward compatibility.
 type PeerID uint16
 
 const (
@@ -29,79 +52,59 @@ const (
        PeerIDCltMin
 )
 
-// ChannelCount is the maximum channel number + 1.
-const ChannelCount = 3
-
-/*
-rawPkt.Data format (big endian):
-
-       rawType
-       switch rawType {
-       case rawTypeCtl:
-               ctlType
-               switch ctlType {
-               case ctlAck:
-                       // Tells peer you received a rawTypeRel
-                       // and it doesn't need to resend it.
-                       seqnum
-               case ctlSetPeerId:
-                       // Tells peer to send packets with this Src PeerID.
-                       PeerId
-               case ctlPing:
-                       // Sent to prevent timeout.
-               case ctlDisco:
-                       // Tells peer that you disconnected.
-               }
-       case rawTypeOrig:
-               Pkt.(Data)
-       case rawTypeSplit:
-               // Packet larger than MaxNetPktSize split into smaller packets.
-               // Packets with I >= Count should be ignored.
-               // Once all Count chunks are recieved, they are sorted by I and
-               // concatenated to make a Pkt.(Data).
-               seqnum // Identifies split packet.
-               Count, I uint16
-               Chunk...
-       case rawTypeRel:
-               // Resent until a ctlAck with same seqnum is recieved.
-               // seqnums are sequencial and start at seqnumInit,
-               // These should be processed in seqnum order.
-               seqnum
-               rawPkt.Data
-       }
-*/
-type rawPkt struct {
-       Data  []byte
-       ChNo  uint8
-       Unrel bool
-}
-
 type rawType uint8
 
 const (
-       rawTypeCtl rawType = iota
-       rawTypeOrig
-       rawTypeSplit
-       rawTypeRel
+       rawCtl rawType = iota
+       // ctlType...
+
+       rawOrig
+       // data...
+
+       rawSplit
+       // seqnum
+       // n, i uint16
+       // data...
+
+       rawRel
+       // seqnum
+       // rawType...
 )
 
 type ctlType uint8
 
 const (
        ctlAck ctlType = iota
+       // seqnum
+
        ctlSetPeerID
-       ctlPing
+       // PeerID
+
+       ctlPing // Sent to prevent timeout.
+
        ctlDisco
 )
 
 type Pkt struct {
-       Data  []byte
-       ChNo  uint8
+       io.Reader
+       PktInfo
+}
+
+// Reliable packets in a channel are be received in the order they are sent in.
+// A Channel must be less than ChannelCount.
+type Channel uint8
+
+const ChannelCount Channel = 3
+
+type PktInfo struct {
+       Channel
+
+       // Unrel (unreliable) packets may be dropped, duplicated or reordered.
        Unrel bool
 }
 
 // seqnums are sequence numbers used to maintain reliable packet order
-// and to identify split packets.
+// and identify split packets.
 type seqnum uint16
 
-const seqnumInit seqnum = 65500
+const initSeqnum seqnum = 65500
index c43d0561a1f041a0adab453e3b082636975c0390..5522ee0f08417fafd4c111f044613474ef7b5fe6 100644 (file)
 package rudp
 
 import (
+       "bytes"
        "errors"
        "fmt"
-       "math"
+       "io"
        "net"
        "sync"
+       "sync/atomic"
        "time"
 )
 
-const (
-       // protoID + src PeerID + channel number
-       MtHdrSize = 4 + 2 + 1
-
-       // rawTypeOrig
-       OrigHdrSize = 1
-
-       // rawTypeSpilt + seqnum + chunk count + chunk number
-       SplitHdrSize = 1 + 2 + 2 + 2
-
-       // rawTypeRel + seqnum
-       RelHdrSize = 1 + 2
-)
-
-const (
-       MaxNetPktSize = 512
-
-       MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
-       MaxRelRawPktSize   = MaxUnrelRawPktSize - RelHdrSize
-
-       MaxRelPktSize   = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
-       MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
-)
-
 var ErrPktTooBig = errors.New("can't send pkt: too big")
-var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
-
-// Send sends a packet to the Peer.
-// It returns a channel that's closed when all chunks are acked or an error.
-// The ack channel is nil if pkt.Unrel is true.
-func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
-       if pkt.ChNo >= ChannelCount {
-               return nil, ErrChNoTooBig
-       }
-
-       hdrSize := MtHdrSize
-       if !pkt.Unrel {
-               hdrSize += RelHdrSize
-       }
-
-       if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
-               c := &p.chans[pkt.ChNo]
-
-               c.outSplitMu.Lock()
-               sn := c.outSplitSN
-               c.outSplitSN++
-               c.outSplitMu.Unlock()
 
-               chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize))
+// A TooBigChError reports a Channel greater than or equal to ChannelCount.
+type TooBigChError Channel
 
-               if len(chunks) > math.MaxUint16 {
-                       return nil, ErrPktTooBig
-               }
+func (e TooBigChError) Error() string {
+       return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e)
+}
 
-               var wg sync.WaitGroup
+// Send sends a Pkt to the Conn.
+// Ack is closed when the packet is acknowledged.
+// Ack is nil if pkt.Unrel is true or err != nil.
+func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+       if pkt.Channel >= ChannelCount {
+               return nil, TooBigChError(pkt.Channel)
+       }
 
-               for i, chunk := range chunks {
-                       data := make([]byte, SplitHdrSize+len(chunk))
-                       data[0] = uint8(rawTypeSplit)
-                       be.PutUint16(data[1:3], uint16(sn))
-                       be.PutUint16(data[3:5], uint16(len(chunks)))
-                       be.PutUint16(data[5:7], uint16(i))
-                       copy(data[SplitHdrSize:], chunk)
+       var e error
+       send := c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawOrig)
 
-                       wg.Add(1)
-                       ack, err := p.sendRaw(rawPkt{
-                               Data:  data,
-                               ChNo:  pkt.ChNo,
-                               Unrel: pkt.Unrel,
-                       })
+               nn := 1
+               for nn < len(buf) {
+                       n, err := pkt.Read(buf[nn:])
+                       nn += n
                        if err != nil {
-                               return nil, err
-                       }
-                       if !pkt.Unrel {
-                               go func() {
-                                       <-ack
-                                       wg.Done()
-                               }()
+                               e = err
+                               return nn
                        }
                }
 
-               if pkt.Unrel {
-                       return nil, nil
-               } else {
-                       ack := make(chan struct{})
-
-                       go func() {
-                               wg.Wait()
-                               close(ack)
-                       }()
+               if _, e = pkt.Read(nil); e != nil {
+                       return nn
+               }
 
-                       return ack, nil
+               pkt.Reader = io.MultiReader(
+                       bytes.NewReader([]byte(buf[1:nn])),
+                       pkt.Reader,
+               )
+               return nn
+       }, pkt.PktInfo)
+       if e != nil {
+               if e == io.EOF {
+                       return send()
                }
+               return nil, e
        }
 
-       return p.sendRaw(rawPkt{
-               Data:  append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
-               ChNo:  pkt.ChNo,
-               Unrel: pkt.Unrel,
-       })
-}
+       var (
+               sn seqnum
+               i  uint16
+
+               sends []func() (<-chan struct{}, error)
+       )
+
+       for {
+               var (
+                       b []byte
+                       e error
+               )
+               send := c.sendRaw(func(buf []byte) int {
+                       buf[0] = uint8(rawSplit)
+
+                       n, err := io.ReadFull(pkt, buf[7:])
+                       if err != nil && err != io.ErrUnexpectedEOF {
+                               e = err
+                               return 0
+                       }
 
-// sendRaw sends a raw packet to the Peer.
-func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
-       if pkt.ChNo >= ChannelCount {
-               return nil, ErrChNoTooBig
-       }
+                       be.PutUint16(buf[5:7], i)
+                       if i++; i == 0 {
+                               e = ErrPktTooBig
+                               return 0
+                       }
 
-       p.mu.RLock()
-       defer p.mu.RUnlock()
+                       b = buf
+                       return 7 + n
+               }, pkt.PktInfo)
+               if e != nil {
+                       if e == io.EOF {
+                               break
+                       }
+                       return nil, e
+               }
 
-       select {
-       case <-p.Disco():
-               return nil, net.ErrClosed
-       default:
+               sends = append(sends, func() (<-chan struct{}, error) {
+                       be.PutUint16(b[1:3], uint16(sn))
+                       be.PutUint16(b[3:5], i)
+                       return send()
+               })
        }
 
-       if !pkt.Unrel {
-               return p.sendRel(pkt)
-       }
+       ch := &c.chans[pkt.Channel]
 
-       data := make([]byte, MtHdrSize+len(pkt.Data))
-       be.PutUint32(data[0:4], protoID)
-       be.PutUint16(data[4:6], uint16(p.idOfPeer))
-       data[6] = pkt.ChNo
-       copy(data[MtHdrSize:], pkt.Data)
+       ch.outSplitMu.Lock()
+       sn = ch.outSplitSN
+       ch.outSplitSN++
+       ch.outSplitMu.Unlock()
 
-       if len(data) > MaxNetPktSize {
-               return nil, ErrPktTooBig
-       }
+       var wg sync.WaitGroup
 
-       if p.conn != nil {
-               _, err = p.conn.Write(data)
-       } else {
-               _, err = p.pc.WriteTo(data, p.Addr())
-       }
-       if err != nil {
-               return nil, err
+       for _, send := range sends {
+               ack, err := send()
+               if err != nil {
+                       return nil, err
+               }
+               if !pkt.Unrel {
+                       wg.Add(1)
+                       go func() {
+                               <-ack
+                               wg.Done()
+                       }()
+               }
        }
 
-       p.ping.Reset(PingTimeout)
+       if !pkt.Unrel {
+               ack := make(chan struct{})
+               go func() {
+                       wg.Wait()
+                       close(ack)
+               }()
+               return ack, nil
+       }
 
        return nil, nil
 }
 
-// sendRel sends a reliable raw packet to the Peer.
-func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
-       if pkt.Unrel {
-               panic("pkt.Unrel is true")
-       }
-
-       c := &p.chans[pkt.ChNo]
+func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) {
+       if pi.Unrel {
+               buf := make([]byte, maxUDPPktSize)
+               be.PutUint32(buf[0:4], protoID)
+               c.mu.RLock()
+               be.PutUint16(buf[4:6], uint16(c.remoteID))
+               c.mu.RUnlock()
+               buf[6] = uint8(pi.Channel)
+               buf = buf[:7+read(buf[7:])]
+
+               return func() (<-chan struct{}, error) {
+                       if _, err := c.udpConn.Write(buf); err != nil {
+                               c.close(err)
+                               return nil, net.ErrClosed
+                       }
 
-       c.outRelMu.Lock()
-       defer c.outRelMu.Unlock()
+                       c.ping.Reset(PingTimeout)
+                       if atomic.LoadUint32(&c.closing) == 1 {
+                               c.ping.Stop()
+                       }
 
-       sn := c.outRelSN
-       for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ {
-               if ack, ok := c.ackChans.Load(c.outRelWin); ok {
-                       <-ack.(chan struct{})
+                       return nil, nil
                }
        }
-       c.outRelSN++
-
-       rwack := make(chan struct{}) // close-only
-       c.ackChans.Store(sn, rwack)
-       ack = rwack
-
-       data := make([]byte, RelHdrSize+len(pkt.Data))
-       data[0] = uint8(rawTypeRel)
-       be.PutUint16(data[1:3], uint16(sn))
-       copy(data[RelHdrSize:], pkt.Data)
-       rel := rawPkt{
-               Data:  data,
-               ChNo:  pkt.ChNo,
-               Unrel: true,
-       }
-
-       if _, err := p.sendRaw(rel); err != nil {
-               c.ackChans.Delete(sn)
 
-               return nil, err
-       }
-
-       go func() {
-               for {
-                       select {
-                       case <-time.After(500 * time.Millisecond):
-                               if _, err := p.sendRaw(rel); err != nil {
-                                       if errors.Is(err, net.ErrClosed) {
-                                               return
-                                       }
-                                       p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+       pi.Unrel = true
+       var snBuf []byte
+       send := c.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawRel)
+               snBuf = buf[1:3]
+               return 3 + read(buf[3:])
+       }, pi)
+
+       return func() (<-chan struct{}, error) {
+               ch := &c.chans[pi.Channel]
+
+               ch.outRelMu.Lock()
+               defer ch.outRelMu.Unlock()
+
+               sn := ch.outRelSN
+               be.PutUint16(snBuf, uint16(sn))
+               for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ {
+                       if ack, ok := ch.ackChans.Load(ch.outRelWin); ok {
+                               select {
+                               case <-ack.(chan struct{}):
+                               case <-c.Closed():
                                }
-                       case <-ack:
-                               return
-                       case <-p.Disco():
-                               return
                        }
                }
-       }()
-
-       return ack, nil
-}
-
-// SendDisco sends a disconnect packet to the Peer but does not close it.
-// It returns a channel that's closed when it's acked or an error.
-// The ack channel is nil if unrel is true.
-func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
-       return p.sendRaw(rawPkt{
-               Data:  []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
-               ChNo:  chno,
-               Unrel: unrel,
-       })
-}
 
-func split(data []byte, chunksize int) [][]byte {
-       chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+               ack := make(chan struct{})
+               ch.ackChans.Store(sn, ack)
 
-       for i := 0; i < len(data); i += chunksize {
-               end := i + chunksize
-               if end > len(data) {
-                       end = len(data)
+               if _, err := send(); err != nil {
+                       if ack, ok := ch.ackChans.LoadAndDelete(sn); ok {
+                               close(ack.(chan struct{}))
+                       }
+                       return nil, err
                }
+               ch.outRelSN++
+
+               go func() {
+                       t := time.NewTimer(500 * time.Millisecond)
+                       defer t.Stop()
+
+                       for {
+                               select {
+                               case <-ack:
+                                       return
+                               case <-t.C:
+                                       send()
+                                       t.Reset(500 * time.Millisecond)
+                               case <-c.Closed():
+                                       return
+                               }
+                       }
+               }()
 
-               chunks = append(chunks, data[i:end])
+               return ack, nil
        }
-
-       return chunks
 }
diff --git a/rudp/udp.go b/rudp/udp.go
new file mode 100644 (file)
index 0000000..503f1d4
--- /dev/null
@@ -0,0 +1,13 @@
+package rudp
+
+import "net"
+
+const maxUDPPktSize = 512
+
+type udpConn interface {
+       recvUDP() ([]byte, error)
+       Write([]byte) (int, error)
+       Close() error
+       LocalAddr() net.Addr
+       RemoteAddr() net.Addr
+}