import (
"errors"
- "fmt"
"net"
"sync"
)
+func tryClose(ch chan struct{}) (ok bool) {
+ defer func() { recover() }()
+ close(ch)
+ return true
+}
+
+type udpClt struct {
+ l *Listener
+ id PeerID
+ addr net.Addr
+ pkts chan []byte
+ closed chan struct{}
+}
+
+func (c *udpClt) mkConn() {
+ conn := newConn(c, c.id, PeerIDSrv)
+ go func() {
+ <-conn.Closed()
+ c.l.wg.Done()
+ }()
+ conn.sendRaw(func(buf []byte) int {
+ buf[0] = uint8(rawCtl)
+ buf[1] = uint8(ctlSetPeerID)
+ be.PutUint16(buf[2:4], uint16(conn.ID()))
+ return 4
+ }, PktInfo{})()
+ select {
+ case c.l.conns <- conn:
+ case <-c.l.closed:
+ conn.Close()
+ }
+}
+
+func (c *udpClt) Write(pkt []byte) (int, error) {
+ select {
+ case <-c.closed:
+ return 0, net.ErrClosed
+ default:
+ }
+
+ return c.l.pc.WriteTo(pkt, c.addr)
+}
+
+func (c *udpClt) recvUDP() ([]byte, error) {
+ select {
+ case pkt := <-c.pkts:
+ return pkt, nil
+ case <-c.closed:
+ return nil, net.ErrClosed
+ }
+}
+
+func (c *udpClt) Close() error {
+ if !tryClose(c.closed) {
+ return net.ErrClosed
+ }
+
+ c.l.mu.Lock()
+ defer c.l.mu.Unlock()
+
+ delete(c.l.ids, c.id)
+ delete(c.l.clts, c.addr.String())
+
+ return nil
+}
+
+func (c *udpClt) LocalAddr() net.Addr { return c.l.pc.LocalAddr() }
+func (c *udpClt) RemoteAddr() net.Addr { return c.addr }
+
+// All Listener's methods are safe for concurrent use.
type Listener struct {
- conn net.PacketConn
+ pc net.PacketConn
- clts chan cltPeer
- errs chan error
+ peerID PeerID
+ conns chan *Conn
+ errs chan error
+ closed chan struct{}
+ wg sync.WaitGroup
- mu sync.Mutex
- addr2peer map[string]cltPeer
- id2peer map[PeerID]cltPeer
- peerID PeerID
+ mu sync.RWMutex
+ ids map[PeerID]bool
+ clts map[string]*udpClt
}
-// Listen listens for packets on conn until it is closed.
-func Listen(conn net.PacketConn) *Listener {
+// Listen listens for connections on pc, pc is closed once the returned Listener
+// and all Conns connected through it are closed.
+func Listen(pc net.PacketConn) *Listener {
l := &Listener{
- conn: conn,
+ pc: pc,
- clts: make(chan cltPeer),
- errs: make(chan error),
+ conns: make(chan *Conn),
+ closed: make(chan struct{}),
- addr2peer: make(map[string]cltPeer),
- id2peer: make(map[PeerID]cltPeer),
+ ids: make(map[PeerID]bool),
+ clts: make(map[string]*udpClt),
}
- pkts := make(chan netPkt)
- go readNetPkts(l.conn, pkts, l.errs)
go func() {
- for pkt := range pkts {
- if err := l.processNetPkt(pkt); err != nil {
- l.errs <- err
+ for {
+ if err := l.processNetPkt(); err != nil {
+ if errors.Is(err, net.ErrClosed) {
+ break
+ }
+ select {
+ case l.errs <- err:
+ case <-l.closed:
+ }
}
}
-
- close(l.clts)
-
- for _, clt := range l.addr2peer {
- clt.Close()
- }
}()
return l
}
-// Accept waits for and returns a connecting Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (l *Listener) Accept() (*Peer, error) {
+// Accept waits for and returns the next incoming Conn or an error.
+func (l *Listener) Accept() (*Conn, error) {
select {
- case clt, ok := <-l.clts:
- if !ok {
- select {
- case err := <-l.errs:
- return nil, err
- default:
- return nil, net.ErrClosed
- }
- }
- close(clt.accepted)
- return clt.Peer, nil
+ case c := <-l.conns:
+ return c, nil
case err := <-l.errs:
return nil, err
+ case <-l.closed:
+ return nil, net.ErrClosed
}
}
-// Addr returns the net.PacketConn the Listener is listening on.
-func (l *Listener) Conn() net.PacketConn { return l.conn }
+// Close makes the Listener stop listening for new Conns.
+// Blocked Accept calls will return net.ErrClosed.
+// Already Accepted Conns are not closed.
+func (l *Listener) Close() error {
+ if !tryClose(l.closed) {
+ return net.ErrClosed
+ }
-var ErrOutOfPeerIDs = errors.New("out of peer ids")
+ go func() {
+ l.wg.Wait()
+ l.pc.Close()
+ }()
-type cltPeer struct {
- *Peer
- pkts chan<- netPkt
- accepted chan struct{} // close-only
+ return nil
}
-func (l *Listener) processNetPkt(pkt netPkt) error {
- l.mu.Lock()
- defer l.mu.Unlock()
-
- addrstr := pkt.SrcAddr.String()
+// Addr returns the Listener's network address.
+func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() }
- clt, ok := l.addr2peer[addrstr]
- if !ok {
- prev := l.peerID
- for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin {
- if l.peerID == prev-1 {
- return ErrOutOfPeerIDs
- }
- l.peerID++
- }
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
- pkts := make(chan netPkt, 256)
+func (l *Listener) processNetPkt() error {
+ buf := make([]byte, maxUDPPktSize)
+ n, addr, err := l.pc.ReadFrom(buf)
+ if err != nil {
+ return err
+ }
- clt = cltPeer{
- Peer: newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv),
- pkts: pkts,
- accepted: make(chan struct{}),
+ l.mu.RLock()
+ clt, ok := l.clts[addr.String()]
+ l.mu.RUnlock()
+ if !ok {
+ select {
+ case <-l.closed:
+ return nil
+ default:
}
- l.addr2peer[addrstr] = clt
- l.id2peer[clt.ID()] = clt
-
- data := make([]byte, 1+1+2)
- data[0] = uint8(rawTypeCtl)
- data[1] = uint8(ctlSetPeerID)
- be.PutUint16(data[2:4], uint16(clt.ID()))
- if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
- if errors.Is(err, net.ErrClosed) {
- return nil
- }
- return fmt.Errorf("can't set client peer id: %w", err)
+ clt, err = l.add(addr)
+ if err != nil {
+ return err
}
+ }
- go func() {
- select {
- case l.clts <- clt:
- case <-clt.Disco():
- }
-
- clt.processNetPkts(pkts)
- }()
+ select {
+ case clt.pkts <- buf[:n]:
+ case <-clt.closed:
+ }
- go func() {
- <-clt.Disco()
+ return nil
+}
- l.mu.Lock()
- close(pkts)
- delete(l.addr2peer, addrstr)
- delete(l.id2peer, clt.ID())
- l.mu.Unlock()
- }()
- }
+func (l *Listener) add(addr net.Addr) (*udpClt, error) {
+ l.mu.Lock()
+ defer l.mu.Unlock()
- select {
- case <-clt.accepted:
- clt.pkts <- pkt
- default:
- select {
- case clt.pkts <- pkt:
- default:
- // It's OK to drop packets if the buffer is full
- // because MT RUDP can cope with packet loss.
+ start := l.peerID
+ l.peerID++
+ for l.peerID < PeerIDCltMin || l.ids[l.peerID] {
+ if l.peerID == start {
+ return nil, ErrOutOfPeerIDs
}
+ l.peerID++
+ }
+ l.ids[l.peerID] = true
+
+ clt := &udpClt{
+ l: l,
+ id: l.peerID,
+ addr: addr,
+ pkts: make(chan []byte),
+ closed: make(chan struct{}),
}
+ l.clts[addr.String()] = clt
- return nil
+ l.wg.Add(1)
+ go clt.mkConn()
+
+ return clt, nil
}