From e1bfd543b068fd64d7e12f6ea4f5a8a013085f74 Mon Sep 17 00:00:00 2001 From: Lizzy Fleckenstein Date: Sun, 12 Feb 2023 19:25:09 +0100 Subject: [PATCH] Expose methods for packet (de-)serialization --- proto.go | 59 ++++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/proto.go b/proto.go index 0e02771..597e146 100644 --- a/proto.go +++ b/proto.go @@ -19,30 +19,39 @@ type Peer struct { *rudp.Conn } -func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { +func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool) bool { var cmdNo uint16 - if p.IsSrv() { - cmdNo = pkt.Cmd.(ToSrvCmd).toSrvCmdNo() + if toSrv { + cmdNo = pkt.(ToSrvCmd).toSrvCmdNo() } else { - cmdNo = pkt.Cmd.(ToCltCmd).toCltCmdNo() + cmdNo = pkt.(ToCltCmd).toCltCmdNo() } if cmdNo == 0xffff { - return nil, p.Close() + return false } - r, w := io.Pipe() go func() (err error) { - defer w.CloseWithError(err) + // defer w.CloseWithError(err) + defer w.Close() buf := make([]byte, 2) be.PutUint16(buf, cmdNo) if _, err := w.Write(buf); err != nil { return err } - return serialize(w, pkt.Cmd) + return serialize(w, pkt) }() + return true +} + +func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) { + r, w := io.Pipe() + if !SerializePkt(pkt.Cmd, w, p.IsSrv()) { + return nil, p.Close() + } + return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo}) } @@ -51,38 +60,50 @@ func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) { return p.Send(Pkt{cmd, cmd.DefaultPktInfo()}) } -func (p Peer) Recv() (_ Pkt, rerr error) { - pkt, err := p.Conn.Recv() - if err != nil { - return Pkt{}, err - } - +func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) { buf := make([]byte, 2) if _, err := io.ReadFull(pkt, buf); err != nil { - return Pkt{}, err + return nil, err } cmdNo := be.Uint16(buf) var newCmd func() Cmd - if p.IsSrv() { + if fromSrv { newCmd = newToCltCmd[cmdNo] } else { newCmd = newToSrvCmd[cmdNo] } + if newCmd == nil { - return Pkt{}, fmt.Errorf("unknown cmd: %d", cmdNo) + return nil, fmt.Errorf("unknown cmd: %d", cmdNo) } cmd := newCmd() if err := deserialize(pkt, cmd); err != nil { - return Pkt{}, fmt.Errorf("%T: %w", cmd, err) + return nil, fmt.Errorf("%T: %w", cmd, err) } extra, err := io.ReadAll(pkt) if len(extra) > 0 { err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra)) } - return Pkt{cmd, pkt.PktInfo}, err + + return &cmd, err +} + +func (p Peer) Recv() (_ Pkt, rerr error) { + pkt, err := p.Conn.Recv() + if err != nil { + return Pkt{}, err + } + + cmd, err := DeserializePkt(pkt, p.IsSrv()) + + if cmd == nil { + return Pkt{}, err + } else { + return Pkt{*cmd, pkt.PktInfo}, err + } } func Connect(conn net.Conn) Peer { -- 2.44.0