]> git.lizzy.rs Git - mt.git/blob - proto.go
Make seperate module to facilitate testing
[mt.git] / proto.go
1 package mt
2
3 import (
4         "fmt"
5         "io"
6         "net"
7
8         "github.com/anon55555/mt/rudp"
9 )
10
11 // A Pkt is a deserialized rudp.Pkt.
12 type Pkt struct {
13         Cmd
14         rudp.PktInfo
15 }
16
17 // Peer wraps rudp.Conn, adding (de)serialization.
18 type Peer struct {
19         *rudp.Conn
20 }
21
22 func (p Peer) Send(pkt Pkt) (ack <-chan struct{}, err error) {
23         var cmdNo uint16
24         if p.IsSrv() {
25                 cmdNo = pkt.Cmd.(ToSrvCmd).toSrvCmdNo()
26         } else {
27                 cmdNo = pkt.Cmd.(ToCltCmd).toCltCmdNo()
28         }
29
30         if cmdNo == 0xffff {
31                 return nil, p.Close()
32         }
33
34         r, w := io.Pipe()
35         go func() (err error) {
36                 defer w.CloseWithError(err)
37
38                 buf := make([]byte, 2)
39                 be.PutUint16(buf, cmdNo)
40                 if _, err := w.Write(buf); err != nil {
41                         return err
42                 }
43                 return serialize(w, pkt.Cmd)
44         }()
45
46         return p.Conn.Send(rudp.Pkt{r, pkt.PktInfo})
47 }
48
49 // SendCmd is equivalent to Send(Pkt{cmd, cmd.DefaultPktInfo()}).
50 func (p Peer) SendCmd(cmd Cmd) (ack <-chan struct{}, err error) {
51         return p.Send(Pkt{cmd, cmd.DefaultPktInfo()})
52 }
53
54 func (p Peer) Recv() (_ Pkt, rerr error) {
55         pkt, err := p.Conn.Recv()
56         if err != nil {
57                 return Pkt{}, err
58         }
59
60         buf := make([]byte, 2)
61         if _, err := io.ReadFull(pkt, buf); err != nil {
62                 return Pkt{}, err
63         }
64         cmdNo := be.Uint16(buf)
65
66         var newCmd func() Cmd
67         if p.IsSrv() {
68                 newCmd = newToCltCmd[cmdNo]
69         } else {
70                 newCmd = newToSrvCmd[cmdNo]
71         }
72         if newCmd == nil {
73                 return Pkt{}, fmt.Errorf("unknown cmd: %d", cmdNo)
74         }
75         cmd := newCmd()
76
77         if err := deserialize(pkt, cmd); err != nil {
78                 return Pkt{}, fmt.Errorf("%T: %w", cmd, err)
79         }
80
81         extra, err := io.ReadAll(pkt)
82         if len(extra) > 0 {
83                 err = fmt.Errorf("%T: %w", cmd, rudp.TrailingDataError(extra))
84         }
85         return Pkt{cmd, pkt.PktInfo}, err
86 }
87
88 func Connect(conn net.Conn) Peer {
89         return Peer{rudp.Connect(conn)}
90 }
91
92 type Listener struct {
93         *rudp.Listener
94 }
95
96 func Listen(conn net.PacketConn) Listener {
97         return Listener{rudp.Listen(conn)}
98 }
99
100 func (l Listener) Accept() (Peer, error) {
101         rpeer, err := l.Listener.Accept()
102         return Peer{rpeer}, err
103 }