]> git.lizzy.rs Git - mt.git/blob - rudp/listen.go
871a5912adcc18a0ad005c39f01454e0543cd98f
[mt.git] / rudp / listen.go
1 package rudp
2
3 import (
4         "encoding/binary"
5         "errors"
6         "fmt"
7         "net"
8         "sync"
9 )
10
11 type Listener struct {
12         conn net.PacketConn
13
14         clts chan cltPeer
15         errs chan error
16
17         mu        sync.Mutex
18         addr2peer map[string]cltPeer
19         id2peer   map[PeerID]cltPeer
20         peerid    PeerID
21 }
22
23 // Listen listens for packets on conn until it is closed.
24 func Listen(conn net.PacketConn) *Listener {
25         l := &Listener{
26                 conn: conn,
27
28                 clts: make(chan cltPeer),
29                 errs: make(chan error),
30
31                 addr2peer: make(map[string]cltPeer),
32                 id2peer:   make(map[PeerID]cltPeer),
33         }
34
35         pkts := make(chan netPkt)
36         go readNetPkts(l.conn, pkts, l.errs)
37         go func() {
38                 for pkt := range pkts {
39                         if err := l.processNetPkt(pkt); err != nil {
40                                 l.errs <- err
41                         }
42                 }
43
44                 close(l.clts)
45
46                 for _, clt := range l.addr2peer {
47                         clt.Close()
48                 }
49         }()
50
51         return l
52 }
53
54 // Accept waits for and returns a connecting Peer.
55 // You should keep calling this until it returns ErrClosed
56 // so it doesn't leak a goroutine.
57 func (l *Listener) Accept() (*Peer, error) {
58         select {
59         case clt, ok := <-l.clts:
60                 if !ok {
61                         select {
62                         case err := <-l.errs:
63                                 return nil, err
64                         default:
65                                 return nil, ErrClosed
66                         }
67                 }
68                 close(clt.accepted)
69                 return clt.Peer, nil
70         case err := <-l.errs:
71                 return nil, err
72         }
73 }
74
75 // Addr returns the net.PacketConn the Listener is listening on.
76 func (l *Listener) Conn() net.PacketConn { return l.conn }
77
78 var ErrOutOfPeerIDs = errors.New("out of peer ids")
79
80 type cltPeer struct {
81         *Peer
82         pkts     chan<- netPkt
83         accepted chan struct{} // close-only
84 }
85
86 func (l *Listener) processNetPkt(pkt netPkt) error {
87         l.mu.Lock()
88         defer l.mu.Unlock()
89
90         addrstr := pkt.SrcAddr.String()
91
92         clt, ok := l.addr2peer[addrstr]
93         if !ok {
94                 prev := l.peerid
95                 for l.id2peer[l.peerid].Peer != nil || l.peerid < PeerIDCltMin {
96                         if l.peerid == prev-1 {
97                                 return ErrOutOfPeerIDs
98                         }
99                         l.peerid++
100                 }
101
102                 pkts := make(chan netPkt, 256)
103
104                 clt = cltPeer{
105                         Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerid, PeerIDSrv),
106                         pkts:     pkts,
107                         accepted: make(chan struct{}),
108                 }
109
110                 l.addr2peer[addrstr] = clt
111                 l.id2peer[clt.ID()] = clt
112
113                 data := make([]byte, 1+1+2)
114                 data[0] = uint8(rawTypeCtl)
115                 data[1] = uint8(ctlSetPeerID)
116                 binary.BigEndian.PutUint16(data[2:4], uint16(clt.ID()))
117                 if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
118                         return fmt.Errorf("can't set client peer id: %w", err)
119                 }
120
121                 go func() {
122                         select {
123                         case l.clts <- clt:
124                         case <-clt.Disco():
125                         }
126
127                         clt.processNetPkts(pkts)
128                 }()
129
130                 go func() {
131                         <-clt.Disco()
132
133                         l.mu.Lock()
134                         close(pkts)
135                         delete(l.addr2peer, addrstr)
136                         delete(l.id2peer, clt.ID())
137                         l.mu.Unlock()
138                 }()
139         }
140
141         select {
142         case <-clt.accepted:
143                 clt.pkts <- pkt
144         default:
145                 select {
146                 case clt.pkts <- pkt:
147                 default:
148                         return fmt.Errorf("ignoring net pkt from %s because buf is full", addrstr)
149                 }
150         }
151
152         return nil
153 }