]> git.lizzy.rs Git - mt.git/blobdiff - rudp/process.go
rudp/proxy: fix import path
[mt.git] / rudp / process.go
index c85aba4edafec01fb008bc174a4c20b4c0e066a7..7238fe5463ccbc17018e796e978e8749637db2e6 100644 (file)
@@ -1,7 +1,6 @@
 package rudp
 
 import (
-       "encoding/binary"
        "errors"
        "fmt"
        "io"
@@ -48,7 +47,7 @@ func (p *Peer) processNetPkt(pkt netPkt) (err error) {
                return io.ErrUnexpectedEOF
        }
 
-       if id := binary.BigEndian.Uint32(pkt.Data[0:4]); id != protoID {
+       if id := be.Uint32(pkt.Data[0:4]); id != protoID {
                return fmt.Errorf("unsupported protocol id: 0x%08x", id)
        }
 
@@ -104,9 +103,9 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                                return io.ErrUnexpectedEOF
                        }
 
-                       sn := seqnum(binary.BigEndian.Uint16(pkt.Data[2:4]))
+                       sn := seqnum(be.Uint16(pkt.Data[2:4]))
 
-                       if ack, ok := c.ackchans.LoadAndDelete(sn); ok {
+                       if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
                                close(ack.(chan struct{}))
                        }
 
@@ -126,7 +125,7 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                                return errors.New("peer id already set")
                        }
 
-                       p.idOfPeer = PeerID(binary.BigEndian.Uint16(pkt.Data[2:4]))
+                       p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
                        p.mu.Unlock()
 
                        if len(pkt.Data) > 1+1+2 {
@@ -162,50 +161,48 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                        return io.ErrUnexpectedEOF
                }
 
-               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
-               count := binary.BigEndian.Uint16(pkt.Data[3:5])
-               i := binary.BigEndian.Uint16(pkt.Data[5:7])
+               sn := seqnum(be.Uint16(pkt.Data[1:3]))
+               count := be.Uint16(pkt.Data[3:5])
+               i := be.Uint16(pkt.Data[5:7])
 
                if i >= count {
                        return nil
                }
 
-               splitpkts := p.chans[pkt.ChNo].insplit
+               splits := p.chans[pkt.ChNo].inSplit
 
                // Delete old incomplete split packets
                // so new ones don't get corrupted.
-               delete(splitpkts, sn-0x8000)
+               splits[sn-0x8000] = nil
 
-               if splitpkts[sn] == nil {
-                       splitpkts[sn] = make([][]byte, count)
+               if splits[sn] == nil {
+                       splits[sn] = &inSplit{chunks: make([][]byte, count)}
                }
 
-               chunks := splitpkts[sn]
+               s := splits[sn]
 
-               if int(count) != len(chunks) {
-                       return fmt.Errorf("chunk count changed on seqnum: %d", sn)
+               if int(count) != len(s.chunks) {
+                       return fmt.Errorf("chunk count changed on split packet: %d", sn)
                }
 
-               chunks[i] = pkt.Data[7:]
+               s.chunks[i] = pkt.Data[7:]
+               s.size += len(s.chunks[i])
+               s.got++
 
-               for _, chunk := range chunks {
-                       if chunk == nil {
-                               return nil
+               if s.got == len(s.chunks) {
+                       data := make([]byte, 0, s.size)
+                       for _, chunk := range s.chunks {
+                               data = append(data, chunk...)
                        }
-               }
 
-               var data []byte
-               for _, chunk := range chunks {
-                       data = append(data, chunk...)
-               }
+                       p.pkts <- Pkt{
+                               Data:  data,
+                               ChNo:  pkt.ChNo,
+                               Unrel: pkt.Unrel,
+                       }
 
-               p.pkts <- Pkt{
-                       Data:  data,
-                       ChNo:  pkt.ChNo,
-                       Unrel: pkt.Unrel,
+                       splits[sn] = nil
                }
-
-               delete(splitpkts, sn)
        case rawTypeRel:
                defer errWrap("rel: %w")
 
@@ -213,39 +210,37 @@ func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
                        return io.ErrUnexpectedEOF
                }
 
-               sn := seqnum(binary.BigEndian.Uint16(pkt.Data[1:3]))
+               sn := seqnum(be.Uint16(pkt.Data[1:3]))
 
-               ackdata := make([]byte, 1+1+2)
-               ackdata[0] = uint8(rawTypeCtl)
-               ackdata[1] = uint8(ctlAck)
-               binary.BigEndian.PutUint16(ackdata[2:4], uint16(sn))
-               ack := rawPkt{
-                       Data:  ackdata,
+               ack := make([]byte, 1+1+2)
+               ack[0] = uint8(rawTypeCtl)
+               ack[1] = uint8(ctlAck)
+               be.PutUint16(ack[2:4], uint16(sn))
+               if _, err := p.sendRaw(rawPkt{
+                       Data:  ack,
                        ChNo:  pkt.ChNo,
                        Unrel: true,
-               }
-               if _, err := p.sendRaw(ack); err != nil {
+               }); err != nil {
                        if errors.Is(err, net.ErrClosed) {
                                return nil
                        }
                        return fmt.Errorf("can't ack %d: %w", sn, err)
                }
 
-               if sn-c.inrelsn >= 0x8000 {
+               if sn-c.inRelSN >= 0x8000 {
                        return nil // Already received.
                }
 
-               c.inrel[sn] = pkt.Data[3:]
-
-               for ; c.inrel[c.inrelsn] != nil; c.inrelsn++ {
-                       data := c.inrel[c.inrelsn]
-                       delete(c.inrel, c.inrelsn)
+               c.inRel[sn] = pkt.Data[3:]
 
+               for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
                        rpkt := rawPkt{
-                               Data:  data,
+                               Data:  c.inRel[c.inRelSN],
                                ChNo:  pkt.ChNo,
                                Unrel: false,
                        }
+                       c.inRel[c.inRelSN] = nil
+
                        if err := p.processRawPkt(rpkt); err != nil {
                                p.errs <- PktError{"rel", rpkt.Data, err}
                        }