]> git.lizzy.rs Git - mt.git/blobdiff - proto.go
Add WaitGroup to SerializePkt
[mt.git] / proto.go
index b859a3949b0acd3cde42104e5d32e4bb42ad6035..c5566f038ad3546ea6fbfafde66a2c30e404b89a 100644 (file)
--- a/proto.go
+++ b/proto.go
@@ -4,12 +4,11 @@ import (
        "fmt"
        "io"
        "net"
+       "sync"
 
-       "github.com/anon55555/mt/rudp"
+       "github.com/dragonfireclient/mt/rudp"
 )
 
-const ChannelCount = rudp.ChannelCount
-
 // A Pkt is a deserialized rudp.Pkt.
 type Pkt struct {
        Cmd
@@ -21,26 +20,41 @@ type Peer struct {
        *rudp.Conn
 }
 
-func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool, wg *sync.WaitGroup) bool {
        var cmdNo uint16
-       if p.IsSrv() {
-               cmdNo = pkt.Cmd.(ToSrvCmd).toSrvCmdNo()
+       if toSrv {
+               cmdNo = pkt.(ToSrvCmd).toSrvCmdNo()
        } else {
-               cmdNo = pkt.Cmd.(ToCltCmd).toCltCmdNo()
+               cmdNo = pkt.(ToCltCmd).toCltCmdNo()
        }
 
-       r, w := io.Pipe()
+       if cmdNo == 0xffff {
+               return false
+       }
+
+       wg.Add(1)
        go func() (err error) {
-               defer w.CloseWithError(err)
+               defer wg.Done()
+               // defer w.CloseWithError(err)
+               defer w.Close()
 
                buf := make([]byte, 2)
                be.PutUint16(buf, cmdNo)
                if _, err := w.Write(buf); err != nil {
                        return err
                }
-               return serialize(w, pkt.Cmd)
+               return serialize(w, pkt)
        }()
 
+       return true
+}
+
+func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
+       r, w := io.Pipe()
+       if !SerializePkt(pkt.Cmd, w, p.IsSrv(), &sync.WaitGroup{}) {
+               return nil, p.Close()
+       }
+
        return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
 }
 
@@ -49,38 +63,50 @@ func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
        return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
 }
 
-func (p Peer) Recv() (_ Pkt, rerr error) {
-       pkt, err := p.Conn.Recv()
-       if err != nil {
-               return Pkt{}, err
-       }
-
+func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) {
        buf := make([]byte, 2)
        if _, err := io.ReadFull(pkt, buf); err != nil {
-               return Pkt{}, err
+               return nil, err
        }
        cmdNo := be.Uint16(buf)
 
        var newCmd func() Cmd
-       if p.IsSrv() {
+       if fromSrv {
                newCmd = newToCltCmd[cmdNo]
        } else {
                newCmd = newToSrvCmd[cmdNo]
        }
+
        if newCmd == nil {
-               return Pkt{}, fmt.Errorf("unknown cmd: %d", cmdNo)
+               return nil, fmt.Errorf("unknown cmd: %d", cmdNo)
        }
        cmd := newCmd()
 
        if err := deserialize(pkt, cmd); err != nil {
-               return Pkt{}, fmt.Errorf("%T: %w", cmd, err)
+               return nil, fmt.Errorf("%T: %w", cmd, err)
        }
 
        extra, err := io.ReadAll(pkt)
        if len(extra) > 0 {
-               err = rudp.TrailingDataError(extra)
+               err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
+       }
+
+       return &cmd, err
+}
+
+func (p Peer) Recv() (_ Pkt, rerr error) {
+       pkt, err := p.Conn.Recv()
+       if err != nil {
+               return Pkt{}, err
+       }
+
+       cmd, err := DeserializePkt(pkt, p.IsSrv())
+
+       if cmd == nil {
+               return Pkt{}, err
+       } else {
+               return Pkt{*cmd, pkt.PktInfo}, err
        }
-       return Pkt{cmd, pkt.PktInfo}, err
 }
 
 func Connect(conn net.Conn) Peer {