]> git.lizzy.rs Git - mt.git/blob - rudp/listen.go
Add WaitGroup to SerializePkt
[mt.git] / rudp / listen.go
1 package rudp
2
3 import (
4         "errors"
5         "net"
6         "sync"
7 )
8
9 func tryClose(ch chan struct{}) (ok bool) {
10         defer func() { recover() }()
11         close(ch)
12         return true
13 }
14
15 type udpClt struct {
16         l      *Listener
17         id     PeerID
18         addr   net.Addr
19         pkts   chan []byte
20         closed chan struct{}
21 }
22
23 func (c *udpClt) mkConn() {
24         conn := newConn(c, c.id, PeerIDSrv)
25         go func() {
26                 <-conn.Closed()
27                 c.l.wg.Done()
28         }()
29         conn.sendRaw(func(buf []byte) int {
30                 buf[0] = uint8(rawCtl)
31                 buf[1] = uint8(ctlSetPeerID)
32                 be.PutUint16(buf[2:4], uint16(conn.ID()))
33                 return 4
34         }, PktInfo{})()
35         select {
36         case c.l.conns <- conn:
37         case <-c.l.closed:
38                 conn.Close()
39         }
40 }
41
42 func (c *udpClt) Write(pkt []byte) (int, error) {
43         select {
44         case <-c.closed:
45                 return 0, net.ErrClosed
46         default:
47         }
48
49         return c.l.pc.WriteTo(pkt, c.addr)
50 }
51
52 func (c *udpClt) recvUDP() ([]byte, error) {
53         select {
54         case pkt := <-c.pkts:
55                 return pkt, nil
56         case <-c.closed:
57                 return nil, net.ErrClosed
58         }
59 }
60
61 func (c *udpClt) Close() error {
62         if !tryClose(c.closed) {
63                 return net.ErrClosed
64         }
65
66         c.l.mu.Lock()
67         defer c.l.mu.Unlock()
68
69         delete(c.l.ids, c.id)
70         delete(c.l.clts, c.addr.String())
71
72         return nil
73 }
74
75 func (c *udpClt) LocalAddr() net.Addr  { return c.l.pc.LocalAddr() }
76 func (c *udpClt) RemoteAddr() net.Addr { return c.addr }
77
78 // All Listener's methods are safe for concurrent use.
79 type Listener struct {
80         pc net.PacketConn
81
82         peerID PeerID
83         conns  chan *Conn
84         errs   chan error
85         closed chan struct{}
86         wg     sync.WaitGroup
87
88         mu   sync.RWMutex
89         ids  map[PeerID]bool
90         clts map[string]*udpClt
91 }
92
93 // Listen listens for connections on pc, pc is closed once the returned Listener
94 // and all Conns connected through it are closed.
95 func Listen(pc net.PacketConn) *Listener {
96         l := &Listener{
97                 pc: pc,
98
99                 conns:  make(chan *Conn),
100                 closed: make(chan struct{}),
101
102                 ids:  make(map[PeerID]bool),
103                 clts: make(map[string]*udpClt),
104         }
105
106         go func() {
107                 for {
108                         if err := l.processNetPkt(); err != nil {
109                                 if errors.Is(err, net.ErrClosed) {
110                                         break
111                                 }
112                                 select {
113                                 case l.errs <- err:
114                                 case <-l.closed:
115                                 }
116                         }
117                 }
118         }()
119
120         return l
121 }
122
123 // Accept waits for and returns the next incoming Conn or an error.
124 func (l *Listener) Accept() (*Conn, error) {
125         select {
126         case c := <-l.conns:
127                 return c, nil
128         case err := <-l.errs:
129                 return nil, err
130         case <-l.closed:
131                 return nil, net.ErrClosed
132         }
133 }
134
135 // Close makes the Listener stop listening for new Conns.
136 // Blocked Accept calls will return net.ErrClosed.
137 // Already Accepted Conns are not closed.
138 func (l *Listener) Close() error {
139         if !tryClose(l.closed) {
140                 return net.ErrClosed
141         }
142
143         go func() {
144                 l.wg.Wait()
145                 l.pc.Close()
146         }()
147
148         return nil
149 }
150
151 // Addr returns the Listener's network address.
152 func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() }
153
154 var ErrOutOfPeerIDs = errors.New("out of peer ids")
155
156 func (l *Listener) processNetPkt() error {
157         buf := make([]byte, maxUDPPktSize)
158         n, addr, err := l.pc.ReadFrom(buf)
159         if err != nil {
160                 return err
161         }
162
163         l.mu.RLock()
164         clt, ok := l.clts[addr.String()]
165         l.mu.RUnlock()
166         if !ok {
167                 select {
168                 case <-l.closed:
169                         return nil
170                 default:
171                 }
172
173                 clt, err = l.add(addr)
174                 if err != nil {
175                         return err
176                 }
177         }
178
179         select {
180         case clt.pkts <- buf[:n]:
181         case <-clt.closed:
182         }
183
184         return nil
185 }
186
187 func (l *Listener) add(addr net.Addr) (*udpClt, error) {
188         l.mu.Lock()
189         defer l.mu.Unlock()
190
191         start := l.peerID
192         l.peerID++
193         for l.peerID < PeerIDCltMin || l.ids[l.peerID] {
194                 if l.peerID == start {
195                         return nil, ErrOutOfPeerIDs
196                 }
197                 l.peerID++
198         }
199         l.ids[l.peerID] = true
200
201         clt := &udpClt{
202                 l:      l,
203                 id:     l.peerID,
204                 addr:   addr,
205                 pkts:   make(chan []byte),
206                 closed: make(chan struct{}),
207         }
208         l.clts[addr.String()] = clt
209
210         l.wg.Add(1)
211         go clt.mkConn()
212
213         return clt, nil
214 }