]> git.lizzy.rs Git - mt.git/blob - rudp/peer.go
rudp: fix errors returned by Peer.Recv other than net.ErrClosed when the Peer is...
[mt.git] / rudp / peer.go
1 package rudp
2
3 import (
4         "errors"
5         "fmt"
6         "net"
7         "sync"
8         "time"
9 )
10
11 const (
12         // ConnTimeout is the amount of time after no packets being received
13         // from a Peer that it is automatically disconnected.
14         ConnTimeout = 30 * time.Second
15
16         // ConnTimeout is the amount of time after no packets being sent
17         // to a Peer that a CtlPing is automatically sent to prevent timeout.
18         PingTimeout = 5 * time.Second
19 )
20
21 // A Peer is a connection to a client or server.
22 type Peer struct {
23         conn net.PacketConn
24         addr net.Addr
25
26         disco chan struct{} // close-only
27
28         id PeerID
29
30         pkts     chan Pkt
31         errs     chan error    // don't close
32         timedout chan struct{} // close-only
33
34         chans [ChannelCount]pktchan // read/write
35
36         mu       sync.RWMutex
37         idOfPeer PeerID
38         timeout  *time.Timer
39         ping     *time.Ticker
40 }
41
42 type pktchan struct {
43         // Only accessed by Peer.processRawPkt.
44         insplit map[seqnum][][]byte
45         inrel   map[seqnum][]byte
46         inrelsn seqnum
47
48         ackchans sync.Map // map[seqnum]chan struct{}
49
50         outsplitmu sync.Mutex
51         outsplitsn seqnum
52
53         outrelmu  sync.Mutex
54         outrelsn  seqnum
55         outrelwin seqnum
56 }
57
58 // Conn returns the net.PacketConn used to communicate with the Peer.
59 func (p *Peer) Conn() net.PacketConn { return p.conn }
60
61 // Addr returns the address of the Peer.
62 func (p *Peer) Addr() net.Addr { return p.addr }
63
64 // Disco returns a channel that is closed when the Peer is closed.
65 func (p *Peer) Disco() <-chan struct{} { return p.disco }
66
67 // ID returns the ID of the Peer.
68 func (p *Peer) ID() PeerID { return p.id }
69
70 // IsSrv reports whether the Peer is a server.
71 func (p *Peer) IsSrv() bool {
72         return p.ID() == PeerIDSrv
73 }
74
75 // TimedOut reports whether the Peer has timed out.
76 func (p *Peer) TimedOut() bool {
77         select {
78         case <-p.timedout:
79                 return true
80         default:
81                 return false
82         }
83 }
84
85 // Recv recieves a packet from the Peer.
86 // You should keep calling this until it returns net.ErrClosed
87 // so it doesn't leak a goroutine.
88 func (p *Peer) Recv() (Pkt, error) {
89         select {
90         case pkt, ok := <-p.pkts:
91                 if !ok {
92                         select {
93                         case err := <-p.errs:
94                                 return Pkt{}, err
95                         default:
96                                 return Pkt{}, net.ErrClosed
97                         }
98                 }
99                 return pkt, nil
100         case err := <-p.errs:
101                 return Pkt{}, err
102         }
103 }
104
105 // Close closes the Peer but does not send a disconnect packet.
106 func (p *Peer) Close() error {
107         p.mu.Lock()
108         defer p.mu.Unlock()
109
110         select {
111         case <-p.Disco():
112                 return net.ErrClosed
113         default:
114         }
115
116         p.timeout.Stop()
117         p.timeout = nil
118         p.ping.Stop()
119         p.ping = nil
120
121         close(p.disco)
122
123         return nil
124 }
125
126 func newPeer(conn net.PacketConn, addr net.Addr, id, idOfPeer PeerID) *Peer {
127         p := &Peer{
128                 conn:     conn,
129                 addr:     addr,
130                 id:       id,
131                 idOfPeer: idOfPeer,
132
133                 pkts:  make(chan Pkt),
134                 disco: make(chan struct{}),
135                 errs:  make(chan error),
136         }
137
138         for i := range p.chans {
139                 p.chans[i] = pktchan{
140                         insplit: make(map[seqnum][][]byte),
141                         inrel:   make(map[seqnum][]byte),
142                         inrelsn: seqnumInit,
143
144                         outsplitsn: seqnumInit,
145                         outrelsn:   seqnumInit,
146                         outrelwin:  seqnumInit,
147                 }
148         }
149
150         p.timedout = make(chan struct{})
151         p.timeout = time.AfterFunc(ConnTimeout, func() {
152                 close(p.timedout)
153
154                 p.SendDisco(0, true)
155                 p.Close()
156         })
157
158         p.ping = time.NewTicker(PingTimeout)
159         go p.sendPings(p.ping.C)
160
161         return p
162 }
163
164 func (p *Peer) sendPings(ping <-chan time.Time) {
165         pkt := rawPkt{Data: []byte{uint8(rawTypeCtl), uint8(ctlPing)}}
166
167         for {
168                 select {
169                 case <-ping:
170                         if _, err := p.sendRaw(pkt); err != nil {
171                                 if errors.Is(err, net.ErrClosed) {
172                                         return
173                                 }
174                                 p.errs <- fmt.Errorf("can't send ping: %w", err)
175                         }
176                 case <-p.Disco():
177                         return
178                 }
179         }
180 }
181
182 // Connect connects to the server on conn
183 // and closes conn when the returned *Peer disconnects.
184 func Connect(conn net.PacketConn, addr net.Addr) *Peer {
185         srv := newPeer(conn, addr, PeerIDSrv, PeerIDNil)
186
187         pkts := make(chan netPkt)
188         go readNetPkts(conn, pkts, srv.errs)
189         go srv.processNetPkts(pkts)
190
191         go func() {
192                 <-srv.Disco()
193                 conn.Close()
194         }()
195
196         return srv
197 }