From d37e0a8d7d545e9e6d6840d8ffd873f947c1a541 Mon Sep 17 00:00:00 2001 From: anon5 Date: Wed, 3 Mar 2021 17:33:07 +0000 Subject: [PATCH] rudp: optimize and refactor --- rudp/listen.go | 15 +++---- rudp/peer.go | 82 +++++++++++++++++++---------------- rudp/process.go | 87 ++++++++++++++++++-------------------- rudp/proxy/proxy.go | 2 +- rudp/{proto.go => rudp.go} | 4 ++ rudp/send.go | 75 +++++++++++++++----------------- 6 files changed, 133 insertions(+), 132 deletions(-) rename rudp/{proto.go => rudp.go} (97%) diff --git a/rudp/listen.go b/rudp/listen.go index 2d702c4..5b7154a 100644 --- a/rudp/listen.go +++ b/rudp/listen.go @@ -1,7 +1,6 @@ package rudp import ( - "encoding/binary" "errors" "fmt" "net" @@ -17,7 +16,7 @@ type Listener struct { mu sync.Mutex addr2peer map[string]cltPeer id2peer map[PeerID]cltPeer - peerid PeerID + peerID PeerID } // Listen listens for packets on conn until it is closed. @@ -91,18 +90,18 @@ func (l *Listener) processNetPkt(pkt netPkt) error { clt, ok := l.addr2peer[addrstr] if !ok { - prev := l.peerid - for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin { - if l.peerid == prev-1 { + prev := l.peerID + for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin { + if l.peerID == prev-1 { return ErrOutOfPeerIDs } - l.peerid++ + l.peerID++ } pkts := make(chan netPkt, 256) clt = cltPeer{ - Peer: newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv), + Peer: newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv), pkts: pkts, accepted: make(chan struct{}), } @@ -113,7 +112,7 @@ func (l *Listener) processNetPkt(pkt netPkt) error { data := make([]byte, 1+1+2) data[0] = uint8(rawTypeCtl) data[1] = uint8(ctlSetPeerID) - binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID())) + be.PutUint16(data[2:4], uint16(clt.ID())) if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil { if errors.Is(err, net.ErrClosed) { return nil diff --git a/rudp/peer.go b/rudp/peer.go index 4d8df47..791249c 100644 --- a/rudp/peer.go +++ b/rudp/peer.go @@ -20,8 +20,9 @@ const ( // A Peer is a connection to a client or server. type Peer struct { - conn net.PacketConn + pc net.PacketConn addr net.Addr + conn net.Conn disco chan struct{} // close-only @@ -29,7 +30,7 @@ type Peer struct { pkts chan Pkt errs chan error // don't close - timedout chan struct{} // close-only + timedOut chan struct{} // close-only chans [ChannelCount]pktchan // read/write @@ -39,24 +40,8 @@ type Peer struct { ping *time.Ticker } -type pktchan struct { - // Only accessed by Peer.processRawPkt. - insplit map[seqnum][][]byte - inrel map[seqnum][]byte - inrelsn seqnum - - ackchans sync.Map // map[seqnum]chan struct{} - - outsplitmu sync.Mutex - outsplitsn seqnum - - outrelmu sync.Mutex - outrelsn seqnum - outrelwin seqnum -} - // Conn returns the net.PacketConn used to communicate with the Peer. -func (p *Peer) Conn() net.PacketConn { return p.conn } +func (p *Peer) Conn() net.PacketConn { return p.pc } // Addr returns the address of the Peer. func (p *Peer) Addr() net.Addr { return p.addr } @@ -75,13 +60,34 @@ func (p *Peer) IsSrv() bool { // TimedOut reports whether the Peer has timed out. func (p *Peer) TimedOut() bool { select { - case <-p.timedout: + case <-p.timedOut: return true default: return false } } +type inSplit struct { + chunks [][]byte + size, got int +} + +type pktchan struct { + // Only accessed by Peer.processRawPkt. + inSplit *[65536]*inSplit + inRel *[65536][]byte + inRelSN seqnum + + ackChans sync.Map // map[seqnum]chan struct{} + + outSplitMu sync.Mutex + outSplitSN seqnum + + outRelMu sync.Mutex + outRelSN seqnum + outRelWin seqnum +} + // Recv recieves a packet from the Peer. // You should keep calling this until it returns net.ErrClosed // so it doesn't leak a goroutine. @@ -123,9 +129,9 @@ func (p *Peer) Close() error { return nil } -func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer { +func newPeer(pc net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer { p := &Peer{ - conn: conn, + pc: pc, addr: addr, id: id, idOfPeer: idOfPeer, @@ -135,21 +141,25 @@ func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer { errs: make(chan error), } + if conn, ok := pc.(net.Conn); ok && conn.RemoteAddr() != nil { + p.conn = conn + } + for i := range p.chans { p.chans[i] = pktchan{ - insplit: make(map[seqnum][][]byte), - inrel: make(map[seqnum][]byte), - inrelsn: seqnumInit, + inSplit: new([65536]*inSplit), + inRel: new([65536][]byte), + inRelSN: seqnumInit, - outsplitsn: seqnumInit, - outrelsn: seqnumInit, - outrelwin: seqnumInit, + outSplitSN: seqnumInit, + outRelSN: seqnumInit, + outRelWin: seqnumInit, } } - p.timedout = make(chan struct{}) + p.timedOut = make(chan struct{}) p.timeout = time.AfterFunc(ConnTimeout, func() { - close(p.timedout) + close(p.timedOut) p.SendDisco(0, true) p.Close() @@ -179,18 +189,18 @@ func (p *Peer) sendPings(ping <-chan time.Time) { } } -// Connect connects to the server on conn -// and closes conn when the returned *Peer disconnects. -func Connect(conn net.PacketConn, addr net.Addr) *Peer { - srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil) +// Connect connects to addr using pc +// and closes pc when the returned *Peer disconnects. +func Connect(pc net.PacketConn, addr net.Addr) *Peer { + srv := newPeer(pc, addr, PeerIDSrv, PeerIDNil) pkts := make(chan netPkt) - go readNetPkts(conn, pkts, srv.errs) + go readNetPkts(pc, pkts, srv.errs) go srv.processNetPkts(pkts) go func() { <-srv.Disco() - conn.Close() + pc.Close() }() return srv diff --git a/rudp/process.go b/rudp/process.go index c85aba4..7238fe5 100644 --- a/rudp/process.go +++ b/rudp/process.go @@ -1,7 +1,6 @@ package rudp import ( - "encoding/binary" "errors" "fmt" "io" @@ -48,7 +47,7 @@ func (p *Peer) processNetPkt(pkt netPkt) (err error) { return io.ErrUnexpectedEOF } - if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID { + if id := be.Uint32(pkt.Data[0:4]); id != protoID { return fmt.Errorf("unsupported protocol id: 0x%08x", id) } @@ -104,9 +103,9 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) { return io.ErrUnexpectedEOF } - sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4])) + sn := seqnum(be.Uint16(pkt.Data[2:4])) - if ack, ok := c.ackchans.LoadAndDelete(sn); ok { + if ack, ok := c.ackChans.LoadAndDelete(sn); ok { close(ack.(chan struct{})) } @@ -126,7 +125,7 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) { return errors.New("peer id already set") } - p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4])) + p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4])) p.mu.Unlock() if len(pkt.Data) > 1+1+2 { @@ -162,50 +161,48 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) { return io.ErrUnexpectedEOF } - sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) - count := binary.BigEndian.Uint16(pkt.Data[3:5]) - i := binary.BigEndian.Uint16(pkt.Data[5:7]) + sn := seqnum(be.Uint16(pkt.Data[1:3])) + count := be.Uint16(pkt.Data[3:5]) + i := be.Uint16(pkt.Data[5:7]) if i >= count { return nil } - splitpkts := p.chans[pkt.ChNo].insplit + splits := p.chans[pkt.ChNo].inSplit // Delete old incomplete split packets // so new ones don't get corrupted. - delete(splitpkts, sn-0x8000) + splits[sn-0x8000] = nil - if splitpkts[sn] == nil { - splitpkts[sn] = make([][]byte, count) + if splits[sn] == nil { + splits[sn] = &inSplit{chunks: make([][]byte, count)} } - chunks := splitpkts[sn] + s := splits[sn] - if int(count) != len(chunks) { - return fmt.Errorf("chunk count changed on seqnum: %d", sn) + if int(count) != len(s.chunks) { + return fmt.Errorf("chunk count changed on split packet: %d", sn) } - chunks[i] = pkt.Data[7:] + s.chunks[i] = pkt.Data[7:] + s.size += len(s.chunks[i]) + s.got++ - for _, chunk := range chunks { - if chunk == nil { - return nil + if s.got == len(s.chunks) { + data := make([]byte, 0, s.size) + for _, chunk := range s.chunks { + data = append(data, chunk...) } - } - var data []byte - for _, chunk := range chunks { - data = append(data, chunk...) - } + p.pkts <- Pkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + } - p.pkts <- Pkt{ - Data: data, - ChNo: pkt.ChNo, - Unrel: pkt.Unrel, + splits[sn] = nil } - - delete(splitpkts, sn) case rawTypeRel: defer errWrap("rel: %w") @@ -213,39 +210,37 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) { return io.ErrUnexpectedEOF } - sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) + sn := seqnum(be.Uint16(pkt.Data[1:3])) - ackdata := make([]byte, 1+1+2) - ackdata[0] = uint8(rawTypeCtl) - ackdata[1] = uint8(ctlAck) - binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn)) - ack := rawPkt{ - Data: ackdata, + ack := make([]byte, 1+1+2) + ack[0] = uint8(rawTypeCtl) + ack[1] = uint8(ctlAck) + be.PutUint16(ack[2:4], uint16(sn)) + if _, err := p.sendRaw(rawPkt{ + Data: ack, ChNo: pkt.ChNo, Unrel: true, - } - if _, err := p.sendRaw(ack); err != nil { + }); err != nil { if errors.Is(err, net.ErrClosed) { return nil } return fmt.Errorf("can't ack %d: %w", sn, err) } - if sn-c.inrelsn >= 0x8000 { + if sn-c.inRelSN >= 0x8000 { return nil // Already received. } - c.inrel[sn] = pkt.Data[3:] - - for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ { - data := c.inrel[c.inrelsn] - delete(c.inrel, c.inrelsn) + c.inRel[sn] = pkt.Data[3:] + for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ { rpkt := rawPkt{ - Data: data, + Data: c.inRel[c.inRelSN], ChNo: pkt.ChNo, Unrel: false, } + c.inRel[c.inRelSN] = nil + if err := p.processRawPkt(rpkt); err != nil { p.errs <- PktError{"rel", rpkt.Data, err} } diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go index a80b448..05d8a66 100644 --- a/rudp/proxy/proxy.go +++ b/rudp/proxy/proxy.go @@ -16,7 +16,7 @@ import ( "net" "os" - "github.com/anon55555/mt/rudp" + "mt/rudp" ) func main() { diff --git a/rudp/proto.go b/rudp/rudp.go similarity index 97% rename from rudp/proto.go rename to rudp/rudp.go index 04176b2..6b96b56 100644 --- a/rudp/proto.go +++ b/rudp/rudp.go @@ -7,6 +7,10 @@ by multiple goroutines. */ package rudp +import "encoding/binary" + +var be = binary.BigEndian + // protoID must be at the start of every network packet. const protoID uint32 = 0x4f457403 diff --git a/rudp/send.go b/rudp/send.go index ce3f013..c43d056 100644 --- a/rudp/send.go +++ b/rudp/send.go @@ -1,7 +1,6 @@ package rudp import ( - "encoding/binary" "errors" "fmt" "math" @@ -45,20 +44,20 @@ func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { return nil, ErrChNoTooBig } - hdrsize := MtHdrSize + hdrSize := MtHdrSize if !pkt.Unrel { - hdrsize += RelHdrSize + hdrSize += RelHdrSize } - if hdrsize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize { + if hdrSize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize { c := &p.chans[pkt.ChNo] - c.outsplitmu.Lock() - sn := c.outsplitsn - c.outsplitsn++ - c.outsplitmu.Unlock() + c.outSplitMu.Lock() + sn := c.outSplitSN + c.outSplitSN++ + c.outSplitMu.Unlock() - chunks := split(pkt.Data, MaxNetPktSize-(hdrsize+SplitHdrSize)) + chunks := split(pkt.Data, MaxNetPktSize-(hdrSize+SplitHdrSize)) if len(chunks) > math.MaxUint16 { return nil, ErrPktTooBig @@ -69,9 +68,9 @@ func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { for i, chunk := range chunks { data := make([]byte, SplitHdrSize+len(chunk)) data[0] = uint8(rawTypeSplit) - binary.BigEndian.PutUint16(data[1:3], uint16(sn)) - binary.BigEndian.PutUint16(data[3:5], uint16(len(chunks))) - binary.BigEndian.PutUint16(data[5:7], uint16(i)) + be.PutUint16(data[1:3], uint16(sn)) + be.PutUint16(data[3:5], uint16(len(chunks))) + be.PutUint16(data[5:7], uint16(i)) copy(data[SplitHdrSize:], chunk) wg.Add(1) @@ -84,9 +83,6 @@ func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { return nil, err } if !pkt.Unrel { - if ack == nil { - panic("ack is nil") - } go func() { <-ack wg.Done() @@ -135,8 +131,8 @@ func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { } data := make([]byte, MtHdrSize+len(pkt.Data)) - binary.BigEndian.PutUint32(data[0:4], protoID) - binary.BigEndian.PutUint16(data[4:6], uint16(p.idOfPeer)) + be.PutUint32(data[0:4], protoID) + be.PutUint16(data[4:6], uint16(p.idOfPeer)) data[6] = pkt.ChNo copy(data[MtHdrSize:], pkt.Data) @@ -144,13 +140,10 @@ func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { return nil, ErrPktTooBig } - _, err = p.Conn().WriteTo(data, p.Addr()) - if errors.Is(err, net.ErrWriteToConnected) { - conn, ok := p.Conn().(net.Conn) - if !ok { - return nil, err - } - _, err = conn.Write(data) + if p.conn != nil { + _, err = p.conn.Write(data) + } else { + _, err = p.pc.WriteTo(data, p.Addr()) } if err != nil { return nil, err @@ -164,38 +157,38 @@ func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { // sendRel sends a reliable raw packet to the Peer. func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) { if pkt.Unrel { - panic("mt/rudp: sendRel: pkt.Unrel is true") + panic("pkt.Unrel is true") } c := &p.chans[pkt.ChNo] - c.outrelmu.Lock() - defer c.outrelmu.Unlock() + c.outRelMu.Lock() + defer c.outRelMu.Unlock() - sn := c.outrelsn - for ; sn-c.outrelwin >= 0x8000; c.outrelwin++ { - if ack, ok := c.ackchans.Load(c.outrelwin); ok { + sn := c.outRelSN + for ; sn-c.outRelWin >= 0x8000; c.outRelWin++ { + if ack, ok := c.ackChans.Load(c.outRelWin); ok { <-ack.(chan struct{}) } } - c.outrelsn++ + c.outRelSN++ rwack := make(chan struct{}) // close-only - c.ackchans.Store(sn, rwack) + c.ackChans.Store(sn, rwack) ack = rwack - reldata := make([]byte, RelHdrSize+len(pkt.Data)) - reldata[0] = uint8(rawTypeRel) - binary.BigEndian.PutUint16(reldata[1:3], uint16(sn)) - copy(reldata[RelHdrSize:], pkt.Data) - relpkt := rawPkt{ - Data: reldata, + data := make([]byte, RelHdrSize+len(pkt.Data)) + data[0] = uint8(rawTypeRel) + be.PutUint16(data[1:3], uint16(sn)) + copy(data[RelHdrSize:], pkt.Data) + rel := rawPkt{ + Data: data, ChNo: pkt.ChNo, Unrel: true, } - if _, err := p.sendRaw(relpkt); err != nil { - c.ackchans.Delete(sn) + if _, err := p.sendRaw(rel); err != nil { + c.ackChans.Delete(sn) return nil, err } @@ -204,7 +197,7 @@ func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) { for { select { case <-time.After(500 * time.Millisecond): - if _, err := p.sendRaw(relpkt); err != nil { + if _, err := p.sendRaw(rel); err != nil { if errors.Is(err, net.ErrClosed) { return } -- 2.44.0