]> git.lizzy.rs Git - mt.git/blob - proto.go
Add WaitGroup to SerializePkt
[mt.git] / proto.go
1 package mt
2
3 import (
4         "fmt"
5         "io"
6         "net"
7         "sync"
8
9         "github.com/dragonfireclient/mt/rudp"
10 )
11
12 // A Pkt is a deserialized rudp.Pkt.
13 type Pkt struct {
14         Cmd
15         rudp.PktInfo
16 }
17
18 // Peer wraps rudp.Conn, adding (de)serialization.
19 type Peer struct {
20         *rudp.Conn
21 }
22
23 func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool, wg *sync.WaitGroup) bool {
24         var cmdNo uint16
25         if toSrv {
26                 cmdNo = pkt.(ToSrvCmd).toSrvCmdNo()
27         } else {
28                 cmdNo = pkt.(ToCltCmd).toCltCmdNo()
29         }
30
31         if cmdNo == 0xffff {
32                 return false
33         }
34
35         wg.Add(1)
36         go func() (err error) {
37                 defer wg.Done()
38                 // defer w.CloseWithError(err)
39                 defer w.Close()
40
41                 buf := make([]byte, 2)
42                 be.PutUint16(buf, cmdNo)
43                 if _, err := w.Write(buf); err != nil {
44                         return err
45                 }
46                 return serialize(w, pkt)
47         }()
48
49         return true
50 }
51
52 func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
53         r, w := io.Pipe()
54         if !SerializePkt(pkt.Cmd, w, p.IsSrv(), &sync.WaitGroup{}) {
55                 return nil, p.Close()
56         }
57
58         return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
59 }
60
61 // SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
62 func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
63         return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
64 }
65
66 func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) {
67         buf := make([]byte, 2)
68         if _, err := io.ReadFull(pkt, buf); err != nil {
69                 return nil, err
70         }
71         cmdNo := be.Uint16(buf)
72
73         var newCmd func() Cmd
74         if fromSrv {
75                 newCmd = newToCltCmd[cmdNo]
76         } else {
77                 newCmd = newToSrvCmd[cmdNo]
78         }
79
80         if newCmd == nil {
81                 return nil, fmt.Errorf("unknown cmd: %d", cmdNo)
82         }
83         cmd := newCmd()
84
85         if err := deserialize(pkt, cmd); err != nil {
86                 return nil, fmt.Errorf("%T: %w", cmd, err)
87         }
88
89         extra, err := io.ReadAll(pkt)
90         if len(extra) > 0 {
91                 err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
92         }
93
94         return &cmd, err
95 }
96
97 func (p Peer) Recv() (_ Pkt, rerr error) {
98         pkt, err := p.Conn.Recv()
99         if err != nil {
100                 return Pkt{}, err
101         }
102
103         cmd, err := DeserializePkt(pkt, p.IsSrv())
104
105         if cmd == nil {
106                 return Pkt{}, err
107         } else {
108                 return Pkt{*cmd, pkt.PktInfo}, err
109         }
110 }
111
112 func Connect(conn net.Conn) Peer {
113         return Peer{rudp.Connect(conn)}
114 }
115
116 type Listener struct {
117         *rudp.Listener
118 }
119
120 func Listen(conn net.PacketConn) Listener {
121         return Listener{rudp.Listen(conn)}
122 }
123
124 func (l Listener) Accept() (Peer, error) {
125         rpeer, err := l.Listener.Accept()
126         return Peer{rpeer}, err
127 }