]> git.lizzy.rs Git - mt.git/commitdiff
Initial public release
authoranon5 <anon5clam@protonmail.com>
Sat, 7 Nov 2020 18:01:24 +0000 (18:01 +0000)
committeranon5 <anon5clam@protonmail.com>
Sat, 7 Nov 2020 18:01:24 +0000 (18:01 +0000)
LICENSE [new file with mode: 0644]
rudp/listen.go [new file with mode: 0644]
rudp/net.go [new file with mode: 0644]
rudp/peer.go [new file with mode: 0644]
rudp/process.go [new file with mode: 0644]
rudp/proto.go [new file with mode: 0644]
rudp/proxy/proxy.go [new file with mode: 0644]
rudp/send.go [new file with mode: 0644]

diff --git a/LICENSE b/LICENSE
new file mode 100644 (file)
index 0000000..3febf69
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2020 anon5
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/rudp/listen.go b/rudp/listen.go
new file mode 100644 (file)
index 0000000..871a591
--- /dev/null
@@ -0,0 +1,153 @@
+package rudp
+
+import (
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "net"
+       "sync"
+)
+
+type Listener struct {
+       conn net.PacketConn
+
+       clts chan cltPeer
+       errs chan error
+
+       mu        sync.Mutex
+       addr2peer map[string]cltPeer
+       id2peer   map[PeerID]cltPeer
+       peerid    PeerID
+}
+
+// Listen listens for packets on conn until it is closed.
+func Listen(conn net.PacketConn) *Listener {
+       l := &Listener{
+               conn: conn,
+
+               clts: make(chan cltPeer),
+               errs: make(chan error),
+
+               addr2peer: make(map[string]cltPeer),
+               id2peer:   make(map[PeerID]cltPeer),
+       }
+
+       pkts := make(chan netPkt)
+       go readNetPkts(l.conn, pkts, l.errs)
+       go func() {
+               for pkt := range pkts {
+                       if err := l.processNetPkt(pkt); err != nil {
+                               l.errs <- err
+                       }
+               }
+
+               close(l.clts)
+
+               for _, clt := range l.addr2peer {
+                       clt.Close()
+               }
+       }()
+
+       return l
+}
+
+// Accept waits for and returns a connecting Peer.
+// You should keep calling this until it returns ErrClosed
+// so it doesn't leak a goroutine.
+func (l *Listener) Accept() (*Peer, error) {
+       select {
+       case clt, ok := <-l.clts:
+               if !ok {
+                       select {
+                       case err := <-l.errs:
+                               return nil, err
+                       default:
+                               return nil, ErrClosed
+                       }
+               }
+               close(clt.accepted)
+               return clt.Peer, nil
+       case err := <-l.errs:
+               return nil, err
+       }
+}
+
+// Addr returns the net.PacketConn the Listener is listening on.
+func (l *Listener) Conn() net.PacketConn { return l.conn }
+
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
+
+type cltPeer struct {
+       *Peer
+       pkts     chan<- netPkt
+       accepted chan struct{} // close-only
+}
+
+func (l *Listener) processNetPkt(pkt netPkt) error {
+       l.mu.Lock()
+       defer l.mu.Unlock()
+
+       addrstr := pkt.SrcAddr.String()
+
+       clt, ok := l.addr2peer[addrstr]
+       if !ok {
+               prev := l.peerid
+               for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin {
+                       if l.peerid == prev-1 {
+                               return ErrOutOfPeerIDs
+                       }
+                       l.peerid++
+               }
+
+               pkts := make(chan netPkt, 256)
+
+               clt = cltPeer{
+                       Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv),
+                       pkts:     pkts,
+                       accepted: make(chan struct{}),
+               }
+
+               l.addr2peer[addrstr] = clt
+               l.id2peer[clt.ID()] = clt
+
+               data := make([]byte, 1+1+2)
+               data[0] = uint8(rawTypeCtl)
+               data[1] = uint8(ctlSetPeerID)
+               binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID()))
+               if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
+                       return fmt.Errorf("can't set client peer id: %w", err)
+               }
+
+               go func() {
+                       select {
+                       case l.clts <- clt:
+                       case <-clt.Disco():
+                       }
+
+                       clt.processNetPkts(pkts)
+               }()
+
+               go func() {
+                       <-clt.Disco()
+
+                       l.mu.Lock()
+                       close(pkts)
+                       delete(l.addr2peer, addrstr)
+                       delete(l.id2peer, clt.ID())
+                       l.mu.Unlock()
+               }()
+       }
+
+       select {
+       case <-clt.accepted:
+               clt.pkts <- pkt
+       default:
+               select {
+               case clt.pkts <- pkt:
+               default:
+                       return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr)
+               }
+       }
+
+       return nil
+}
diff --git a/rudp/net.go b/rudp/net.go
new file mode 100644 (file)
index 0000000..421a3e7
--- /dev/null
@@ -0,0 +1,44 @@
+package rudp
+
+import (
+       "errors"
+       "net"
+       "strings"
+)
+
+// TODO: Use net.ErrClosed when Go 1.16 is released.
+var ErrClosed = errors.New("use of closed peer")
+
+/*
+netPkt.Data format (big endian):
+
+       ProtoID
+       Src PeerID
+       ChNo uint8 // Must be < ChannelCount.
+       RawPkt.Data
+*/
+type netPkt struct {
+       SrcAddr net.Addr
+       Data    []byte
+}
+
+func readNetPkts(conn net.PacketConn, pkts chan<- netPkt, errs chan<- error) {
+       for {
+               buf := make([]byte, MaxNetPktSize)
+               n, addr, err := conn.ReadFrom(buf)
+               if err != nil {
+                       // TODO: Change to this when Go 1.16 is released:
+                       // if errors.Is(err, net.ErrClosed) {
+                       if strings.Contains(err.Error(), "use of closed network connection") {
+                               break
+                       }
+
+                       errs <- err
+                       continue
+               }
+
+               pkts <- netPkt{addr, buf[:n]}
+       }
+
+       close(pkts)
+}
diff --git a/rudp/peer.go b/rudp/peer.go
new file mode 100644 (file)
index 0000000..feb0ff9
--- /dev/null
@@ -0,0 +1,193 @@
+package rudp
+
+import (
+       "fmt"
+       "net"
+       "sync"
+       "time"
+)
+
+const (
+       // ConnTimeout is the amount of time after no packets being received
+       // from a Peer that it is automatically disconnected.
+       ConnTimeout = 30 * time.Second
+
+       // ConnTimeout is the amount of time after no packets being sent
+       // to a Peer that a CtlPing is automatically sent to prevent timeout.
+       PingTimeout = 5 * time.Second
+)
+
+// A Peer is a connection to a client or server.
+type Peer struct {
+       conn net.PacketConn
+       addr net.Addr
+
+       disco chan struct{} // close-only
+
+       id PeerID
+
+       pkts     chan Pkt
+       errs     chan error    // don't close
+       timedout chan struct{} // close-only
+
+       chans [ChannelCount]pktchan // read/write
+
+       mu       sync.RWMutex
+       idOfPeer PeerID
+       timeout  *time.Timer
+       ping     *time.Ticker
+}
+
+type pktchan struct {
+       // Only accessed by Peer.processRawPkt.
+       insplit map[seqnum][][]byte
+       inrel   map[seqnum][]byte
+       inrelsn seqnum
+
+       ackchans sync.Map // map[seqnum]chan struct{}
+
+       outsplitmu sync.Mutex
+       outsplitsn seqnum
+
+       outrelmu  sync.Mutex
+       outrelsn  seqnum
+       outrelwin seqnum
+}
+
+// Conn returns the net.PacketConn used to communicate with the Peer.
+func (p *Peer) Conn() net.PacketConn { return p.conn }
+
+// Addr returns the address of the Peer.
+func (p *Peer) Addr() net.Addr { return p.addr }
+
+// Disco returns a channel that is closed when the Peer is closed.
+func (p *Peer) Disco() <-chan struct{} { return p.disco }
+
+// ID returns the ID of the Peer.
+func (p *Peer) ID() PeerID { return p.id }
+
+// IsSrv reports whether the Peer is a server.
+func (p *Peer) IsSrv() bool {
+       return p.ID() == PeerIDSrv
+}
+
+// TimedOut reports whether the Peer has timed out.
+func (p *Peer) TimedOut() bool {
+       select {
+       case <-p.timedout:
+               return true
+       default:
+               return false
+       }
+}
+
+// Recv recieves a packet from the Peer.
+// You should keep calling this until it returns ErrClosed
+// so it doesn't leak a goroutine.
+func (p *Peer) Recv() (Pkt, error) {
+       select {
+       case pkt, ok := <-p.pkts:
+               if !ok {
+                       select {
+                       case err := <-p.errs:
+                               return Pkt{}, err
+                       default:
+                               return Pkt{}, ErrClosed
+                       }
+               }
+               return pkt, nil
+       case err := <-p.errs:
+               return Pkt{}, err
+       }
+}
+
+// Close closes the Peer but does not send a disconnect packet.
+func (p *Peer) Close() error {
+       p.mu.Lock()
+       defer p.mu.Unlock()
+
+       select {
+       case <-p.Disco():
+               return ErrClosed
+       default:
+       }
+
+       p.timeout.Stop()
+       p.timeout = nil
+       p.ping.Stop()
+       p.ping = nil
+
+       close(p.disco)
+
+       return nil
+}
+
+func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
+       p := &Peer{
+               conn:     conn,
+               addr:     addr,
+               id:       id,
+               idOfPeer: idOfPeer,
+
+               pkts:  make(chan Pkt),
+               disco: make(chan struct{}),
+               errs:  make(chan error),
+       }
+
+       for i := range p.chans {
+               p.chans[i] = pktchan{
+                       insplit: make(map[seqnum][][]byte),
+                       inrel:   make(map[seqnum][]byte),
+                       inrelsn: seqnumInit,
+
+                       outsplitsn: seqnumInit,
+                       outrelsn:   seqnumInit,
+                       outrelwin:  seqnumInit,
+               }
+       }
+
+       p.timedout = make(chan struct{})
+       p.timeout = time.AfterFunc(ConnTimeout, func() {
+               close(p.timedout)
+
+               p.SendDisco(0, true)
+               p.Close()
+       })
+
+       p.ping = time.NewTicker(PingTimeout)
+       go p.sendPings(p.ping.C)
+
+       return p
+}
+
+func (p *Peer) sendPings(ping <-chan time.Time) {
+       pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
+
+       for {
+               select {
+               case <-ping:
+                       if _, err := p.sendRaw(pkt); err != nil {
+                               p.errs <- fmt.Errorf("can't send ping: %w", err)
+                       }
+               case <-p.Disco():
+                       return
+               }
+       }
+}
+
+// Connect connects to the server on conn
+// and closes conn when the Peer disconnects.
+func Connect(conn net.PacketConn, addr net.Addr) *Peer {
+       srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil)
+
+       pkts := make(chan netPkt)
+       go readNetPkts(conn, pkts, srv.errs)
+       go srv.processNetPkts(pkts)
+
+       go func() {
+               <-srv.Disco()
+               conn.Close()
+       }()
+
+       return srv
+}
diff --git a/rudp/process.go b/rudp/process.go
new file mode 100644 (file)
index 0000000..c36af81
--- /dev/null
@@ -0,0 +1,259 @@
+package rudp
+
+import (
+       "encoding/binary"
+       "encoding/hex"
+       "errors"
+       "fmt"
+       "io"
+)
+
+// A PktError is an error that occured while processing a packet.
+type PktError struct {
+       Type string // "net", "raw" or "rel".
+       Data []byte
+       Err  error
+}
+
+func (e PktError) Error() string {
+       return "error processing " + e.Type + " pkt: " +
+               hex.EncodeToString(e.Data) + ": " +
+               e.Err.Error()
+}
+
+func (e PktError) Unwrap() error { return e.Err }
+
+func (p *Peer) processNetPkts(pkts <-chan netPkt) {
+       for pkt := range pkts {
+               if err := p.processNetPkt(pkt); err != nil {
+                       p.errs <- PktError{"net", pkt.Data, err}
+               }
+       }
+
+       close(p.pkts)
+}
+
+// A TrailingDataError reports a packet with trailing data,
+// it doesn't stop a packet from being processed.
+type TrailingDataError []byte
+
+func (e TrailingDataError) Error() string {
+       return "trailing data: " + hex.EncodeToString([]byte(e))
+}
+
+func (p *Peer) processNetPkt(pkt netPkt) (err error) {
+       if pkt.SrcAddr.String() != p.Addr().String() {
+               return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
+       }
+
+       if len(pkt.Data) < MtHdrSize {
+               return io.ErrUnexpectedEOF
+       }
+
+       if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
+               return fmt.Errorf("unsupported protocol id: 0x%08x", id)
+       }
+
+       // src PeerID at pkt.Data[4:6]
+
+       chno := pkt.Data[6]
+       if chno >= ChannelCount {
+               return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
+       }
+
+       p.mu.RLock()
+       if p.timeout != nil {
+               p.timeout.Reset(ConnTimeout)
+       }
+       p.mu.RUnlock()
+
+       rpkt := rawPkt{
+               Data:  pkt.Data[MtHdrSize:],
+               ChNo:  chno,
+               Unrel: true,
+       }
+       if err := p.processRawPkt(rpkt); err != nil {
+               p.errs <- PktError{"raw", rpkt.Data, err}
+       }
+
+       return nil
+}
+
+func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
+       errWrap := func(format string, a ...interface{}) {
+               if err != nil {
+                       err = fmt.Errorf(format, append(a, err)...)
+               }
+       }
+
+       c := &p.chans[pkt.ChNo]
+
+       if len(pkt.Data) < 1 {
+               return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
+       }
+       switch t := rawType(pkt.Data[0]); t {
+       case rawTypeCtl:
+               defer errWrap("ctl: %w")
+
+               if len(pkt.Data) < 1+1 {
+                       return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
+               }
+               switch ct := ctlType(pkt.Data[1]); ct {
+               case ctlAck:
+                       defer errWrap("ack: %w")
+
+                       if len(pkt.Data) < 1+1+2 {
+                               return io.ErrUnexpectedEOF
+                       }
+
+                       sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
+
+                       if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
+                               close(ack.(chan struct{}))
+                       }
+
+                       if len(pkt.Data) > 1+1+2 {
+                               return TrailingDataError(pkt.Data[1+1+2:])
+                       }
+               case ctlSetPeerID:
+                       defer errWrap("set peer id: %w")
+
+                       if len(pkt.Data) < 1+1+2 {
+                               return io.ErrUnexpectedEOF
+                       }
+
+                       // Ensure no concurrent senders while peer id changes.
+                       p.mu.Lock()
+                       if p.idOfPeer != PeerIDNil {
+                               return errors.New("peer id already set")
+                       }
+
+                       p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
+                       p.mu.Unlock()
+
+                       if len(pkt.Data) > 1+1+2 {
+                               return TrailingDataError(pkt.Data[1+1+2:])
+                       }
+               case ctlPing:
+                       defer errWrap("ping: %w")
+
+                       if len(pkt.Data) > 1+1 {
+                               return TrailingDataError(pkt.Data[1+1:])
+                       }
+               case ctlDisco:
+                       defer errWrap("disco: %w")
+
+                       if err := p.Close(); err != nil {
+                               return fmt.Errorf("can't close: %w", err)
+                       }
+
+                       if len(pkt.Data) > 1+1 {
+                               return TrailingDataError(pkt.Data[1+1:])
+                       }
+               default:
+                       return fmt.Errorf("unsupported ctl type: %d", ct)
+               }
+       case rawTypeOrig:
+               p.pkts <- Pkt{
+                       Data:  pkt.Data[1:],
+                       ChNo:  pkt.ChNo,
+                       Unrel: pkt.Unrel,
+               }
+       case rawTypeSplit:
+               defer errWrap("split: %w")
+
+               if len(pkt.Data) < 1+2+2+2 {
+                       return io.ErrUnexpectedEOF
+               }
+
+               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+               count := binary.BigEndian.Uint16(pkt.Data[3:5])
+               i := binary.BigEndian.Uint16(pkt.Data[5:7])
+
+               if i >= count {
+                       return nil
+               }
+
+               splitpkts := p.chans[pkt.ChNo].insplit
+
+               // Delete old incomplete split packets
+               // so new ones don't get corrupted.
+               delete(splitpkts, sn-0x8000)
+
+               if splitpkts[sn] == nil {
+                       splitpkts[sn] = make([][]byte, count)
+               }
+
+               chunks := splitpkts[sn]
+
+               if int(count) != len(chunks) {
+                       return fmt.Errorf("chunk count changed on seqnum: %d", sn)
+               }
+
+               chunks[i] = pkt.Data[7:]
+
+               for _, chunk := range chunks {
+                       if chunk == nil {
+                               return nil
+                       }
+               }
+
+               var data []byte
+               for _, chunk := range chunks {
+                       data = append(data, chunk...)
+               }
+
+               p.pkts <- Pkt{
+                       Data:  data,
+                       ChNo:  pkt.ChNo,
+                       Unrel: pkt.Unrel,
+               }
+
+               delete(splitpkts, sn)
+       case rawTypeRel:
+               defer errWrap("rel: %w")
+
+               if len(pkt.Data) < 1+2 {
+                       return io.ErrUnexpectedEOF
+               }
+
+               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+
+               ackdata := make([]byte, 1+1+2)
+               ackdata[0] = uint8(rawTypeCtl)
+               ackdata[1] = uint8(ctlAck)
+               binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
+               ack := rawPkt{
+                       Data:  ackdata,
+                       ChNo:  pkt.ChNo,
+                       Unrel: true,
+               }
+               if _, err := p.sendRaw(ack); err != nil {
+                       return fmt.Errorf("can't ack %d: %w", sn, err)
+               }
+
+               if sn-c.inrelsn >= 0x8000 {
+                       return nil // Already received.
+               }
+
+               c.inrel[sn] = pkt.Data[3:]
+
+               for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
+                       data := c.inrel[c.inrelsn]
+                       delete(c.inrel, c.inrelsn)
+
+                       rpkt := rawPkt{
+                               Data:  data,
+                               ChNo:  pkt.ChNo,
+                               Unrel: false,
+                       }
+                       if err := p.processRawPkt(rpkt); err != nil {
+                               p.errs <- PktError{"rel", rpkt.Data, err}
+                       }
+               }
+       default:
+               return fmt.Errorf("unsupported pkt type: %d", t)
+       }
+
+       return nil
+}
diff --git a/rudp/proto.go b/rudp/proto.go
new file mode 100644 (file)
index 0000000..04176b2
--- /dev/null
@@ -0,0 +1,103 @@
+/*
+Package rudp implements the low-level Minetest protocol described at
+https://dev.minetest.net/Network_Protocol#Low-level_protocol.
+
+All exported functions and methods in this package are safe for concurrent use
+by multiple goroutines.
+*/
+package rudp
+
+// protoID must be at the start of every network packet.
+const protoID uint32 = 0x4f457403
+
+// PeerIDs aren't actually used to identify peers, network addresses are,
+// these just exist for backward compatability.
+type PeerID uint16
+
+const (
+       // Used by clients before the server sets their ID.
+       PeerIDNil PeerID = iota
+
+       // The server always has this ID.
+       PeerIDSrv
+
+       // Lowest ID the server can assign to a client.
+       PeerIDCltMin
+)
+
+// ChannelCount is the maximum channel number + 1.
+const ChannelCount = 3
+
+/*
+rawPkt.Data format (big endian):
+
+       rawType
+       switch rawType {
+       case rawTypeCtl:
+               ctlType
+               switch ctlType {
+               case ctlAck:
+                       // Tells peer you received a rawTypeRel
+                       // and it doesn't need to resend it.
+                       seqnum
+               case ctlSetPeerId:
+                       // Tells peer to send packets with this Src PeerID.
+                       PeerId
+               case ctlPing:
+                       // Sent to prevent timeout.
+               case ctlDisco:
+                       // Tells peer that you disconnected.
+               }
+       case rawTypeOrig:
+               Pkt.(Data)
+       case rawTypeSplit:
+               // Packet larger than MaxNetPktSize split into smaller packets.
+               // Packets with I >= Count should be ignored.
+               // Once all Count chunks are recieved, they are sorted by I and
+               // concatenated to make a Pkt.(Data).
+               seqnum // Identifies split packet.
+               Count, I uint16
+               Chunk...
+       case rawTypeRel:
+               // Resent until a ctlAck with same seqnum is recieved.
+               // seqnums are sequencial and start at seqnumInit,
+               // These should be processed in seqnum order.
+               seqnum
+               rawPkt.Data
+       }
+*/
+type rawPkt struct {
+       Data  []byte
+       ChNo  uint8
+       Unrel bool
+}
+
+type rawType uint8
+
+const (
+       rawTypeCtl rawType = iota
+       rawTypeOrig
+       rawTypeSplit
+       rawTypeRel
+)
+
+type ctlType uint8
+
+const (
+       ctlAck ctlType = iota
+       ctlSetPeerID
+       ctlPing
+       ctlDisco
+)
+
+type Pkt struct {
+       Data  []byte
+       ChNo  uint8
+       Unrel bool
+}
+
+// seqnums are sequence numbers used to maintain reliable packet order
+// and to identify split packets.
+type seqnum uint16
+
+const seqnumInit seqnum = 65500
diff --git a/rudp/proxy/proxy.go b/rudp/proxy/proxy.go
new file mode 100644 (file)
index 0000000..6fc14ec
--- /dev/null
@@ -0,0 +1,85 @@
+/*
+Proxy is a Minetest RUDP proxy server
+supporting multiple concurrent connections.
+
+Usage:
+       proxy dial:port listen:port
+where dial:port is the server address
+and listen:port is the address to listen on.
+*/
+package main
+
+import (
+       "fmt"
+       "log"
+       "net"
+       "os"
+
+       "github.com/anon55555/mt/rudp"
+)
+
+func main() {
+       if len(os.Args) != 3 {
+               fmt.Fprintln(os.Stderr, "usage: proxy dial:port listen:port")
+               os.Exit(1)
+       }
+
+       srvaddr, err := net.ResolveUDPAddr("udp", os.Args[1])
+       if err != nil {
+               log.Fatal(err)
+       }
+
+       lc, err := net.ListenPacket("udp", os.Args[2])
+       if err != nil {
+               log.Fatal(err)
+       }
+       defer lc.Close()
+
+       l := rudp.Listen(lc)
+       for {
+               clt, err := l.Accept()
+               if err != nil {
+                       log.Print(err)
+                       continue
+               }
+
+               log.Print(clt.Addr(), " connected")
+
+               conn, err := net.DialUDP("udp", nil, srvaddr)
+               if err != nil {
+                       log.Print(err)
+                       continue
+               }
+               srv := rudp.Connect(conn, conn.RemoteAddr())
+
+               go proxy(clt, srv)
+               go proxy(srv, clt)
+       }
+}
+
+func proxy(src, dest *rudp.Peer) {
+       for {
+               pkt, err := src.Recv()
+               if err != nil {
+                       if err == rudp.ErrClosed {
+                               msg := src.Addr().String() + " disconnected"
+                               if src.TimedOut() {
+                                       msg += " (timed out)"
+                               }
+                               log.Print(msg)
+
+                               break
+                       }
+
+                       log.Print(err)
+                       continue
+               }
+
+               if _, err := dest.Send(pkt); err != nil {
+                       log.Print(err)
+               }
+       }
+
+       dest.SendDisco(0, true)
+       dest.Close()
+}
diff --git a/rudp/send.go b/rudp/send.go
new file mode 100644 (file)
index 0000000..3cfcda4
--- /dev/null
@@ -0,0 +1,248 @@
+package rudp
+
+import (
+       "encoding/binary"
+       "errors"
+       "fmt"
+       "math"
+       "net"
+       "sync"
+       "time"
+)
+
+const (
+       // protoID + src PeerID + channel number
+       MtHdrSize = 4 + 2 + 1
+
+       // rawTypeOrig
+       OrigHdrSize = 1
+
+       // rawTypeSpilt + seqnum + chunk count + chunk number
+       SplitHdrSize = 1 + 2 + 2 + 2
+
+       // rawTypeRel + seqnum
+       RelHdrSize = 1 + 2
+)
+
+const (
+       MaxNetPktSize = 512
+
+       MaxUnrelRawPktSize = MaxNetPktSize - MtHdrSize
+       MaxRelRawPktSize   = MaxUnrelRawPktSize - RelHdrSize
+
+       MaxRelPktSize   = (MaxRelRawPktSize - SplitHdrSize) * math.MaxUint16
+       MaxUnrelPktSize = (MaxUnrelRawPktSize - SplitHdrSize) * math.MaxUint16
+)
+
+var ErrPktTooBig = errors.New("can't send pkt: too big")
+var ErrChNoTooBig = errors.New("can't send pkt: channel number >= ChannelCount")
+
+// Send sends a packet to the Peer.
+// It returns a channel that's closed when all chunks are acked or an error.
+// The ack channel is nil if pkt.Unrel is true.
+func (p *Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+       if pkt.ChNo >= ChannelCount {
+               return nil, ErrChNoTooBig
+       }
+
+       hdrsize := MtHdrSize
+       if !pkt.Unrel {
+               hdrsize += RelHdrSize
+       }
+
+       if hdrsize+OrigHdrSize+len(pkt.Data) > MaxNetPktSize {
+               c := &p.chans[pkt.ChNo]
+
+               c.outsplitmu.Lock()
+               sn := c.outsplitsn
+               c.outsplitsn++
+               c.outsplitmu.Unlock()
+
+               chunks := split(pkt.Data, MaxNetPktSize-(hdrsize+SplitHdrSize))
+
+               if len(chunks) > math.MaxUint16 {
+                       return nil, ErrPktTooBig
+               }
+
+               var wg sync.WaitGroup
+
+               for i, chunk := range chunks {
+                       data := make([]byte, SplitHdrSize+len(chunk))
+                       data[0] = uint8(rawTypeSplit)
+                       binary.BigEndian.PutUint16(data[1:3], uint16(sn))
+                       binary.BigEndian.PutUint16(data[3:5], uint16(len(chunks)))
+                       binary.BigEndian.PutUint16(data[5:7], uint16(i))
+                       copy(data[SplitHdrSize:], chunk)
+
+                       wg.Add(1)
+                       ack, err := p.sendRaw(rawPkt{
+                               Data:  data,
+                               ChNo:  pkt.ChNo,
+                               Unrel: pkt.Unrel,
+                       })
+                       if err != nil {
+                               return nil, err
+                       }
+                       if !pkt.Unrel {
+                               if ack == nil {
+                                       panic("ack is nil")
+                               }
+                               go func() {
+                                       <-ack
+                                       wg.Done()
+                               }()
+                       }
+               }
+
+               if pkt.Unrel {
+                       return nil, nil
+               } else {
+                       ack := make(chan struct{})
+
+                       go func() {
+                               wg.Wait()
+                               close(ack)
+                       }()
+
+                       return ack, nil
+               }
+       }
+
+       return p.sendRaw(rawPkt{
+               Data:  append([]byte{uint8(rawTypeOrig)}, pkt.Data...),
+               ChNo:  pkt.ChNo,
+               Unrel: pkt.Unrel,
+       })
+}
+
+// sendRaw sends a raw packet to the Peer.
+func (p *Peer) sendRaw(pkt rawPkt) (ack <-chan struct{}, err error) {
+       if pkt.ChNo >= ChannelCount {
+               return nil, ErrChNoTooBig
+       }
+
+       p.mu.RLock()
+       defer p.mu.RUnlock()
+
+       select {
+       case <-p.Disco():
+               return nil, ErrClosed
+       default:
+       }
+
+       if !pkt.Unrel {
+               return p.sendRel(pkt)
+       }
+
+       data := make([]byte, MtHdrSize+len(pkt.Data))
+       binary.BigEndian.PutUint32(data[0:4], protoID)
+       binary.BigEndian.PutUint16(data[4:6], uint16(p.idOfPeer))
+       data[6] = pkt.ChNo
+       copy(data[MtHdrSize:], pkt.Data)
+
+       if len(data) > MaxNetPktSize {
+               return nil, ErrPktTooBig
+       }
+
+       _, err = p.Conn().WriteTo(data, p.Addr())
+       if errors.Is(err, net.ErrWriteToConnected) {
+               conn, ok := p.Conn().(net.Conn)
+               if !ok {
+                       return nil, err
+               }
+               _, err = conn.Write(data)
+       }
+       if err != nil {
+               return nil, err
+       }
+
+       p.ping.Reset(PingTimeout)
+
+       return nil, nil
+}
+
+// sendRel sends a reliable raw packet to the Peer.
+func (p *Peer) sendRel(pkt rawPkt) (ack <-chan struct{}, err error) {
+       if pkt.Unrel {
+               panic("mt/rudp: sendRel: pkt.Unrel is true")
+       }
+
+       c := &p.chans[pkt.ChNo]
+
+       c.outrelmu.Lock()
+       defer c.outrelmu.Unlock()
+
+       sn := c.outrelsn
+       for ; sn-c.outrelwin >= 0x8000; c.outrelwin++ {
+               if ack, ok := c.ackchans.Load(c.outrelwin); ok {
+                       <-ack.(chan struct{})
+               }
+       }
+       c.outrelsn++
+
+       rwack := make(chan struct{}) // close-only
+       c.ackchans.Store(sn, rwack)
+       ack = rwack
+
+       reldata := make([]byte, RelHdrSize+len(pkt.Data))
+       reldata[0] = uint8(rawTypeRel)
+       binary.BigEndian.PutUint16(reldata[1:3], uint16(sn))
+       copy(reldata[RelHdrSize:], pkt.Data)
+       relpkt := rawPkt{
+               Data:  reldata,
+               ChNo:  pkt.ChNo,
+               Unrel: true,
+       }
+
+       if _, err := p.sendRaw(relpkt); err != nil {
+               c.ackchans.Delete(sn)
+
+               return nil, err
+       }
+
+       go func() {
+               resend := time.NewTicker(500 * time.Millisecond)
+               defer resend.Stop()
+
+               for {
+                       select {
+                       case <-resend.C:
+                               if _, err := p.sendRaw(relpkt); err != nil {
+                                       p.errs <- fmt.Errorf("failed to re-send timed out reliable seqnum: %d: %w", sn, err)
+                               }
+                       case <-ack:
+                               return
+                       case <-p.Disco():
+                               return
+                       }
+               }
+       }()
+
+       return ack, nil
+}
+
+// SendDisco sends a disconnect packet to the Peer but does not close it.
+// It returns a channel that's closed when it's acked or an error.
+// The ack channel is nil if unrel is true.
+func (p *Peer) SendDisco(chno uint8, unrel bool) (ack <-chan struct{}, err error) {
+       return p.sendRaw(rawPkt{
+               Data:  []byte{uint8(rawTypeCtl), uint8(ctlDisco)},
+               ChNo:  chno,
+               Unrel: unrel,
+       })
+}
+
+func split(data []byte, chunksize int) [][]byte {
+       chunks := make([][]byte, 0, (len(data)+chunksize-1)/chunksize)
+
+       for i := 0; i < len(data); i += chunksize {
+               end := i + chunksize
+               if end > len(data) {
+                       end = len(data)
+               }
+
+               chunks = append(chunks, data[i:end])
+       }
+
+       return chunks
+}