]> git.lizzy.rs Git - mt.git/blobdiff - rudp/process.go
rudp/proxy: fix import path
[mt.git] / rudp / process.go
index c36af81f7066a311b19f3ee06f1cd2c77e2a051d..7238fe5463ccbc17018e796e978e8749637db2e6 100644 (file)
@@ -1,11 +1,10 @@
 package rudp
 
 import (
-       "encoding/binary"
-       "encoding/hex"
        "errors"
        "fmt"
        "io"
+       "net"
 )
 
 // A PktError is an error that occured while processing a packet.
@@ -16,9 +15,7 @@ type PktError struct {
 }
 
 func (e PktError) Error() string {
-       return "error processing " + e.Type + " pkt: " +
-               hex.EncodeToString(e.Data) + ": " +
-               e.Err.Error()
+       return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
 }
 
 func (e PktError) Unwrap() error { return e.Err }
@@ -38,7 +35,7 @@ func (p *Peer) processNetPkts(pkts <-chan netPkt) {
 type TrailingDataError []byte
 
 func (e TrailingDataError) Error() string {
-       return "trailing data: " + hex.EncodeToString([]byte(e))
+       return fmt.Sprintf("trailing data: %x", []byte(e))
 }
 
 func (p *Peer) processNetPkt(pkt netPkt) (err error) {
@@ -50,7 +47,7 @@ func (p *Peer) processNetPkt(pkt netPkt) (err error) {
                return io.ErrUnexpectedEOF
        }
 
-       if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
+       if id := be.Uint32(pkt.Data[0:4]); id != protoID {
                return fmt.Errorf("unsupported protocol id: 0x%08x", id)
        }
 
@@ -106,9 +103,9 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                                return io.ErrUnexpectedEOF
                        }
 
-                       sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
+                       sn := seqnum(be.Uint16(pkt.Data[2:4]))
 
-                       if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
+                       if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
                                close(ack.(chan struct{}))
                        }
 
@@ -128,7 +125,7 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                                return errors.New("peer id already set")
                        }
 
-                       p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
+                       p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
                        p.mu.Unlock()
 
                        if len(pkt.Data) > 1+1+2 {
@@ -143,9 +140,7 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                case ctlDisco:
                        defer errWrap("disco: %w")
 
-                       if err := p.Close(); err != nil {
-                               return fmt.Errorf("can't close: %w", err)
-                       }
+                       p.Close()
 
                        if len(pkt.Data) > 1+1 {
                                return TrailingDataError(pkt.Data[1+1:])
@@ -166,50 +161,48 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                        return io.ErrUnexpectedEOF
                }
 
-               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
-               count := binary.BigEndian.Uint16(pkt.Data[3:5])
-               i := binary.BigEndian.Uint16(pkt.Data[5:7])
+               sn := seqnum(be.Uint16(pkt.Data[1:3]))
+               count := be.Uint16(pkt.Data[3:5])
+               i := be.Uint16(pkt.Data[5:7])
 
                if i >= count {
                        return nil
                }
 
-               splitpkts := p.chans[pkt.ChNo].insplit
+               splits := p.chans[pkt.ChNo].inSplit
 
                // Delete old incomplete split packets
                // so new ones don't get corrupted.
-               delete(splitpkts, sn-0x8000)
+               splits[sn-0x8000] = nil
 
-               if splitpkts[sn] == nil {
-                       splitpkts[sn] = make([][]byte, count)
+               if splits[sn] == nil {
+                       splits[sn] = &inSplit{chunks: make([][]byte, count)}
                }
 
-               chunks := splitpkts[sn]
+               s := splits[sn]
 
-               if int(count) != len(chunks) {
-                       return fmt.Errorf("chunk count changed on seqnum: %d", sn)
+               if int(count) != len(s.chunks) {
+                       return fmt.Errorf("chunk count changed on split packet: %d", sn)
                }
 
-               chunks[i] = pkt.Data[7:]
+               s.chunks[i] = pkt.Data[7:]
+               s.size += len(s.chunks[i])
+               s.got++
 
-               for _, chunk := range chunks {
-                       if chunk == nil {
-                               return nil
+               if s.got == len(s.chunks) {
+                       data := make([]byte, 0, s.size)
+                       for _, chunk := range s.chunks {
+                               data = append(data, chunk...)
                        }
-               }
 
-               var data []byte
-               for _, chunk := range chunks {
-                       data = append(data, chunk...)
-               }
+                       p.pkts <- Pkt{
+                               Data:  data,
+                               ChNo:  pkt.ChNo,
+                               Unrel: pkt.Unrel,
+                       }
 
-               p.pkts <- Pkt{
-                       Data:  data,
-                       ChNo:  pkt.ChNo,
-                       Unrel: pkt.Unrel,
+                       splits[sn] = nil
                }
-
-               delete(splitpkts, sn)
        case rawTypeRel:
                defer errWrap("rel: %w")
 
@@ -217,36 +210,37 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                        return io.ErrUnexpectedEOF
                }
 
-               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+               sn := seqnum(be.Uint16(pkt.Data[1:3]))
 
-               ackdata := make([]byte, 1+1+2)
-               ackdata[0] = uint8(rawTypeCtl)
-               ackdata[1] = uint8(ctlAck)
-               binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
-               ack := rawPkt{
-                       Data:  ackdata,
+               ack := make([]byte, 1+1+2)
+               ack[0] = uint8(rawTypeCtl)
+               ack[1] = uint8(ctlAck)
+               be.PutUint16(ack[2:4], uint16(sn))
+               if _, err := p.sendRaw(rawPkt{
+                       Data:  ack,
                        ChNo:  pkt.ChNo,
                        Unrel: true,
-               }
-               if _, err := p.sendRaw(ack); err != nil {
+               }); err != nil {
+                       if errors.Is(err, net.ErrClosed) {
+                               return nil
+                       }
                        return fmt.Errorf("can't ack %d: %w", sn, err)
                }
 
-               if sn-c.inrelsn >= 0x8000 {
+               if sn-c.inRelSN >= 0x8000 {
                        return nil // Already received.
                }
 
-               c.inrel[sn] = pkt.Data[3:]
-
-               for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
-                       data := c.inrel[c.inrelsn]
-                       delete(c.inrel, c.inrelsn)
+               c.inRel[sn] = pkt.Data[3:]
 
+               for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
                        rpkt := rawPkt{
-                               Data:  data,
+                               Data:  c.inRel[c.inRelSN],
                                ChNo:  pkt.ChNo,
                                Unrel: false,
                        }
+                       c.inRel[c.inRelSN] = nil
+
                        if err := p.processRawPkt(rpkt); err != nil {
                                p.errs <- PktError{"rel", rpkt.Data, err}
                        }