]> git.lizzy.rs Git - mt.git/blob - rudp/peer.go
rudp: deprecate ErrClosed and replace with net.ErrClosed
[mt.git] / rudp / peer.go
1 package rudp
2
3 import (
4         "fmt"
5         "net"
6         "sync"
7         "time"
8 )
9
10 const (
11         // ConnTimeout is the amount of time after no packets being received
12         // from a Peer that it is automatically disconnected.
13         ConnTimeout = 30 * time.Second
14
15         // ConnTimeout is the amount of time after no packets being sent
16         // to a Peer that a CtlPing is automatically sent to prevent timeout.
17         PingTimeout = 5 * time.Second
18 )
19
20 // A Peer is a connection to a client or server.
21 type Peer struct {
22         conn net.PacketConn
23         addr net.Addr
24
25         disco chan struct{} // close-only
26
27         id PeerID
28
29         pkts     chan Pkt
30         errs     chan error    // don't close
31         timedout chan struct{} // close-only
32
33         chans [ChannelCount]pktchan // read/write
34
35         mu       sync.RWMutex
36         idOfPeer PeerID
37         timeout  *time.Timer
38         ping     *time.Ticker
39 }
40
41 type pktchan struct {
42         // Only accessed by Peer.processRawPkt.
43         insplit map[seqnum][][]byte
44         inrel   map[seqnum][]byte
45         inrelsn seqnum
46
47         ackchans sync.Map // map[seqnum]chan struct{}
48
49         outsplitmu sync.Mutex
50         outsplitsn seqnum
51
52         outrelmu  sync.Mutex
53         outrelsn  seqnum
54         outrelwin seqnum
55 }
56
57 // Conn returns the net.PacketConn used to communicate with the Peer.
58 func (p *Peer) Conn() net.PacketConn { return p.conn }
59
60 // Addr returns the address of the Peer.
61 func (p *Peer) Addr() net.Addr { return p.addr }
62
63 // Disco returns a channel that is closed when the Peer is closed.
64 func (p *Peer) Disco() <-chan struct{} { return p.disco }
65
66 // ID returns the ID of the Peer.
67 func (p *Peer) ID() PeerID { return p.id }
68
69 // IsSrv reports whether the Peer is a server.
70 func (p *Peer) IsSrv() bool {
71         return p.ID() == PeerIDSrv
72 }
73
74 // TimedOut reports whether the Peer has timed out.
75 func (p *Peer) TimedOut() bool {
76         select {
77         case <-p.timedout:
78                 return true
79         default:
80                 return false
81         }
82 }
83
84 // Recv recieves a packet from the Peer.
85 // You should keep calling this until it returns net.ErrClosed
86 // so it doesn't leak a goroutine.
87 func (p *Peer) Recv() (Pkt, error) {
88         select {
89         case pkt, ok := <-p.pkts:
90                 if !ok {
91                         select {
92                         case err := <-p.errs:
93                                 return Pkt{}, err
94                         default:
95                                 return Pkt{}, net.ErrClosed
96                         }
97                 }
98                 return pkt, nil
99         case err := <-p.errs:
100                 return Pkt{}, err
101         }
102 }
103
104 // Close closes the Peer but does not send a disconnect packet.
105 func (p *Peer) Close() error {
106         p.mu.Lock()
107         defer p.mu.Unlock()
108
109         select {
110         case <-p.Disco():
111                 return net.ErrClosed
112         default:
113         }
114
115         p.timeout.Stop()
116         p.timeout = nil
117         p.ping.Stop()
118         p.ping = nil
119
120         close(p.disco)
121
122         return nil
123 }
124
125 func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
126         p := &Peer{
127                 conn:     conn,
128                 addr:     addr,
129                 id:       id,
130                 idOfPeer: idOfPeer,
131
132                 pkts:  make(chan Pkt),
133                 disco: make(chan struct{}),
134                 errs:  make(chan error),
135         }
136
137         for i := range p.chans {
138                 p.chans[i] = pktchan{
139                         insplit: make(map[seqnum][][]byte),
140                         inrel:   make(map[seqnum][]byte),
141                         inrelsn: seqnumInit,
142
143                         outsplitsn: seqnumInit,
144                         outrelsn:   seqnumInit,
145                         outrelwin:  seqnumInit,
146                 }
147         }
148
149         p.timedout = make(chan struct{})
150         p.timeout = time.AfterFunc(ConnTimeout, func() {
151                 close(p.timedout)
152
153                 p.SendDisco(0, true)
154                 p.Close()
155         })
156
157         p.ping = time.NewTicker(PingTimeout)
158         go p.sendPings(p.ping.C)
159
160         return p
161 }
162
163 func (p *Peer) sendPings(ping <-chan time.Time) {
164         pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
165
166         for {
167                 select {
168                 case <-ping:
169                         if _, err := p.sendRaw(pkt); err != nil {
170                                 p.errs <- fmt.Errorf("can't send ping: %w", err)
171                         }
172                 case <-p.Disco():
173                         return
174                 }
175         }
176 }
177
178 // Connect connects to the server on conn
179 // and closes conn when the Peer disconnects.
180 func Connect(conn net.PacketConn, addr net.Addr) *Peer {
181         srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil)
182
183         pkts := make(chan netPkt)
184         go readNetPkts(conn, pkts, srv.errs)
185         go srv.processNetPkts(pkts)
186
187         go func() {
188                 <-srv.Disco()
189                 conn.Close()
190         }()
191
192         return srv
193 }