-package rudp
-
-import (
- "errors"
- "fmt"
- "io"
- "net"
-)
-
-// A PktError is an error that occured while processing a packet.
-type PktError struct {
- Type string // "net", "raw" or "rel".
- Data []byte
- Err error
-}
-
-func (e PktError) Error() string {
- return fmt.Sprintf("error processing %s pkt: %x: %v", e.Type, e.Data, e.Err)
-}
-
-func (e PktError) Unwrap() error { return e.Err }
-
-func (p *Peer) processNetPkts(pkts <-chan netPkt) {
- for pkt := range pkts {
- if err := p.processNetPkt(pkt); err != nil {
- p.errs <- PktError{"net", pkt.Data, err}
- }
- }
-
- close(p.pkts)
-}
-
-// A TrailingDataError reports a packet with trailing data,
-// it doesn't stop a packet from being processed.
-type TrailingDataError []byte
-
-func (e TrailingDataError) Error() string {
- return fmt.Sprintf("trailing data: %x", []byte(e))
-}
-
-func (p *Peer) processNetPkt(pkt netPkt) (err error) {
- if pkt.SrcAddr.String() != p.Addr().String() {
- return fmt.Errorf("got pkt from wrong addr: %s", p.Addr().String())
- }
-
- if len(pkt.Data) < MtHdrSize {
- return io.ErrUnexpectedEOF
- }
-
- if id := be.Uint32(pkt.Data[0:4]); id != protoID {
- return fmt.Errorf("unsupported protocol id: 0x%08x", id)
- }
-
- // src PeerID at pkt.Data[4:6]
-
- chno := pkt.Data[6]
- if chno >= ChannelCount {
- return fmt.Errorf("invalid channel number: %d: >= ChannelCount", chno)
- }
-
- p.mu.RLock()
- if p.timeout != nil {
- p.timeout.Reset(ConnTimeout)
- }
- p.mu.RUnlock()
-
- rpkt := rawPkt{
- Data: pkt.Data[MtHdrSize:],
- ChNo: chno,
- Unrel: true,
- }
- if err := p.processRawPkt(rpkt); err != nil {
- p.errs <- PktError{"raw", rpkt.Data, err}
- }
-
- return nil
-}
-
-func (p *Peer) processRawPkt(pkt rawPkt) (err error) {
- errWrap := func(format string, a ...interface{}) {
- if err != nil {
- err = fmt.Errorf(format, append(a, err)...)
- }
- }
-
- c := &p.chans[pkt.ChNo]
-
- if len(pkt.Data) < 1 {
- return fmt.Errorf("can't read pkt type: %w", io.ErrUnexpectedEOF)
- }
- switch t := rawType(pkt.Data[0]); t {
- case rawTypeCtl:
- defer errWrap("ctl: %w")
-
- if len(pkt.Data) < 1+1 {
- return fmt.Errorf("can't read type: %w", io.ErrUnexpectedEOF)
- }
- switch ct := ctlType(pkt.Data[1]); ct {
- case ctlAck:
- defer errWrap("ack: %w")
-
- if len(pkt.Data) < 1+1+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[2:4]))
-
- if ack, ok := c.ackChans.LoadAndDelete(sn); ok {
- close(ack.(chan struct{}))
- }
-
- if len(pkt.Data) > 1+1+2 {
- return TrailingDataError(pkt.Data[1+1+2:])
- }
- case ctlSetPeerID:
- defer errWrap("set peer id: %w")
-
- if len(pkt.Data) < 1+1+2 {
- return io.ErrUnexpectedEOF
- }
-
- // Ensure no concurrent senders while peer id changes.
- p.mu.Lock()
- if p.idOfPeer != PeerIDNil {
- return errors.New("peer id already set")
- }
-
- p.idOfPeer = PeerID(be.Uint16(pkt.Data[2:4]))
- p.mu.Unlock()
-
- if len(pkt.Data) > 1+1+2 {
- return TrailingDataError(pkt.Data[1+1+2:])
- }
- case ctlPing:
- defer errWrap("ping: %w")
-
- if len(pkt.Data) > 1+1 {
- return TrailingDataError(pkt.Data[1+1:])
- }
- case ctlDisco:
- defer errWrap("disco: %w")
-
- p.Close()
-
- if len(pkt.Data) > 1+1 {
- return TrailingDataError(pkt.Data[1+1:])
- }
- default:
- return fmt.Errorf("unsupported ctl type: %d", ct)
- }
- case rawTypeOrig:
- p.pkts <- Pkt{
- Data: pkt.Data[1:],
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- }
- case rawTypeSplit:
- defer errWrap("split: %w")
-
- if len(pkt.Data) < 1+2+2+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[1:3]))
- count := be.Uint16(pkt.Data[3:5])
- i := be.Uint16(pkt.Data[5:7])
-
- if i >= count {
- return nil
- }
-
- splits := p.chans[pkt.ChNo].inSplit
-
- // Delete old incomplete split packets
- // so new ones don't get corrupted.
- splits[sn-0x8000] = nil
-
- if splits[sn] == nil {
- splits[sn] = &inSplit{chunks: make([][]byte, count)}
- }
-
- s := splits[sn]
-
- if int(count) != len(s.chunks) {
- return fmt.Errorf("chunk count changed on split packet: %d", sn)
- }
-
- s.chunks[i] = pkt.Data[7:]
- s.size += len(s.chunks[i])
- s.got++
-
- if s.got == len(s.chunks) {
- data := make([]byte, 0, s.size)
- for _, chunk := range s.chunks {
- data = append(data, chunk...)
- }
-
- p.pkts <- Pkt{
- Data: data,
- ChNo: pkt.ChNo,
- Unrel: pkt.Unrel,
- }
-
- splits[sn] = nil
- }
- case rawTypeRel:
- defer errWrap("rel: %w")
-
- if len(pkt.Data) < 1+2 {
- return io.ErrUnexpectedEOF
- }
-
- sn := seqnum(be.Uint16(pkt.Data[1:3]))
-
- ack := make([]byte, 1+1+2)
- ack[0] = uint8(rawTypeCtl)
- ack[1] = uint8(ctlAck)
- be.PutUint16(ack[2:4], uint16(sn))
- if _, err := p.sendRaw(rawPkt{
- Data: ack,
- ChNo: pkt.ChNo,
- Unrel: true,
- }); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return nil
- }
- return fmt.Errorf("can't ack %d: %w", sn, err)
- }
-
- if sn-c.inRelSN >= 0x8000 {
- return nil // Already received.
- }
-
- c.inRel[sn] = pkt.Data[3:]
-
- for ; c.inRel[c.inRelSN] != nil; c.inRelSN++ {
- rpkt := rawPkt{
- Data: c.inRel[c.inRelSN],
- ChNo: pkt.ChNo,
- Unrel: false,
- }
- c.inRel[c.inRelSN] = nil
-
- if err := p.processRawPkt(rpkt); err != nil {
- p.errs <- PktError{"rel", rpkt.Data, err}
- }
- }
- default:
- return fmt.Errorf("unsupported pkt type: %d", t)
- }
-
- return nil
-}