]> git.lizzy.rs Git - mt.git/blob - proto.go
Add high-level protocol (de)serialization
[mt.git] / proto.go
1 package mt
2
3 import (
4         "fmt"
5         "io"
6         "net"
7
8         "github.com/anon55555/mt/rudp"
9 )
10
11 const ChannelCount = rudp.ChannelCount
12
13 // A Pkt is a deserialized rudp.Pkt.
14 type Pkt struct {
15         Cmd
16         rudp.PktInfo
17 }
18
19 // Peer wraps rudp.Conn, adding (de)serialization.
20 type Peer struct {
21         *rudp.Conn
22 }
23
24 func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
25         var cmdNo uint16
26         if p.IsSrv() {
27                 cmdNo = pkt.Cmd.(ToSrvCmd).toSrvCmdNo()
28         } else {
29                 cmdNo = pkt.Cmd.(ToCltCmd).toCltCmdNo()
30         }
31
32         r, w := io.Pipe()
33         go func() (err error) {
34                 defer w.CloseWithError(err)
35
36                 buf := make([]byte, 2)
37                 be.PutUint16(buf, cmdNo)
38                 if _, err := w.Write(buf); err != nil {
39                         return err
40                 }
41                 return serialize(w, pkt.Cmd)
42         }()
43
44         return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
45 }
46
47 // SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
48 func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
49         return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
50 }
51
52 func (p Peer) Recv() (_ Pkt, rerr error) {
53         pkt, err := p.Conn.Recv()
54         if err != nil {
55                 return Pkt{}, err
56         }
57
58         buf := make([]byte, 2)
59         if _, err := io.ReadFull(pkt, buf); err != nil {
60                 return Pkt{}, err
61         }
62         cmdNo := be.Uint16(buf)
63
64         var newCmd func() Cmd
65         if p.IsSrv() {
66                 newCmd = newToCltCmd[cmdNo]
67         } else {
68                 newCmd = newToSrvCmd[cmdNo]
69         }
70         if newCmd == nil {
71                 return Pkt{}, fmt.Errorf("unknown cmd: %d", cmdNo)
72         }
73         cmd := newCmd()
74
75         if err := deserialize(pkt, cmd); err != nil {
76                 return Pkt{}, fmt.Errorf("%T: %w", cmd, err)
77         }
78
79         extra, err := io.ReadAll(pkt)
80         if len(extra) > 0 {
81                 err = rudp.TrailingDataError(extra)
82         }
83         return Pkt{cmd, pkt.PktInfo}, err
84 }
85
86 func Connect(conn net.Conn) Peer {
87         return Peer{rudp.Connect(conn)}
88 }
89
90 type Listener struct {
91         *rudp.Listener
92 }
93
94 func Listen(conn net.PacketConn) Listener {
95         return Listener{rudp.Listen(conn)}
96 }
97
98 func (l Listener) Accept() (Peer, error) {
99         rpeer, err := l.Listener.Accept()
100         return Peer{rpeer}, err
101 }