From 433955e45ef35476ab6863c4aff06b6ddd0bec67 Mon Sep 17 00:00:00 2001 From: anon5 Date: Mon, 22 Mar 2021 18:37:36 +0000 Subject: [PATCH] rudp: partial rewrite with new API supporting io.Readers --- rudp/conn.go | 173 +++++++++++++++++++++ rudp/connect.go | 18 +++ rudp/listen.go | 270 ++++++++++++++++++++------------- rudp/net.go | 41 ----- rudp/peer.go | 207 ------------------------- rudp/process.go | 253 ------------------------------- rudp/proxy/proxy.go | 39 ++--- rudp/recv.go | 259 ++++++++++++++++++++++++++++++++ rudp/rudp.go | 129 ++++++++-------- rudp/send.go | 358 +++++++++++++++++++++----------------------- rudp/udp.go | 13 ++ 11 files changed, 878 insertions(+), 882 deletions(-) create mode 100644 rudp/conn.go create mode 100644 rudp/connect.go delete mode 100644 rudp/net.go delete mode 100644 rudp/peer.go delete mode 100644 rudp/process.go create mode 100644 rudp/recv.go create mode 100644 rudp/udp.go diff --git a/rudp/conn.go b/rudp/conn.go new file mode 100644 index 0000000..7e241a8 --- /dev/null +++ b/rudp/conn.go @@ -0,0 +1,173 @@ +package rudp + +import ( + "net" + "sync" + "sync/atomic" + "time" +) + +// A Conn is a connection to a client or server. +// All Conn's methods are safe for concurrent use. +type Conn struct { + udpConn udpConn + + id PeerID + + pkts chan Pkt + errs chan error + + timeout *time.Timer + ping *time.Ticker + + closing uint32 + closed chan struct{} + err error + + mu sync.RWMutex + remoteID PeerID + + chans [ChannelCount]pktChan // read/write +} + +// ID returns the PeerID of the Conn. +func (c *Conn) ID() PeerID { return c.id } + +// IsSrv reports whether the Conn is a connection to a server. +func (c *Conn) IsSrv() bool { return c.ID() == PeerIDSrv } + +// Closed returns a channel which is closed when the Conn is closed. +func (c *Conn) Closed() <-chan struct{} { return c.closed } + +// WhyClosed returns the error that caused the Conn to be closed or nil +// if the Conn was closed using the Close method or by the peer. +// WhyClosed returns nil if the Conn is not closed. +func (c *Conn) WhyClosed() error { + select { + case <-c.Closed(): + return c.err + default: + return nil + } +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { return c.udpConn.LocalAddr() } + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { return c.udpConn.RemoteAddr() } + +type pktChan struct { + // Only accessed by Conn.recvUDPPkts goroutine. + inRels *[0x8000][]byte + inRelSN seqnum + sendAck func() (<-chan struct{}, error) + ackBuf []byte + + inSplitsMu sync.RWMutex + inSplits map[seqnum]*inSplit + + ackChans sync.Map // map[seqnum]chan struct{} + + outSplitMu sync.Mutex + outSplitSN seqnum + + outRelMu sync.Mutex + outRelSN seqnum + outRelWin seqnum +} + +type inSplit struct { + chunks [][]byte + got int + timeout *time.Timer +} + +// Close closes the Conn. +// Any blocked Send or Recv calls will return net.ErrClosed. +func (c *Conn) Close() error { + return c.closeDisco(nil) +} + +func (c *Conn) closeDisco(err error) error { + c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawCtl) + buf[1] = uint8(ctlDisco) + return 2 + }, PktInfo{Unrel: true})() + + return c.close(err) +} + +func (c *Conn) close(err error) error { + if atomic.SwapUint32(&c.closing, 1) == 1 { + return net.ErrClosed + } + + c.timeout.Stop() + c.ping.Stop() + + c.err = err + defer close(c.closed) + + return c.udpConn.Close() +} + +func newConn(uc udpConn, id, remoteID PeerID) *Conn { + var c *Conn + c = &Conn{ + udpConn: uc, + + id: id, + + pkts: make(chan Pkt), + errs: make(chan error), + + timeout: time.AfterFunc(ConnTimeout, func() { + c.closeDisco(ErrTimedOut) + }), + ping: time.NewTicker(PingTimeout), + + closed: make(chan struct{}), + + remoteID: remoteID, + } + + for i := range c.chans { + c.chans[i] = pktChan{ + inRels: new([0x8000][]byte), + inRelSN: initSeqnum, + + inSplits: make(map[seqnum]*inSplit), + + outSplitSN: initSeqnum, + + outRelSN: initSeqnum, + outRelWin: initSeqnum, + } + } + + c.newAckBuf() + + go c.sendPings(c.ping.C) + go c.recvUDPPkts() + + return c +} + +func (c *Conn) sendPings(ping <-chan time.Time) { + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawCtl) + buf[1] = uint8(ctlPing) + return 2 + }, PktInfo{}) + + for { + select { + case <-ping: + send() + case <-c.Closed(): + return + } + } +} diff --git a/rudp/connect.go b/rudp/connect.go new file mode 100644 index 0000000..548ab15 --- /dev/null +++ b/rudp/connect.go @@ -0,0 +1,18 @@ +package rudp + +import "net" + +type udpSrv struct { + net.Conn +} + +func (us udpSrv) recvUDP() ([]byte, error) { + buf := make([]byte, maxUDPPktSize) + n, err := us.Read(buf) + return buf[:n], err +} + +// Connect returns a Conn connected to conn. +func Connect(conn net.Conn) *Conn { + return newConn(udpSrv{conn}, PeerIDSrv, PeerIDNil) +} diff --git a/rudp/listen.go b/rudp/listen.go index 5b7154a..e1cacf4 100644 --- a/rudp/listen.go +++ b/rudp/listen.go @@ -2,155 +2,213 @@ package rudp import ( "errors" - "fmt" "net" "sync" ) +func tryClose(ch chan struct{}) (ok bool) { + defer func() { recover() }() + close(ch) + return true +} + +type udpClt struct { + l *Listener + id PeerID + addr net.Addr + pkts chan []byte + closed chan struct{} +} + +func (c *udpClt) mkConn() { + conn := newConn(c, c.id, PeerIDSrv) + go func() { + <-conn.Closed() + c.l.wg.Done() + }() + conn.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawCtl) + buf[1] = uint8(ctlSetPeerID) + be.PutUint16(buf[2:4], uint16(conn.ID())) + return 4 + }, PktInfo{})() + select { + case c.l.conns <- conn: + case <-c.l.closed: + conn.Close() + } +} + +func (c *udpClt) Write(pkt []byte) (int, error) { + select { + case <-c.closed: + return 0, net.ErrClosed + default: + } + + return c.l.pc.WriteTo(pkt, c.addr) +} + +func (c *udpClt) recvUDP() ([]byte, error) { + select { + case pkt := <-c.pkts: + return pkt, nil + case <-c.closed: + return nil, net.ErrClosed + } +} + +func (c *udpClt) Close() error { + if !tryClose(c.closed) { + return net.ErrClosed + } + + c.l.mu.Lock() + defer c.l.mu.Unlock() + + delete(c.l.ids, c.id) + delete(c.l.clts, c.addr.String()) + + return nil +} + +func (c *udpClt) LocalAddr() net.Addr { return c.l.pc.LocalAddr() } +func (c *udpClt) RemoteAddr() net.Addr { return c.addr } + +// All Listener's methods are safe for concurrent use. type Listener struct { - conn net.PacketConn + pc net.PacketConn - clts chan cltPeer - errs chan error + peerID PeerID + conns chan *Conn + errs chan error + closed chan struct{} + wg sync.WaitGroup - mu sync.Mutex - addr2peer map[string]cltPeer - id2peer map[PeerID]cltPeer - peerID PeerID + mu sync.RWMutex + ids map[PeerID]bool + clts map[string]*udpClt } -// Listen listens for packets on conn until it is closed. -func Listen(conn net.PacketConn) *Listener { +// Listen listens for connections on pc, pc is closed once the returned Listener +// and all Conns connected through it are closed. +func Listen(pc net.PacketConn) *Listener { l := &Listener{ - conn: conn, + pc: pc, - clts: make(chan cltPeer), - errs: make(chan error), + conns: make(chan *Conn), + closed: make(chan struct{}), - addr2peer: make(map[string]cltPeer), - id2peer: make(map[PeerID]cltPeer), + ids: make(map[PeerID]bool), + clts: make(map[string]*udpClt), } - pkts := make(chan netPkt) - go readNetPkts(l.conn, pkts, l.errs) go func() { - for pkt := range pkts { - if err := l.processNetPkt(pkt); err != nil { - l.errs <- err + for { + if err := l.processNetPkt(); err != nil { + if errors.Is(err, net.ErrClosed) { + break + } + select { + case l.errs <- err: + case <-l.closed: + } } } - - close(l.clts) - - for _, clt := range l.addr2peer { - clt.Close() - } }() return l } -// Accept waits for and returns a connecting Peer. -// You should keep calling this until it returns net.ErrClosed -// so it doesn't leak a goroutine. -func (l *Listener) Accept() (*Peer, error) { +// Accept waits for and returns the next incoming Conn or an error. +func (l *Listener) Accept() (*Conn, error) { select { - case clt, ok := <-l.clts: - if !ok { - select { - case err := <-l.errs: - return nil, err - default: - return nil, net.ErrClosed - } - } - close(clt.accepted) - return clt.Peer, nil + case c := <-l.conns: + return c, nil case err := <-l.errs: return nil, err + case <-l.closed: + return nil, net.ErrClosed } } -// Addr returns the net.PacketConn the Listener is listening on. -func (l *Listener) Conn() net.PacketConn { return l.conn } +// Close makes the Listener stop listening for new Conns. +// Blocked Accept calls will return net.ErrClosed. +// Already Accepted Conns are not closed. +func (l *Listener) Close() error { + if !tryClose(l.closed) { + return net.ErrClosed + } -var ErrOutOfPeerIDs = errors.New("out of peer ids") + go func() { + l.wg.Wait() + l.pc.Close() + }() -type cltPeer struct { - *Peer - pkts chan<- netPkt - accepted chan struct{} // close-only + return nil } -func (l *Listener) processNetPkt(pkt netPkt) error { - l.mu.Lock() - defer l.mu.Unlock() - - addrstr := pkt.SrcAddr.String() +// Addr returns the Listener's network address. +func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() } - clt, ok := l.addr2peer[addrstr] - if !ok { - prev := l.peerID - for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin { - if l.peerID == prev-1 { - return ErrOutOfPeerIDs - } - l.peerID++ - } +var ErrOutOfPeerIDs = errors.New("out of peer ids") - pkts := make(chan netPkt, 256) +func (l *Listener) processNetPkt() error { + buf := make([]byte, maxUDPPktSize) + n, addr, err := l.pc.ReadFrom(buf) + if err != nil { + return err + } - clt = cltPeer{ - Peer: newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv), - pkts: pkts, - accepted: make(chan struct{}), + l.mu.RLock() + clt, ok := l.clts[addr.String()] + l.mu.RUnlock() + if !ok { + select { + case <-l.closed: + return nil + default: } - l.addr2peer[addrstr] = clt - l.id2peer[clt.ID()] = clt - - data := make([]byte, 1+1+2) - data[0] = uint8(rawTypeCtl) - data[1] = uint8(ctlSetPeerID) - be.PutUint16(data[2:4], uint16(clt.ID())) - if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return fmt.Errorf("can't set client peer id: %w", err) + clt, err = l.add(addr) + if err != nil { + return err } + } - go func() { - select { - case l.clts <- clt: - case <-clt.Disco(): - } - - clt.processNetPkts(pkts) - }() + select { + case clt.pkts <- buf[:n]: + case <-clt.closed: + } - go func() { - <-clt.Disco() + return nil +} - l.mu.Lock() - close(pkts) - delete(l.addr2peer, addrstr) - delete(l.id2peer, clt.ID()) - l.mu.Unlock() - }() - } +func (l *Listener) add(addr net.Addr) (*udpClt, error) { + l.mu.Lock() + defer l.mu.Unlock() - select { - case <-clt.accepted: - clt.pkts <- pkt - default: - select { - case clt.pkts <- pkt: - default: - // It's OK to drop packets if the buffer is full - // because MT RUDP can cope with packet loss. + start := l.peerID + l.peerID++ + for l.peerID < PeerIDCltMin || l.ids[l.peerID] { + if l.peerID == start { + return nil, ErrOutOfPeerIDs } + l.peerID++ + } + l.ids[l.peerID] = true + + clt := &udpClt{ + l: l, + id: l.peerID, + addr: addr, + pkts: make(chan []byte), + closed: make(chan struct{}), } + l.clts[addr.String()] = clt - return nil + l.wg.Add(1) + go clt.mkConn() + + return clt, nil } diff --git a/rudp/net.go b/rudp/net.go deleted file mode 100644 index e2e7289..0000000 --- a/rudp/net.go +++ /dev/null @@ -1,41 +0,0 @@ -package rudp - -import ( - "errors" - "net" -) - -// ErrClosed is deprecated, use net.ErrClosed instead. -var ErrClosed = net.ErrClosed - -/* -netPkt.Data format (big endian): - - ProtoID - Src PeerID - ChNo uint8 // Must be < ChannelCount. - RawPkt.Data -*/ -type netPkt struct { - SrcAddr net.Addr - Data []byte -} - -func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) { - for { - buf := make([]byte, MaxNetPktSize) - n, addr, err := conn.ReadFrom(buf) - if err != nil { - if errors.Is(err, net.ErrClosed) { - break - } - - errs <- err - continue - } - - pkts <- netPkt{addr, buf[:n]} - } - - close(pkts) -} diff --git a/rudp/peer.go b/rudp/peer.go deleted file mode 100644 index 791249c..0000000 --- a/rudp/peer.go +++ /dev/null @@ -1,207 +0,0 @@ -package rudp - -import ( - "errors" - "fmt" - "net" - "sync" - "time" -) - -const ( - // ConnTimeout is the amount of time after no packets being received - // from a Peer that it is automatically disconnected. - ConnTimeout = 30 * time.Second - - // ConnTimeout is the amount of time after no packets being sent - // to a Peer that a CtlPing is automatically sent to prevent timeout. - PingTimeout = 5 * time.Second -) - -// A Peer is a connection to a client or server. -type Peer struct { - pc net.PacketConn - addr net.Addr - conn net.Conn - - disco chan struct{} // close-only - - id PeerID - - pkts chan Pkt - errs chan error // don't close - timedOut chan struct{} // close-only - - chans [ChannelCount]pktchan // read/write - - mu sync.RWMutex - idOfPeer PeerID - timeout *time.Timer - ping *time.Ticker -} - -// Conn returns the net.PacketConn used to communicate with the Peer. -func (p *Peer) Conn() net.PacketConn { return p.pc } - -// Addr returns the address of the Peer. -func (p *Peer) Addr() net.Addr { return p.addr } - -// Disco returns a channel that is closed when the Peer is closed. -func (p *Peer) Disco() <-chan struct{} { return p.disco } - -// ID returns the ID of the Peer. -func (p *Peer) ID() PeerID { return p.id } - -// IsSrv reports whether the Peer is a server. -func (p *Peer) IsSrv() bool { - return p.ID() == PeerIDSrv -} - -// TimedOut reports whether the Peer has timed out. -func (p *Peer) TimedOut() bool { - select { - case <-p.timedOut: - return true - default: - return false - } -} - -type inSplit struct { - chunks [][]byte - size, got int -} - -type pktchan struct { - // Only accessed by Peer.processRawPkt. - inSplit *[65536]*inSplit - inRel *[65536][]byte - inRelSN seqnum - - ackChans sync.Map // map[seqnum]chan struct{} - - outSplitMu sync.Mutex - outSplitSN seqnum - - outRelMu sync.Mutex - outRelSN seqnum - outRelWin seqnum -} - -// Recv recieves a packet from the Peer. -// You should keep calling this until it returns net.ErrClosed -// so it doesn't leak a goroutine. -func (p *Peer) Recv() (Pkt, error) { - select { - case pkt, ok := <-p.pkts: - if !ok { - select { - case err := <-p.errs: - return Pkt{}, err - default: - return Pkt{}, net.ErrClosed - } - } - return pkt, nil - case err := <-p.errs: - return Pkt{}, err - } -} - -// Close closes the Peer but does not send a disconnect packet. -func (p *Peer) Close() error { - p.mu.Lock() - defer p.mu.Unlock() - - select { - case <-p.Disco(): - return net.ErrClosed - default: - } - - p.timeout.Stop() - p.timeout = nil - p.ping.Stop() - p.ping = nil - - close(p.disco) - - return nil -} - -func newPeer(pc net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer { - p := &Peer{ - pc: pc, - addr: addr, - id: id, - idOfPeer: idOfPeer, - - pkts: make(chan Pkt), - disco: make(chan struct{}), - errs: make(chan error), - } - - if conn, ok := pc.(net.Conn); ok && conn.RemoteAddr() != nil { - p.conn = conn - } - - for i := range p.chans { - p.chans[i] = pktchan{ - inSplit: new([65536]*inSplit), - inRel: new([65536][]byte), - inRelSN: seqnumInit, - - outSplitSN: seqnumInit, - outRelSN: seqnumInit, - outRelWin: seqnumInit, - } - } - - p.timedOut = make(chan struct{}) - p.timeout = time.AfterFunc(ConnTimeout, func() { - close(p.timedOut) - - p.SendDisco(0, true) - p.Close() - }) - - p.ping = time.NewTicker(PingTimeout) - go p.sendPings(p.ping.C) - - return p -} - -func (p *Peer) sendPings(ping <-chan time.Time) { - pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}} - - for { - select { - case <-ping: - if _, err := p.sendRaw(pkt); err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - p.errs <- fmt.Errorf("can't send ping: %w", err) - } - case <-p.Disco(): - return - } - } -} - -// Connect connects to addr using pc -// and closes pc when the returned *Peer disconnects. -func Connect(pc net.PacketConn, addr net.Addr) *Peer { - srv := newPeer(pc, addr, PeerIDSrv, PeerIDNil) - - pkts := make(chan netPkt) - go readNetPkts(pc, pkts, srv.errs) - go srv.processNetPkts(pkts) - - go func() { - <-srv.Disco() - pc.Close() - }() - - return srv -} diff --git a/rudp/process.go b/rudp/process.go deleted file mode 100644 index 7238fe5..0000000 --- a/rudp/process.go +++ /dev/null @@ -1,253 +0,0 @@ -package rudp - -import ( - "errors" - "fmt" - "io" - "net" -) - -// A PktError is an error that occured while processing a packet. -type PktError struct { - Type string // "net", "raw" or "rel". - Data []byte - Err error -} - -func (e PktError) Error() string { - return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err) -} - -func (e PktError) Unwrap() error { return e.Err } - -func (p *Peer) processNetPkts(pkts <-chan netPkt) { - for pkt := range pkts { - if err := p.processNetPkt(pkt); err != nil { - p.errs <- PktError{"net", pkt.Data, err} - } - } - - close(p.pkts) -} - -// A TrailingDataError reports a packet with trailing data, -// it doesn't stop a packet from being processed. -type TrailingDataError []byte - -func (e TrailingDataError) Error() string { - return fmt.Sprintf("trailing data: %x", []byte(e)) -} - -func (p *Peer) processNetPkt(pkt netPkt) (err error) { - if pkt.SrcAddr.String() != p.Addr().String() { - return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String()) - } - - if len(pkt.Data) < MtHdrSize { - return io.ErrUnexpectedEOF - } - - if id := be.Uint32(pkt.Data[0:4]); id != protoID { - return fmt.Errorf("unsupported protocol id: 0x%08x", id) - } - - // src PeerID at pkt.Data[4:6] - - chno := pkt.Data[6] - if chno >= ChannelCount { - return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno) - } - - p.mu.RLock() - if p.timeout != nil { - p.timeout.Reset(ConnTimeout) - } - p.mu.RUnlock() - - rpkt := rawPkt{ - Data: pkt.Data[MtHdrSize:], - ChNo: chno, - Unrel: true, - } - if err := p.processRawPkt(rpkt); err != nil { - p.errs <- PktError{"raw", rpkt.Data, err} - } - - return nil -} - -func (p *Peer) processRawPkt(pkt rawPkt) (err error) { - errWrap := func(format string, a ...interface{}) { - if err != nil { - err = fmt.Errorf(format, append(a, err)...) - } - } - - c := &p.chans[pkt.ChNo] - - if len(pkt.Data) < 1 { - return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF) - } - switch t := rawType(pkt.Data[0]); t { - case rawTypeCtl: - defer errWrap("ctl: %w") - - if len(pkt.Data) < 1+1 { - return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF) - } - switch ct := ctlType(pkt.Data[1]); ct { - case ctlAck: - defer errWrap("ack: %w") - - if len(pkt.Data) < 1+1+2 { - return io.ErrUnexpectedEOF - } - - sn := seqnum(be.Uint16(pkt.Data[2:4])) - - if ack, ok := c.ackChans.LoadAndDelete(sn); ok { - close(ack.(chan struct{})) - } - - if len(pkt.Data) > 1+1+2 { - return TrailingDataError(pkt.Data[1+1+2:]) - } - case ctlSetPeerID: - defer errWrap("set peer id: %w") - - if len(pkt.Data) < 1+1+2 { - return io.ErrUnexpectedEOF - } - - // Ensure no concurrent senders while peer id changes. - p.mu.Lock() - if p.idOfPeer != PeerIDNil { - return errors.New("peer id already set") - } - - p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4])) - p.mu.Unlock() - - if len(pkt.Data) > 1+1+2 { - return TrailingDataError(pkt.Data[1+1+2:]) - } - case ctlPing: - defer errWrap("ping: %w") - - if len(pkt.Data) > 1+1 { - return TrailingDataError(pkt.Data[1+1:]) - } - case ctlDisco: - defer errWrap("disco: %w") - - p.Close() - - if len(pkt.Data) > 1+1 { - return TrailingDataError(pkt.Data[1+1:]) - } - default: - return fmt.Errorf("unsupported ctl type: %d", ct) - } - case rawTypeOrig: - p.pkts <- Pkt{ - Data: pkt.Data[1:], - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - } - case rawTypeSplit: - defer errWrap("split: %w") - - if len(pkt.Data) < 1+2+2+2 { - return io.ErrUnexpectedEOF - } - - sn := seqnum(be.Uint16(pkt.Data[1:3])) - count := be.Uint16(pkt.Data[3:5]) - i := be.Uint16(pkt.Data[5:7]) - - if i >= count { - return nil - } - - splits := p.chans[pkt.ChNo].inSplit - - // Delete old incomplete split packets - // so new ones don't get corrupted. - splits[sn-0x8000] = nil - - if splits[sn] == nil { - splits[sn] = &inSplit{chunks: make([][]byte, count)} - } - - s := splits[sn] - - if int(count) != len(s.chunks) { - return fmt.Errorf("chunk count changed on split packet: %d", sn) - } - - s.chunks[i] = pkt.Data[7:] - s.size += len(s.chunks[i]) - s.got++ - - if s.got == len(s.chunks) { - data := make([]byte, 0, s.size) - for _, chunk := range s.chunks { - data = append(data, chunk...) - } - - p.pkts <- Pkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - } - - splits[sn] = nil - } - case rawTypeRel: - defer errWrap("rel: %w") - - if len(pkt.Data) < 1+2 { - return io.ErrUnexpectedEOF - } - - sn := seqnum(be.Uint16(pkt.Data[1:3])) - - ack := make([]byte, 1+1+2) - ack[0] = uint8(rawTypeCtl) - ack[1] = uint8(ctlAck) - be.PutUint16(ack[2:4], uint16(sn)) - if _, err := p.sendRaw(rawPkt{ - Data: ack, - ChNo: pkt.ChNo, - Unrel: true, - }); err != nil { - if errors.Is(err, net.ErrClosed) { - return nil - } - return fmt.Errorf("can't ack %d: %w", sn, err) - } - - if sn-c.inRelSN >= 0x8000 { - return nil // Already received. - } - - c.inRel[sn] = pkt.Data[3:] - - for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ { - rpkt := rawPkt{ - Data: c.inRel[c.inRelSN], - ChNo: pkt.ChNo, - Unrel: false, - } - c.inRel[c.inRelSN] = nil - - if err := p.processRawPkt(rpkt); err != nil { - p.errs <- PktError{"rel", rpkt.Data, err} - } - } - default: - return fmt.Errorf("unsupported pkt type: %d", t) - } - - return nil -} diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go index a80b448..b8ca9f4 100644 --- a/rudp/proxy/proxy.go +++ b/rudp/proxy/proxy.go @@ -25,62 +25,55 @@ func main() { os.Exit(1) } - srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1]) + pc, err := net.ListenPacket("udp", os.Args[2]) if err != nil { log.Fatal(err) } + defer pc.Close() - lc, err := net.ListenPacket("udp", os.Args[2]) - if err != nil { - log.Fatal(err) - } - defer lc.Close() - - l := rudp.Listen(lc) + l := rudp.Listen(pc) for { clt, err := l.Accept() if err != nil { - log.Print(err) + log.Print("accept: ", err) continue } - log.Print(clt.Addr(), " connected") + log.Print(clt.ID(), ": connected") - conn, err := net.DialUDP("udp", nil, srvaddr) + conn, err := net.Dial("udp", os.Args[1]) if err != nil { log.Print(err) continue } - srv := rudp.Connect(conn, conn.RemoteAddr()) + srv := rudp.Connect(conn) go proxy(clt, srv) go proxy(srv, clt) } } -func proxy(src, dest *rudp.Peer) { +func proxy(src, dest *rudp.Conn) { + s := fmt.Sprint(src.ID(), " (", src.RemoteAddr(), "): ") + for { pkt, err := src.Recv() if err != nil { if errors.Is(err, net.ErrClosed) { - msg := src.Addr().String() + " disconnected" - if src.TimedOut() { - msg += " (timed out)" + if err := src.WhyClosed(); err != nil { + log.Print(s, "disconnected: ", err) + } else { + log.Print(s, "disconnected") } - log.Print(msg) - break } - log.Print(err) + log.Print(s, err) continue } - if _, err := dest.Send(pkt); err != nil { - log.Print(err) - } + dest.Send(pkt) } - dest.SendDisco(0, true) dest.Close() } diff --git a/rudp/recv.go b/rudp/recv.go new file mode 100644 index 0000000..f5ac236 --- /dev/null +++ b/rudp/recv.go @@ -0,0 +1,259 @@ +package rudp + +import ( + "bytes" + "errors" + "fmt" + "io" + "net" + "time" +) + +// Recv receives a Pkt from the Conn. +func (c *Conn) Recv() (Pkt, error) { + select { + case pkt := <-c.pkts: + return pkt, nil + case err := <-c.errs: + return Pkt{}, err + case <-c.Closed(): + return Pkt{}, net.ErrClosed + } +} + +func (c *Conn) gotPkt(pkt Pkt) { + select { + case c.pkts <- pkt: + case <-c.Closed(): + } +} + +func (c *Conn) gotErr(kind string, data []byte, err error) { + select { + case c.errs <- fmt.Errorf("%s: %x: %w", kind, data, err): + case <-c.Closed(): + } +} + +func (c *Conn) recvUDPPkts() { + for { + pkt, err := c.udpConn.recvUDP() + if err != nil { + c.closeDisco(err) + break + } + + if err := c.processUDPPkt(pkt); err != nil { + c.gotErr("udp", pkt, err) + } + } +} + +func (c *Conn) processUDPPkt(pkt []byte) error { + if c.timeout.Stop() { + c.timeout.Reset(ConnTimeout) + } + + if len(pkt) < 6 { + return io.ErrUnexpectedEOF + } + + if id := be.Uint32(pkt[0:4]); id != protoID { + return fmt.Errorf("unsupported protocol id: 0x%08x", id) + } + + ch := Channel(pkt[6]) + if ch >= ChannelCount { + return TooBigChError(ch) + } + + if err := c.processRawPkt(pkt[7:], PktInfo{Channel: ch, Unrel: true}); err != nil { + c.gotErr("raw", pkt, err) + } + + return nil +} + +// A TrailingDataError reports trailing data after a packet. +type TrailingDataError []byte + +func (e TrailingDataError) Error() string { + return fmt.Sprintf("trailing data: %x", []byte(e)) +} + +func (c *Conn) processRawPkt(data []byte, pi PktInfo) (err error) { + errWrap := func(format string, a ...interface{}) { + if err != nil { + err = fmt.Errorf(format+": %w", append(a, err)...) + } + } + + eof := new(byte) + defer func() { + switch r := recover(); r { + case nil: + case eof: + err = io.ErrUnexpectedEOF + default: + panic(r) + } + }() + + off := 0 + eat := func(n int) []byte { + i := off + off += n + if i > len(data) { + panic(eof) + } + return data[i:off] + } + + ch := &c.chans[pi.Channel] + + switch t := rawType(eat(1)[0]); t { + case rawCtl: + defer errWrap("ctl") + + switch ct := ctlType(eat(1)[0]); ct { + case ctlAck: + defer errWrap("ack") + + sn := seqnum(be.Uint16(eat(2))) + + if ack, ok := ch.ackChans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + case ctlSetPeerID: + defer errWrap("set peer id") + + c.mu.Lock() + if c.remoteID != PeerIDNil { + return errors.New("peer id already set") + } + + c.remoteID = PeerID(be.Uint16(eat(2))) + c.mu.Unlock() + + c.newAckBuf() + case ctlPing: + defer errWrap("ping") + case ctlDisco: + defer errWrap("disco") + + c.close(nil) + default: + return fmt.Errorf("unsupported ctl type: %d", ct) + } + + if off < len(data) { + return TrailingDataError(data[off:]) + } + case rawOrig: + c.gotPkt(Pkt{ + Reader: bytes.NewReader(data[off:]), + PktInfo: pi, + }) + case rawSplit: + defer errWrap("split") + + sn := seqnum(be.Uint16(eat(2))) + n := be.Uint16(eat(2)) + i := be.Uint16(eat(2)) + + defer errWrap("%d", sn) + + if i >= n { + return fmt.Errorf("chunk number (%d) > chunk count (%d)", i, n) + } + + ch.inSplitsMu.RLock() + s := ch.inSplits[sn] + ch.inSplitsMu.RUnlock() + + if s == nil { + s = &inSplit{chunks: make([][]byte, n)} + if pi.Unrel { + s.timeout = time.AfterFunc(ConnTimeout, func() { + ch.inSplitsMu.Lock() + delete(ch.inSplits, sn) + ch.inSplitsMu.Unlock() + }) + } + + ch.inSplitsMu.Lock() + ch.inSplits[sn] = s + ch.inSplitsMu.Unlock() + } + + if int(n) != len(s.chunks) { + return fmt.Errorf("chunk count changed from %d to %d", len(s.chunks), n) + } + + if s.chunks[i] == nil { + s.chunks[i] = data[off:] + s.got++ + } + + if s.got < len(s.chunks) { + if s.timeout != nil && s.timeout.Stop() { + s.timeout.Reset(ConnTimeout) + } + return + } + + if s.timeout != nil { + s.timeout.Stop() + } + + ch.inSplitsMu.Lock() + delete(ch.inSplits, sn) + ch.inSplitsMu.Unlock() + + c.gotPkt(Pkt{ + Reader: (*net.Buffers)(&s.chunks), + PktInfo: pi, + }) + case rawRel: + defer errWrap("rel") + + sn := seqnum(be.Uint16(eat(2))) + + defer errWrap("%d", sn) + + be.PutUint16(ch.ackBuf, uint16(sn)) + ch.sendAck() + + if sn-ch.inRelSN >= 0x8000 { + // Already received. + return nil + } + + ch.inRels[sn&0x7fff] = data[off:] + + i := func() seqnum { return ch.inRelSN & 0x7fff } + for ; ch.inRels[i()] != nil; ch.inRelSN++ { + data := ch.inRels[i()] + ch.inRels[i()] = nil + if err := c.processRawPkt(data, PktInfo{Channel: pi.Channel}); err != nil { + c.gotErr("rel", data, err) + } + } + default: + return fmt.Errorf("unsupported pkt type: %d", t) + } + + return nil +} + +func (c *Conn) newAckBuf() { + for i := range c.chans { + ch := &c.chans[i] + ch.sendAck = c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawCtl) + buf[1] = uint8(ctlAck) + ch.ackBuf = buf[2:4] + return 4 + }, PktInfo{Channel: Channel(i), Unrel: true}) + } +} diff --git a/rudp/rudp.go b/rudp/rudp.go index 6b96b56..faf67d9 100644 --- a/rudp/rudp.go +++ b/rudp/rudp.go @@ -1,21 +1,44 @@ /* Package rudp implements the low-level Minetest protocol described at https://dev.minetest.net/Network_Protocol#Low-level_protocol. - -All exported functions and methods in this package are safe for concurrent use -by multiple goroutines. */ package rudp -import "encoding/binary" +import ( + "encoding/binary" + "errors" + "io" + "time" +) var be = binary.BigEndian -// protoID must be at the start of every network packet. +/* +UDP packet format: + + protoID + src PeerID + channel uint8 + rawType... +*/ + +var ErrTimedOut = errors.New("timed out") + +const ( + ConnTimeout = 30 * time.Second + PingTimeout = 5 * time.Second +) + +const ( + MaxRelPktSize = 32439825 + MaxUnrelPktSize = 32636430 +) + +// protoID must be at the start of every UDP packet. const protoID uint32 = 0x4f457403 -// PeerIDs aren't actually used to identify peers, network addresses are, -// these just exist for backward compatability. +// PeerIDs aren't actually used to identify peers, IP addresses and ports are, +// these just exist for backward compatibility. type PeerID uint16 const ( @@ -29,79 +52,59 @@ const ( PeerIDCltMin ) -// ChannelCount is the maximum channel number + 1. -const ChannelCount = 3 - -/* -rawPkt.Data format (big endian): - - rawType - switch rawType { - case rawTypeCtl: - ctlType - switch ctlType { - case ctlAck: - // Tells peer you received a rawTypeRel - // and it doesn't need to resend it. - seqnum - case ctlSetPeerId: - // Tells peer to send packets with this Src PeerID. - PeerId - case ctlPing: - // Sent to prevent timeout. - case ctlDisco: - // Tells peer that you disconnected. - } - case rawTypeOrig: - Pkt.(Data) - case rawTypeSplit: - // Packet larger than MaxNetPktSize split into smaller packets. - // Packets with I >= Count should be ignored. - // Once all Count chunks are recieved, they are sorted by I and - // concatenated to make a Pkt.(Data). - seqnum // Identifies split packet. - Count, I uint16 - Chunk... - case rawTypeRel: - // Resent until a ctlAck with same seqnum is recieved. - // seqnums are sequencial and start at seqnumInit, - // These should be processed in seqnum order. - seqnum - rawPkt.Data - } -*/ -type rawPkt struct { - Data []byte - ChNo uint8 - Unrel bool -} - type rawType uint8 const ( - rawTypeCtl rawType = iota - rawTypeOrig - rawTypeSplit - rawTypeRel + rawCtl rawType = iota + // ctlType... + + rawOrig + // data... + + rawSplit + // seqnum + // n, i uint16 + // data... + + rawRel + // seqnum + // rawType... ) type ctlType uint8 const ( ctlAck ctlType = iota + // seqnum + ctlSetPeerID - ctlPing + // PeerID + + ctlPing // Sent to prevent timeout. + ctlDisco ) type Pkt struct { - Data []byte - ChNo uint8 + io.Reader + PktInfo +} + +// Reliable packets in a channel are be received in the order they are sent in. +// A Channel must be less than ChannelCount. +type Channel uint8 + +const ChannelCount Channel = 3 + +type PktInfo struct { + Channel + + // Unrel (unreliable) packets may be dropped, duplicated or reordered. Unrel bool } // seqnums are sequence numbers used to maintain reliable packet order -// and to identify split packets. +// and identify split packets. type seqnum uint16 -const seqnumInit seqnum = 65500 +const initSeqnum seqnum = 65500 diff --git a/rudp/send.go b/rudp/send.go index c43d056..5522ee0 100644 --- a/rudp/send.go +++ b/rudp/send.go @@ -1,241 +1,221 @@ package rudp import ( + "bytes" "errors" "fmt" - "math" + "io" "net" "sync" + "sync/atomic" "time" ) -const ( - // protoID + src PeerID + channel number - MtHdrSize = 4 + 2 + 1 - - // rawTypeOrig - OrigHdrSize = 1 - - // rawTypeSpilt + seqnum + chunk count + chunk number - SplitHdrSize = 1 + 2 + 2 + 2 - - // rawTypeRel + seqnum - RelHdrSize = 1 + 2 -) - -const ( - MaxNetPktSize = 512 - - MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize - MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize - - MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16 - MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16 -) - var ErrPktTooBig = errors.New("can't send pkt: too big") -var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount") - -// Send sends a packet to the Peer. -// It returns a channel that's closed when all chunks are acked or an error. -// The ack channel is nil if pkt.Unrel is true. -func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { - if pkt.ChNo >= ChannelCount { - return nil, ErrChNoTooBig - } - - hdrSize := MtHdrSize - if !pkt.Unrel { - hdrSize += RelHdrSize - } - - if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize { - c := &p.chans[pkt.ChNo] - - c.outSplitMu.Lock() - sn := c.outSplitSN - c.outSplitSN++ - c.outSplitMu.Unlock() - chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize)) +// A TooBigChError reports a Channel greater than or equal to ChannelCount. +type TooBigChError Channel - if len(chunks) > math.MaxUint16 { - return nil, ErrPktTooBig - } +func (e TooBigChError) Error() string { + return fmt.Sprintf("channel >= ChannelCount (%d): %d", ChannelCount, e) +} - var wg sync.WaitGroup +// Send sends a Pkt to the Conn. +// Ack is closed when the packet is acknowledged. +// Ack is nil if pkt.Unrel is true or err != nil. +func (c *Conn) Send(pkt Pkt) (ack <-chan struct{}, err error) { + if pkt.Channel >= ChannelCount { + return nil, TooBigChError(pkt.Channel) + } - for i, chunk := range chunks { - data := make([]byte, SplitHdrSize+len(chunk)) - data[0] = uint8(rawTypeSplit) - be.PutUint16(data[1:3], uint16(sn)) - be.PutUint16(data[3:5], uint16(len(chunks))) - be.PutUint16(data[5:7], uint16(i)) - copy(data[SplitHdrSize:], chunk) + var e error + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawOrig) - wg.Add(1) - ack, err := p.sendRaw(rawPkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - }) + nn := 1 + for nn < len(buf) { + n, err := pkt.Read(buf[nn:]) + nn += n if err != nil { - return nil, err - } - if !pkt.Unrel { - go func() { - <-ack - wg.Done() - }() + e = err + return nn } } - if pkt.Unrel { - return nil, nil - } else { - ack := make(chan struct{}) - - go func() { - wg.Wait() - close(ack) - }() + if _, e = pkt.Read(nil); e != nil { + return nn + } - return ack, nil + pkt.Reader = io.MultiReader( + bytes.NewReader([]byte(buf[1:nn])), + pkt.Reader, + ) + return nn + }, pkt.PktInfo) + if e != nil { + if e == io.EOF { + return send() } + return nil, e } - return p.sendRaw(rawPkt{ - Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...), - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, - }) -} + var ( + sn seqnum + i uint16 + + sends []func() (<-chan struct{}, error) + ) + + for { + var ( + b []byte + e error + ) + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawSplit) + + n, err := io.ReadFull(pkt, buf[7:]) + if err != nil && err != io.ErrUnexpectedEOF { + e = err + return 0 + } -// sendRaw sends a raw packet to the Peer. -func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { - if pkt.ChNo >= ChannelCount { - return nil, ErrChNoTooBig - } + be.PutUint16(buf[5:7], i) + if i++; i == 0 { + e = ErrPktTooBig + return 0 + } - p.mu.RLock() - defer p.mu.RUnlock() + b = buf + return 7 + n + }, pkt.PktInfo) + if e != nil { + if e == io.EOF { + break + } + return nil, e + } - select { - case <-p.Disco(): - return nil, net.ErrClosed - default: + sends = append(sends, func() (<-chan struct{}, error) { + be.PutUint16(b[1:3], uint16(sn)) + be.PutUint16(b[3:5], i) + return send() + }) } - if !pkt.Unrel { - return p.sendRel(pkt) - } + ch := &c.chans[pkt.Channel] - data := make([]byte, MtHdrSize+len(pkt.Data)) - be.PutUint32(data[0:4], protoID) - be.PutUint16(data[4:6], uint16(p.idOfPeer)) - data[6] = pkt.ChNo - copy(data[MtHdrSize:], pkt.Data) + ch.outSplitMu.Lock() + sn = ch.outSplitSN + ch.outSplitSN++ + ch.outSplitMu.Unlock() - if len(data) > MaxNetPktSize { - return nil, ErrPktTooBig - } + var wg sync.WaitGroup - if p.conn != nil { - _, err = p.conn.Write(data) - } else { - _, err = p.pc.WriteTo(data, p.Addr()) - } - if err != nil { - return nil, err + for _, send := range sends { + ack, err := send() + if err != nil { + return nil, err + } + if !pkt.Unrel { + wg.Add(1) + go func() { + <-ack + wg.Done() + }() + } } - p.ping.Reset(PingTimeout) + if !pkt.Unrel { + ack := make(chan struct{}) + go func() { + wg.Wait() + close(ack) + }() + return ack, nil + } return nil, nil } -// sendRel sends a reliable raw packet to the Peer. -func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) { - if pkt.Unrel { - panic("pkt.Unrel is true") - } - - c := &p.chans[pkt.ChNo] +func (c *Conn) sendRaw(read func([]byte) int, pi PktInfo) func() (<-chan struct{}, error) { + if pi.Unrel { + buf := make([]byte, maxUDPPktSize) + be.PutUint32(buf[0:4], protoID) + c.mu.RLock() + be.PutUint16(buf[4:6], uint16(c.remoteID)) + c.mu.RUnlock() + buf[6] = uint8(pi.Channel) + buf = buf[:7+read(buf[7:])] + + return func() (<-chan struct{}, error) { + if _, err := c.udpConn.Write(buf); err != nil { + c.close(err) + return nil, net.ErrClosed + } - c.outRelMu.Lock() - defer c.outRelMu.Unlock() + c.ping.Reset(PingTimeout) + if atomic.LoadUint32(&c.closing) == 1 { + c.ping.Stop() + } - sn := c.outRelSN - for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ { - if ack, ok := c.ackChans.Load(c.outRelWin); ok { - <-ack.(chan struct{}) + return nil, nil } } - c.outRelSN++ - - rwack := make(chan struct{}) // close-only - c.ackChans.Store(sn, rwack) - ack = rwack - - data := make([]byte, RelHdrSize+len(pkt.Data)) - data[0] = uint8(rawTypeRel) - be.PutUint16(data[1:3], uint16(sn)) - copy(data[RelHdrSize:], pkt.Data) - rel := rawPkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: true, - } - - if _, err := p.sendRaw(rel); err != nil { - c.ackChans.Delete(sn) - return nil, err - } - - go func() { - for { - select { - case <-time.After(500 * time.Millisecond): - if _, err := p.sendRaw(rel); err != nil { - if errors.Is(err, net.ErrClosed) { - return - } - p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err) + pi.Unrel = true + var snBuf []byte + send := c.sendRaw(func(buf []byte) int { + buf[0] = uint8(rawRel) + snBuf = buf[1:3] + return 3 + read(buf[3:]) + }, pi) + + return func() (<-chan struct{}, error) { + ch := &c.chans[pi.Channel] + + ch.outRelMu.Lock() + defer ch.outRelMu.Unlock() + + sn := ch.outRelSN + be.PutUint16(snBuf, uint16(sn)) + for ; sn-ch.outRelWin >= 0x8000; ch.outRelWin++ { + if ack, ok := ch.ackChans.Load(ch.outRelWin); ok { + select { + case <-ack.(chan struct{}): + case <-c.Closed(): } - case <-ack: - return - case <-p.Disco(): - return } } - }() - - return ack, nil -} - -// SendDisco sends a disconnect packet to the Peer but does not close it. -// It returns a channel that's closed when it's acked or an error. -// The ack channel is nil if unrel is true. -func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) { - return p.sendRaw(rawPkt{ - Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)}, - ChNo: chno, - Unrel: unrel, - }) -} -func split(data []byte, chunksize int) [][]byte { - chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize) + ack := make(chan struct{}) + ch.ackChans.Store(sn, ack) - for i := 0; i < len(data); i += chunksize { - end := i + chunksize - if end > len(data) { - end = len(data) + if _, err := send(); err != nil { + if ack, ok := ch.ackChans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + return nil, err } + ch.outRelSN++ + + go func() { + t := time.NewTimer(500 * time.Millisecond) + defer t.Stop() + + for { + select { + case <-ack: + return + case <-t.C: + send() + t.Reset(500 * time.Millisecond) + case <-c.Closed(): + return + } + } + }() - chunks = append(chunks, data[i:end]) + return ack, nil } - - return chunks } diff --git a/rudp/udp.go b/rudp/udp.go new file mode 100644 index 0000000..503f1d4 --- /dev/null +++ b/rudp/udp.go @@ -0,0 +1,13 @@ +package rudp + +import "net" + +const maxUDPPktSize = 512 + +type udpConn interface { + recvUDP() ([]byte, error) + Write([]byte) (int, error) + Close() error + LocalAddr() net.Addr + RemoteAddr() net.Addr +} -- 2.44.0