]> git.lizzy.rs Git - mt.git/blob - rudp/listen.go
5b7154a56c7dc02ef91e89a5ff5888e074a4c21f
[mt.git] / rudp / listen.go
1 package rudp
2
3 import (
4         "errors"
5         "fmt"
6         "net"
7         "sync"
8 )
9
10 type Listener struct {
11         conn net.PacketConn
12
13         clts chan cltPeer
14         errs chan error
15
16         mu        sync.Mutex
17         addr2peer map[string]cltPeer
18         id2peer   map[PeerID]cltPeer
19         peerID    PeerID
20 }
21
22 // Listen listens for packets on conn until it is closed.
23 func Listen(conn net.PacketConn) *Listener {
24         l := &Listener{
25                 conn: conn,
26
27                 clts: make(chan cltPeer),
28                 errs: make(chan error),
29
30                 addr2peer: make(map[string]cltPeer),
31                 id2peer:   make(map[PeerID]cltPeer),
32         }
33
34         pkts := make(chan netPkt)
35         go readNetPkts(l.conn, pkts, l.errs)
36         go func() {
37                 for pkt := range pkts {
38                         if err := l.processNetPkt(pkt); err != nil {
39                                 l.errs <- err
40                         }
41                 }
42
43                 close(l.clts)
44
45                 for _, clt := range l.addr2peer {
46                         clt.Close()
47                 }
48         }()
49
50         return l
51 }
52
53 // Accept waits for and returns a connecting Peer.
54 // You should keep calling this until it returns net.ErrClosed
55 // so it doesn't leak a goroutine.
56 func (l *Listener) Accept() (*Peer, error) {
57         select {
58         case clt, ok := <-l.clts:
59                 if !ok {
60                         select {
61                         case err := <-l.errs:
62                                 return nil, err
63                         default:
64                                 return nil, net.ErrClosed
65                         }
66                 }
67                 close(clt.accepted)
68                 return clt.Peer, nil
69         case err := <-l.errs:
70                 return nil, err
71         }
72 }
73
74 // Addr returns the net.PacketConn the Listener is listening on.
75 func (l *Listener) Conn() net.PacketConn { return l.conn }
76
77 var ErrOutOfPeerIDs = errors.New("out of peer ids")
78
79 type cltPeer struct {
80         *Peer
81         pkts     chan<- netPkt
82         accepted chan struct{} // close-only
83 }
84
85 func (l *Listener) processNetPkt(pkt netPkt) error {
86         l.mu.Lock()
87         defer l.mu.Unlock()
88
89         addrstr := pkt.SrcAddr.String()
90
91         clt, ok := l.addr2peer[addrstr]
92         if !ok {
93                 prev := l.peerID
94                 for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin {
95                         if l.peerID == prev-1 {
96                                 return ErrOutOfPeerIDs
97                         }
98                         l.peerID++
99                 }
100
101                 pkts := make(chan netPkt, 256)
102
103                 clt = cltPeer{
104                         Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv),
105                         pkts:     pkts,
106                         accepted: make(chan struct{}),
107                 }
108
109                 l.addr2peer[addrstr] = clt
110                 l.id2peer[clt.ID()] = clt
111
112                 data := make([]byte, 1+1+2)
113                 data[0] = uint8(rawTypeCtl)
114                 data[1] = uint8(ctlSetPeerID)
115                 be.PutUint16(data[2:4], uint16(clt.ID()))
116                 if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
117                         if errors.Is(err, net.ErrClosed) {
118                                 return nil
119                         }
120                         return fmt.Errorf("can't set client peer id: %w", err)
121                 }
122
123                 go func() {
124                         select {
125                         case l.clts <- clt:
126                         case <-clt.Disco():
127                         }
128
129                         clt.processNetPkts(pkts)
130                 }()
131
132                 go func() {
133                         <-clt.Disco()
134
135                         l.mu.Lock()
136                         close(pkts)
137                         delete(l.addr2peer, addrstr)
138                         delete(l.id2peer, clt.ID())
139                         l.mu.Unlock()
140                 }()
141         }
142
143         select {
144         case <-clt.accepted:
145                 clt.pkts <- pkt
146         default:
147                 select {
148                 case clt.pkts <- pkt:
149                 default:
150                         // It's OK to drop packets if the buffer is full
151                         // because MT RUDP can cope with packet loss.
152                 }
153         }
154
155         return nil
156 }