From a865d2bce1aa097273fdb9d0d02d9cfa8460aefd Mon Sep 17 00:00:00 2001 From: anon5 Date: Sat, 7 Nov 2020 18:01:24 +0000 Subject: [PATCH] Initial public release --- LICENSE | 21 ++++ rudp/listen.go | 153 ++++++++++++++++++++++++++ rudp/net.go | 44 ++++++++ rudp/peer.go | 193 +++++++++++++++++++++++++++++++++ rudp/process.go | 259 ++++++++++++++++++++++++++++++++++++++++++++ rudp/proto.go | 103 ++++++++++++++++++ rudp/proxy/proxy.go | 85 +++++++++++++++ rudp/send.go | 248 ++++++++++++++++++++++++++++++++++++++++++ 8 files changed, 1106 insertions(+) create mode 100644 LICENSE create mode 100644 rudp/listen.go create mode 100644 rudp/net.go create mode 100644 rudp/peer.go create mode 100644 rudp/process.go create mode 100644 rudp/proto.go create mode 100644 rudp/proxy/proxy.go create mode 100644 rudp/send.go diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..3febf69 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 anon5 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/rudp/listen.go b/rudp/listen.go new file mode 100644 index 0000000..871a591 --- /dev/null +++ b/rudp/listen.go @@ -0,0 +1,153 @@ +package rudp + +import ( + "encoding/binary" + "errors" + "fmt" + "net" + "sync" +) + +type Listener struct { + conn net.PacketConn + + clts chan cltPeer + errs chan error + + mu sync.Mutex + addr2peer map[string]cltPeer + id2peer map[PeerID]cltPeer + peerid PeerID +} + +// Listen listens for packets on conn until it is closed. +func Listen(conn net.PacketConn) *Listener { + l := &Listener{ + conn: conn, + + clts: make(chan cltPeer), + errs: make(chan error), + + addr2peer: make(map[string]cltPeer), + id2peer: make(map[PeerID]cltPeer), + } + + pkts := make(chan netPkt) + go readNetPkts(l.conn, pkts, l.errs) + go func() { + for pkt := range pkts { + if err := l.processNetPkt(pkt); err != nil { + l.errs <- err + } + } + + close(l.clts) + + for _, clt := range l.addr2peer { + clt.Close() + } + }() + + return l +} + +// Accept waits for and returns a connecting Peer. +// You should keep calling this until it returns ErrClosed +// so it doesn't leak a goroutine. +func (l *Listener) Accept() (*Peer, error) { + select { + case clt, ok := <-l.clts: + if !ok { + select { + case err := <-l.errs: + return nil, err + default: + return nil, ErrClosed + } + } + close(clt.accepted) + return clt.Peer, nil + case err := <-l.errs: + return nil, err + } +} + +// Addr returns the net.PacketConn the Listener is listening on. +func (l *Listener) Conn() net.PacketConn { return l.conn } + +var ErrOutOfPeerIDs = errors.New("out of peer ids") + +type cltPeer struct { + *Peer + pkts chan<- netPkt + accepted chan struct{} // close-only +} + +func (l *Listener) processNetPkt(pkt netPkt) error { + l.mu.Lock() + defer l.mu.Unlock() + + addrstr := pkt.SrcAddr.String() + + clt, ok := l.addr2peer[addrstr] + if !ok { + prev := l.peerid + for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin { + if l.peerid == prev-1 { + return ErrOutOfPeerIDs + } + l.peerid++ + } + + pkts := make(chan netPkt, 256) + + clt = cltPeer{ + Peer: newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv), + pkts: pkts, + accepted: make(chan struct{}), + } + + l.addr2peer[addrstr] = clt + l.id2peer[clt.ID()] = clt + + data := make([]byte, 1+1+2) + data[0] = uint8(rawTypeCtl) + data[1] = uint8(ctlSetPeerID) + binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID())) + if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil { + return fmt.Errorf("can't set client peer id: %w", err) + } + + go func() { + select { + case l.clts <- clt: + case <-clt.Disco(): + } + + clt.processNetPkts(pkts) + }() + + go func() { + <-clt.Disco() + + l.mu.Lock() + close(pkts) + delete(l.addr2peer, addrstr) + delete(l.id2peer, clt.ID()) + l.mu.Unlock() + }() + } + + select { + case <-clt.accepted: + clt.pkts <- pkt + default: + select { + case clt.pkts <- pkt: + default: + return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr) + } + } + + return nil +} diff --git a/rudp/net.go b/rudp/net.go new file mode 100644 index 0000000..421a3e7 --- /dev/null +++ b/rudp/net.go @@ -0,0 +1,44 @@ +package rudp + +import ( + "errors" + "net" + "strings" +) + +// TODO: Use net.ErrClosed when Go 1.16 is released. +var ErrClosed = errors.New("use of closed peer") + +/* +netPkt.Data format (big endian): + + ProtoID + Src PeerID + ChNo uint8 // Must be < ChannelCount. + RawPkt.Data +*/ +type netPkt struct { + SrcAddr net.Addr + Data []byte +} + +func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) { + for { + buf := make([]byte, MaxNetPktSize) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + // TODO: Change to this when Go 1.16 is released: + // if errors.Is(err, net.ErrClosed) { + if strings.Contains(err.Error(), "use of closed network connection") { + break + } + + errs <- err + continue + } + + pkts <- netPkt{addr, buf[:n]} + } + + close(pkts) +} diff --git a/rudp/peer.go b/rudp/peer.go new file mode 100644 index 0000000..feb0ff9 --- /dev/null +++ b/rudp/peer.go @@ -0,0 +1,193 @@ +package rudp + +import ( + "fmt" + "net" + "sync" + "time" +) + +const ( + // ConnTimeout is the amount of time after no packets being received + // from a Peer that it is automatically disconnected. + ConnTimeout = 30 * time.Second + + // ConnTimeout is the amount of time after no packets being sent + // to a Peer that a CtlPing is automatically sent to prevent timeout. + PingTimeout = 5 * time.Second +) + +// A Peer is a connection to a client or server. +type Peer struct { + conn net.PacketConn + addr net.Addr + + disco chan struct{} // close-only + + id PeerID + + pkts chan Pkt + errs chan error // don't close + timedout chan struct{} // close-only + + chans [ChannelCount]pktchan // read/write + + mu sync.RWMutex + idOfPeer PeerID + timeout *time.Timer + ping *time.Ticker +} + +type pktchan struct { + // Only accessed by Peer.processRawPkt. + insplit map[seqnum][][]byte + inrel map[seqnum][]byte + inrelsn seqnum + + ackchans sync.Map // map[seqnum]chan struct{} + + outsplitmu sync.Mutex + outsplitsn seqnum + + outrelmu sync.Mutex + outrelsn seqnum + outrelwin seqnum +} + +// Conn returns the net.PacketConn used to communicate with the Peer. +func (p *Peer) Conn() net.PacketConn { return p.conn } + +// Addr returns the address of the Peer. +func (p *Peer) Addr() net.Addr { return p.addr } + +// Disco returns a channel that is closed when the Peer is closed. +func (p *Peer) Disco() <-chan struct{} { return p.disco } + +// ID returns the ID of the Peer. +func (p *Peer) ID() PeerID { return p.id } + +// IsSrv reports whether the Peer is a server. +func (p *Peer) IsSrv() bool { + return p.ID() == PeerIDSrv +} + +// TimedOut reports whether the Peer has timed out. +func (p *Peer) TimedOut() bool { + select { + case <-p.timedout: + return true + default: + return false + } +} + +// Recv recieves a packet from the Peer. +// You should keep calling this until it returns ErrClosed +// so it doesn't leak a goroutine. +func (p *Peer) Recv() (Pkt, error) { + select { + case pkt, ok := <-p.pkts: + if !ok { + select { + case err := <-p.errs: + return Pkt{}, err + default: + return Pkt{}, ErrClosed + } + } + return pkt, nil + case err := <-p.errs: + return Pkt{}, err + } +} + +// Close closes the Peer but does not send a disconnect packet. +func (p *Peer) Close() error { + p.mu.Lock() + defer p.mu.Unlock() + + select { + case <-p.Disco(): + return ErrClosed + default: + } + + p.timeout.Stop() + p.timeout = nil + p.ping.Stop() + p.ping = nil + + close(p.disco) + + return nil +} + +func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer { + p := &Peer{ + conn: conn, + addr: addr, + id: id, + idOfPeer: idOfPeer, + + pkts: make(chan Pkt), + disco: make(chan struct{}), + errs: make(chan error), + } + + for i := range p.chans { + p.chans[i] = pktchan{ + insplit: make(map[seqnum][][]byte), + inrel: make(map[seqnum][]byte), + inrelsn: seqnumInit, + + outsplitsn: seqnumInit, + outrelsn: seqnumInit, + outrelwin: seqnumInit, + } + } + + p.timedout = make(chan struct{}) + p.timeout = time.AfterFunc(ConnTimeout, func() { + close(p.timedout) + + p.SendDisco(0, true) + p.Close() + }) + + p.ping = time.NewTicker(PingTimeout) + go p.sendPings(p.ping.C) + + return p +} + +func (p *Peer) sendPings(ping <-chan time.Time) { + pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}} + + for { + select { + case <-ping: + if _, err := p.sendRaw(pkt); err != nil { + p.errs <- fmt.Errorf("can't send ping: %w", err) + } + case <-p.Disco(): + return + } + } +} + +// Connect connects to the server on conn +// and closes conn when the Peer disconnects. +func Connect(conn net.PacketConn, addr net.Addr) *Peer { + srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil) + + pkts := make(chan netPkt) + go readNetPkts(conn, pkts, srv.errs) + go srv.processNetPkts(pkts) + + go func() { + <-srv.Disco() + conn.Close() + }() + + return srv +} diff --git a/rudp/process.go b/rudp/process.go new file mode 100644 index 0000000..c36af81 --- /dev/null +++ b/rudp/process.go @@ -0,0 +1,259 @@ +package rudp + +import ( + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" +) + +// A PktError is an error that occured while processing a packet. +type PktError struct { + Type string // "net", "raw" or "rel". + Data []byte + Err error +} + +func (e PktError) Error() string { + return "error processing " + e.Type + " pkt: " + + hex.EncodeToString(e.Data) + ": " + + e.Err.Error() +} + +func (e PktError) Unwrap() error { return e.Err } + +func (p *Peer) processNetPkts(pkts <-chan netPkt) { + for pkt := range pkts { + if err := p.processNetPkt(pkt); err != nil { + p.errs <- PktError{"net", pkt.Data, err} + } + } + + close(p.pkts) +} + +// A TrailingDataError reports a packet with trailing data, +// it doesn't stop a packet from being processed. +type TrailingDataError []byte + +func (e TrailingDataError) Error() string { + return "trailing data: " + hex.EncodeToString([]byte(e)) +} + +func (p *Peer) processNetPkt(pkt netPkt) (err error) { + if pkt.SrcAddr.String() != p.Addr().String() { + return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String()) + } + + if len(pkt.Data) < MtHdrSize { + return io.ErrUnexpectedEOF + } + + if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID { + return fmt.Errorf("unsupported protocol id: 0x%08x", id) + } + + // src PeerID at pkt.Data[4:6] + + chno := pkt.Data[6] + if chno >= ChannelCount { + return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno) + } + + p.mu.RLock() + if p.timeout != nil { + p.timeout.Reset(ConnTimeout) + } + p.mu.RUnlock() + + rpkt := rawPkt{ + Data: pkt.Data[MtHdrSize:], + ChNo: chno, + Unrel: true, + } + if err := p.processRawPkt(rpkt); err != nil { + p.errs <- PktError{"raw", rpkt.Data, err} + } + + return nil +} + +func (p *Peer) processRawPkt(pkt rawPkt) (err error) { + errWrap := func(format string, a ...interface{}) { + if err != nil { + err = fmt.Errorf(format, append(a, err)...) + } + } + + c := &p.chans[pkt.ChNo] + + if len(pkt.Data) < 1 { + return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF) + } + switch t := rawType(pkt.Data[0]); t { + case rawTypeCtl: + defer errWrap("ctl: %w") + + if len(pkt.Data) < 1+1 { + return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF) + } + switch ct := ctlType(pkt.Data[1]); ct { + case ctlAck: + defer errWrap("ack: %w") + + if len(pkt.Data) < 1+1+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4])) + + if ack, ok := c.ackchans.LoadAndDelete(sn); ok { + close(ack.(chan struct{})) + } + + if len(pkt.Data) > 1+1+2 { + return TrailingDataError(pkt.Data[1+1+2:]) + } + case ctlSetPeerID: + defer errWrap("set peer id: %w") + + if len(pkt.Data) < 1+1+2 { + return io.ErrUnexpectedEOF + } + + // Ensure no concurrent senders while peer id changes. + p.mu.Lock() + if p.idOfPeer != PeerIDNil { + return errors.New("peer id already set") + } + + p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4])) + p.mu.Unlock() + + if len(pkt.Data) > 1+1+2 { + return TrailingDataError(pkt.Data[1+1+2:]) + } + case ctlPing: + defer errWrap("ping: %w") + + if len(pkt.Data) > 1+1 { + return TrailingDataError(pkt.Data[1+1:]) + } + case ctlDisco: + defer errWrap("disco: %w") + + if err := p.Close(); err != nil { + return fmt.Errorf("can't close: %w", err) + } + + if len(pkt.Data) > 1+1 { + return TrailingDataError(pkt.Data[1+1:]) + } + default: + return fmt.Errorf("unsupported ctl type: %d", ct) + } + case rawTypeOrig: + p.pkts <- Pkt{ + Data: pkt.Data[1:], + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + } + case rawTypeSplit: + defer errWrap("split: %w") + + if len(pkt.Data) < 1+2+2+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) + count := binary.BigEndian.Uint16(pkt.Data[3:5]) + i := binary.BigEndian.Uint16(pkt.Data[5:7]) + + if i >= count { + return nil + } + + splitpkts := p.chans[pkt.ChNo].insplit + + // Delete old incomplete split packets + // so new ones don't get corrupted. + delete(splitpkts, sn-0x8000) + + if splitpkts[sn] == nil { + splitpkts[sn] = make([][]byte, count) + } + + chunks := splitpkts[sn] + + if int(count) != len(chunks) { + return fmt.Errorf("chunk count changed on seqnum: %d", sn) + } + + chunks[i] = pkt.Data[7:] + + for _, chunk := range chunks { + if chunk == nil { + return nil + } + } + + var data []byte + for _, chunk := range chunks { + data = append(data, chunk...) + } + + p.pkts <- Pkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + } + + delete(splitpkts, sn) + case rawTypeRel: + defer errWrap("rel: %w") + + if len(pkt.Data) < 1+2 { + return io.ErrUnexpectedEOF + } + + sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3])) + + ackdata := make([]byte, 1+1+2) + ackdata[0] = uint8(rawTypeCtl) + ackdata[1] = uint8(ctlAck) + binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn)) + ack := rawPkt{ + Data: ackdata, + ChNo: pkt.ChNo, + Unrel: true, + } + if _, err := p.sendRaw(ack); err != nil { + return fmt.Errorf("can't ack %d: %w", sn, err) + } + + if sn-c.inrelsn >= 0x8000 { + return nil // Already received. + } + + c.inrel[sn] = pkt.Data[3:] + + for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ { + data := c.inrel[c.inrelsn] + delete(c.inrel, c.inrelsn) + + rpkt := rawPkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: false, + } + if err := p.processRawPkt(rpkt); err != nil { + p.errs <- PktError{"rel", rpkt.Data, err} + } + } + default: + return fmt.Errorf("unsupported pkt type: %d", t) + } + + return nil +} diff --git a/rudp/proto.go b/rudp/proto.go new file mode 100644 index 0000000..04176b2 --- /dev/null +++ b/rudp/proto.go @@ -0,0 +1,103 @@ +/* +Package rudp implements the low-level Minetest protocol described at +https://dev.minetest.net/Network_Protocol#Low-level_protocol. + +All exported functions and methods in this package are safe for concurrent use +by multiple goroutines. +*/ +package rudp + +// protoID must be at the start of every network packet. +const protoID uint32 = 0x4f457403 + +// PeerIDs aren't actually used to identify peers, network addresses are, +// these just exist for backward compatability. +type PeerID uint16 + +const ( + // Used by clients before the server sets their ID. + PeerIDNil PeerID = iota + + // The server always has this ID. + PeerIDSrv + + // Lowest ID the server can assign to a client. + PeerIDCltMin +) + +// ChannelCount is the maximum channel number + 1. +const ChannelCount = 3 + +/* +rawPkt.Data format (big endian): + + rawType + switch rawType { + case rawTypeCtl: + ctlType + switch ctlType { + case ctlAck: + // Tells peer you received a rawTypeRel + // and it doesn't need to resend it. + seqnum + case ctlSetPeerId: + // Tells peer to send packets with this Src PeerID. + PeerId + case ctlPing: + // Sent to prevent timeout. + case ctlDisco: + // Tells peer that you disconnected. + } + case rawTypeOrig: + Pkt.(Data) + case rawTypeSplit: + // Packet larger than MaxNetPktSize split into smaller packets. + // Packets with I >= Count should be ignored. + // Once all Count chunks are recieved, they are sorted by I and + // concatenated to make a Pkt.(Data). + seqnum // Identifies split packet. + Count, I uint16 + Chunk... + case rawTypeRel: + // Resent until a ctlAck with same seqnum is recieved. + // seqnums are sequencial and start at seqnumInit, + // These should be processed in seqnum order. + seqnum + rawPkt.Data + } +*/ +type rawPkt struct { + Data []byte + ChNo uint8 + Unrel bool +} + +type rawType uint8 + +const ( + rawTypeCtl rawType = iota + rawTypeOrig + rawTypeSplit + rawTypeRel +) + +type ctlType uint8 + +const ( + ctlAck ctlType = iota + ctlSetPeerID + ctlPing + ctlDisco +) + +type Pkt struct { + Data []byte + ChNo uint8 + Unrel bool +} + +// seqnums are sequence numbers used to maintain reliable packet order +// and to identify split packets. +type seqnum uint16 + +const seqnumInit seqnum = 65500 diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go new file mode 100644 index 0000000..6fc14ec --- /dev/null +++ b/rudp/proxy/proxy.go @@ -0,0 +1,85 @@ +/* +Proxy is a Minetest RUDP proxy server +supporting multiple concurrent connections. + +Usage: + proxy dial:port listen:port +where dial:port is the server address +and listen:port is the address to listen on. +*/ +package main + +import ( + "fmt" + "log" + "net" + "os" + + "github.com/anon55555/mt/rudp" +) + +func main() { + if len(os.Args) != 3 { + fmt.Fprintln(os.Stderr, "usage: proxy dial:port listen:port") + os.Exit(1) + } + + srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1]) + if err != nil { + log.Fatal(err) + } + + lc, err := net.ListenPacket("udp", os.Args[2]) + if err != nil { + log.Fatal(err) + } + defer lc.Close() + + l := rudp.Listen(lc) + for { + clt, err := l.Accept() + if err != nil { + log.Print(err) + continue + } + + log.Print(clt.Addr(), " connected") + + conn, err := net.DialUDP("udp", nil, srvaddr) + if err != nil { + log.Print(err) + continue + } + srv := rudp.Connect(conn, conn.RemoteAddr()) + + go proxy(clt, srv) + go proxy(srv, clt) + } +} + +func proxy(src, dest *rudp.Peer) { + for { + pkt, err := src.Recv() + if err != nil { + if err == rudp.ErrClosed { + msg := src.Addr().String() + " disconnected" + if src.TimedOut() { + msg += " (timed out)" + } + log.Print(msg) + + break + } + + log.Print(err) + continue + } + + if _, err := dest.Send(pkt); err != nil { + log.Print(err) + } + } + + dest.SendDisco(0, true) + dest.Close() +} diff --git a/rudp/send.go b/rudp/send.go new file mode 100644 index 0000000..3cfcda4 --- /dev/null +++ b/rudp/send.go @@ -0,0 +1,248 @@ +package rudp + +import ( + "encoding/binary" + "errors" + "fmt" + "math" + "net" + "sync" + "time" +) + +const ( + // protoID + src PeerID + channel number + MtHdrSize = 4 + 2 + 1 + + // rawTypeOrig + OrigHdrSize = 1 + + // rawTypeSpilt + seqnum + chunk count + chunk number + SplitHdrSize = 1 + 2 + 2 + 2 + + // rawTypeRel + seqnum + RelHdrSize = 1 + 2 +) + +const ( + MaxNetPktSize = 512 + + MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize + MaxRelRawPktSize = MaxUnrelRawPktSize - RelHdrSize + + MaxRelPktSize = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16 + MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16 +) + +var ErrPktTooBig = errors.New("can't send pkt: too big") +var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount") + +// Send sends a packet to the Peer. +// It returns a channel that's closed when all chunks are acked or an error. +// The ack channel is nil if pkt.Unrel is true. +func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { + if pkt.ChNo >= ChannelCount { + return nil, ErrChNoTooBig + } + + hdrsize := MtHdrSize + if !pkt.Unrel { + hdrsize += RelHdrSize + } + + if hdrsize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize { + c := &p.chans[pkt.ChNo] + + c.outsplitmu.Lock() + sn := c.outsplitsn + c.outsplitsn++ + c.outsplitmu.Unlock() + + chunks := split(pkt.Data, MaxNetPktSize-(hdrsize+SplitHdrSize)) + + if len(chunks) > math.MaxUint16 { + return nil, ErrPktTooBig + } + + var wg sync.WaitGroup + + for i, chunk := range chunks { + data := make([]byte, SplitHdrSize+len(chunk)) + data[0] = uint8(rawTypeSplit) + binary.BigEndian.PutUint16(data[1:3], uint16(sn)) + binary.BigEndian.PutUint16(data[3:5], uint16(len(chunks))) + binary.BigEndian.PutUint16(data[5:7], uint16(i)) + copy(data[SplitHdrSize:], chunk) + + wg.Add(1) + ack, err := p.sendRaw(rawPkt{ + Data: data, + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + }) + if err != nil { + return nil, err + } + if !pkt.Unrel { + if ack == nil { + panic("ack is nil") + } + go func() { + <-ack + wg.Done() + }() + } + } + + if pkt.Unrel { + return nil, nil + } else { + ack := make(chan struct{}) + + go func() { + wg.Wait() + close(ack) + }() + + return ack, nil + } + } + + return p.sendRaw(rawPkt{ + Data: append([]byte{uint8(rawTypeOrig)}, pkt.Data...), + ChNo: pkt.ChNo, + Unrel: pkt.Unrel, + }) +} + +// sendRaw sends a raw packet to the Peer. +func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) { + if pkt.ChNo >= ChannelCount { + return nil, ErrChNoTooBig + } + + p.mu.RLock() + defer p.mu.RUnlock() + + select { + case <-p.Disco(): + return nil, ErrClosed + default: + } + + if !pkt.Unrel { + return p.sendRel(pkt) + } + + data := make([]byte, MtHdrSize+len(pkt.Data)) + binary.BigEndian.PutUint32(data[0:4], protoID) + binary.BigEndian.PutUint16(data[4:6], uint16(p.idOfPeer)) + data[6] = pkt.ChNo + copy(data[MtHdrSize:], pkt.Data) + + if len(data) > MaxNetPktSize { + return nil, ErrPktTooBig + } + + _, err = p.Conn().WriteTo(data, p.Addr()) + if errors.Is(err, net.ErrWriteToConnected) { + conn, ok := p.Conn().(net.Conn) + if !ok { + return nil, err + } + _, err = conn.Write(data) + } + if err != nil { + return nil, err + } + + p.ping.Reset(PingTimeout) + + return nil, nil +} + +// sendRel sends a reliable raw packet to the Peer. +func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) { + if pkt.Unrel { + panic("mt/rudp: sendRel: pkt.Unrel is true") + } + + c := &p.chans[pkt.ChNo] + + c.outrelmu.Lock() + defer c.outrelmu.Unlock() + + sn := c.outrelsn + for ; sn-c.outrelwin >= 0x8000; c.outrelwin++ { + if ack, ok := c.ackchans.Load(c.outrelwin); ok { + <-ack.(chan struct{}) + } + } + c.outrelsn++ + + rwack := make(chan struct{}) // close-only + c.ackchans.Store(sn, rwack) + ack = rwack + + reldata := make([]byte, RelHdrSize+len(pkt.Data)) + reldata[0] = uint8(rawTypeRel) + binary.BigEndian.PutUint16(reldata[1:3], uint16(sn)) + copy(reldata[RelHdrSize:], pkt.Data) + relpkt := rawPkt{ + Data: reldata, + ChNo: pkt.ChNo, + Unrel: true, + } + + if _, err := p.sendRaw(relpkt); err != nil { + c.ackchans.Delete(sn) + + return nil, err + } + + go func() { + resend := time.NewTicker(500 * time.Millisecond) + defer resend.Stop() + + for { + select { + case <-resend.C: + if _, err := p.sendRaw(relpkt); err != nil { + p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err) + } + case <-ack: + return + case <-p.Disco(): + return + } + } + }() + + return ack, nil +} + +// SendDisco sends a disconnect packet to the Peer but does not close it. +// It returns a channel that's closed when it's acked or an error. +// The ack channel is nil if unrel is true. +func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) { + return p.sendRaw(rawPkt{ + Data: []byte{uint8(rawTypeCtl), uint8(ctlDisco)}, + ChNo: chno, + Unrel: unrel, + }) +} + +func split(data []byte, chunksize int) [][]byte { + chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize) + + for i := 0; i < len(data); i += chunksize { + end := i + chunksize + if end > len(data) { + end = len(data) + } + + chunks = append(chunks, data[i:end]) + } + + return chunks +} -- 2.44.0