]> git.lizzy.rs Git - mt.git/blob - rudp/peer.go
791249c5107e4d3b1b933d210b67de094c9e3439
[mt.git] / rudp / peer.go
1 package rudp
2
3 import (
4         "errors"
5         "fmt"
6         "net"
7         "sync"
8         "time"
9 )
10
11 const (
12         // ConnTimeout is the amount of time after no packets being received
13         // from a Peer that it is automatically disconnected.
14         ConnTimeout = 30 * time.Second
15
16         // ConnTimeout is the amount of time after no packets being sent
17         // to a Peer that a CtlPing is automatically sent to prevent timeout.
18         PingTimeout = 5 * time.Second
19 )
20
21 // A Peer is a connection to a client or server.
22 type Peer struct {
23         pc   net.PacketConn
24         addr net.Addr
25         conn net.Conn
26
27         disco chan struct{} // close-only
28
29         id PeerID
30
31         pkts     chan Pkt
32         errs     chan error    // don't close
33         timedOut chan struct{} // close-only
34
35         chans [ChannelCount]pktchan // read/write
36
37         mu       sync.RWMutex
38         idOfPeer PeerID
39         timeout  *time.Timer
40         ping     *time.Ticker
41 }
42
43 // Conn returns the net.PacketConn used to communicate with the Peer.
44 func (p *Peer) Conn() net.PacketConn { return p.pc }
45
46 // Addr returns the address of the Peer.
47 func (p *Peer) Addr() net.Addr { return p.addr }
48
49 // Disco returns a channel that is closed when the Peer is closed.
50 func (p *Peer) Disco() <-chan struct{} { return p.disco }
51
52 // ID returns the ID of the Peer.
53 func (p *Peer) ID() PeerID { return p.id }
54
55 // IsSrv reports whether the Peer is a server.
56 func (p *Peer) IsSrv() bool {
57         return p.ID() == PeerIDSrv
58 }
59
60 // TimedOut reports whether the Peer has timed out.
61 func (p *Peer) TimedOut() bool {
62         select {
63         case <-p.timedOut:
64                 return true
65         default:
66                 return false
67         }
68 }
69
70 type inSplit struct {
71         chunks    [][]byte
72         size, got int
73 }
74
75 type pktchan struct {
76         // Only accessed by Peer.processRawPkt.
77         inSplit *[65536]*inSplit
78         inRel   *[65536][]byte
79         inRelSN seqnum
80
81         ackChans sync.Map // map[seqnum]chan struct{}
82
83         outSplitMu sync.Mutex
84         outSplitSN seqnum
85
86         outRelMu  sync.Mutex
87         outRelSN  seqnum
88         outRelWin seqnum
89 }
90
91 // Recv recieves a packet from the Peer.
92 // You should keep calling this until it returns net.ErrClosed
93 // so it doesn't leak a goroutine.
94 func (p *Peer) Recv() (Pkt, error) {
95         select {
96         case pkt, ok := <-p.pkts:
97                 if !ok {
98                         select {
99                         case err := <-p.errs:
100                                 return Pkt{}, err
101                         default:
102                                 return Pkt{}, net.ErrClosed
103                         }
104                 }
105                 return pkt, nil
106         case err := <-p.errs:
107                 return Pkt{}, err
108         }
109 }
110
111 // Close closes the Peer but does not send a disconnect packet.
112 func (p *Peer) Close() error {
113         p.mu.Lock()
114         defer p.mu.Unlock()
115
116         select {
117         case <-p.Disco():
118                 return net.ErrClosed
119         default:
120         }
121
122         p.timeout.Stop()
123         p.timeout = nil
124         p.ping.Stop()
125         p.ping = nil
126
127         close(p.disco)
128
129         return nil
130 }
131
132 func newPeer(pc net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
133         p := &Peer{
134                 pc:       pc,
135                 addr:     addr,
136                 id:       id,
137                 idOfPeer: idOfPeer,
138
139                 pkts:  make(chan Pkt),
140                 disco: make(chan struct{}),
141                 errs:  make(chan error),
142         }
143
144         if conn, ok := pc.(net.Conn); ok && conn.RemoteAddr() != nil {
145                 p.conn = conn
146         }
147
148         for i := range p.chans {
149                 p.chans[i] = pktchan{
150                         inSplit: new([65536]*inSplit),
151                         inRel:   new([65536][]byte),
152                         inRelSN: seqnumInit,
153
154                         outSplitSN: seqnumInit,
155                         outRelSN:   seqnumInit,
156                         outRelWin:  seqnumInit,
157                 }
158         }
159
160         p.timedOut = make(chan struct{})
161         p.timeout = time.AfterFunc(ConnTimeout, func() {
162                 close(p.timedOut)
163
164                 p.SendDisco(0, true)
165                 p.Close()
166         })
167
168         p.ping = time.NewTicker(PingTimeout)
169         go p.sendPings(p.ping.C)
170
171         return p
172 }
173
174 func (p *Peer) sendPings(ping <-chan time.Time) {
175         pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
176
177         for {
178                 select {
179                 case <-ping:
180                         if _, err := p.sendRaw(pkt); err != nil {
181                                 if errors.Is(err, net.ErrClosed) {
182                                         return
183                                 }
184                                 p.errs <- fmt.Errorf("can't send ping: %w", err)
185                         }
186                 case <-p.Disco():
187                         return
188                 }
189         }
190 }
191
192 // Connect connects to addr using pc
193 // and closes pc when the returned *Peer disconnects.
194 func Connect(pc net.PacketConn, addr net.Addr) *Peer {
195         srv := newPeer(pc, addr, PeerIDSrv, PeerIDNil)
196
197         pkts := make(chan netPkt)
198         go readNetPkts(pc, pkts, srv.errs)
199         go srv.processNetPkts(pkts)
200
201         go func() {
202                 <-srv.Disco()
203                 pc.Close()
204         }()
205
206         return srv
207 }