]> git.lizzy.rs Git - mt.git/blob - proto.go
Expose methods for packet (de-)serialization
[mt.git] / proto.go
1 package mt
2
3 import (
4         "fmt"
5         "io"
6         "net"
7
8         "github.com/dragonfireclient/mt/rudp"
9 )
10
11 // A Pkt is a deserialized rudp.Pkt.
12 type Pkt struct {
13         Cmd
14         rudp.PktInfo
15 }
16
17 // Peer wraps rudp.Conn, adding (de)serialization.
18 type Peer struct {
19         *rudp.Conn
20 }
21
22 func SerializePkt(pkt Cmd, w io.WriteCloser, toSrv bool) bool {
23         var cmdNo uint16
24         if toSrv {
25                 cmdNo = pkt.(ToSrvCmd).toSrvCmdNo()
26         } else {
27                 cmdNo = pkt.(ToCltCmd).toCltCmdNo()
28         }
29
30         if cmdNo == 0xffff {
31                 return false
32         }
33
34         go func() (err error) {
35                 // defer w.CloseWithError(err)
36                 defer w.Close()
37
38                 buf := make([]byte, 2)
39                 be.PutUint16(buf, cmdNo)
40                 if _, err := w.Write(buf); err != nil {
41                         return err
42                 }
43                 return serialize(w, pkt)
44         }()
45
46         return true
47 }
48
49 func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
50         r, w := io.Pipe()
51         if !SerializePkt(pkt.Cmd, w, p.IsSrv()) {
52                 return nil, p.Close()
53         }
54
55         return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
56 }
57
58 // SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
59 func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
60         return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
61 }
62
63 func DeserializePkt(pkt io.Reader, fromSrv bool) (*Cmd, error) {
64         buf := make([]byte, 2)
65         if _, err := io.ReadFull(pkt, buf); err != nil {
66                 return nil, err
67         }
68         cmdNo := be.Uint16(buf)
69
70         var newCmd func() Cmd
71         if fromSrv {
72                 newCmd = newToCltCmd[cmdNo]
73         } else {
74                 newCmd = newToSrvCmd[cmdNo]
75         }
76
77         if newCmd == nil {
78                 return nil, fmt.Errorf("unknown cmd: %d", cmdNo)
79         }
80         cmd := newCmd()
81
82         if err := deserialize(pkt, cmd); err != nil {
83                 return nil, fmt.Errorf("%T: %w", cmd, err)
84         }
85
86         extra, err := io.ReadAll(pkt)
87         if len(extra) > 0 {
88                 err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
89         }
90
91         return &cmd, err
92 }
93
94 func (p Peer) Recv() (_ Pkt, rerr error) {
95         pkt, err := p.Conn.Recv()
96         if err != nil {
97                 return Pkt{}, err
98         }
99
100         cmd, err := DeserializePkt(pkt, p.IsSrv())
101
102         if cmd == nil {
103                 return Pkt{}, err
104         } else {
105                 return Pkt{*cmd, pkt.PktInfo}, err
106         }
107 }
108
109 func Connect(conn net.Conn) Peer {
110         return Peer{rudp.Connect(conn)}
111 }
112
113 type Listener struct {
114         *rudp.Listener
115 }
116
117 func Listen(conn net.PacketConn) Listener {
118         return Listener{rudp.Listen(conn)}
119 }
120
121 func (l Listener) Accept() (Peer, error) {
122         rpeer, err := l.Listener.Accept()
123         return Peer{rpeer}, err
124 }