]> git.lizzy.rs Git - mt.git/blobdiff - rudp/listen.go
rudp: partial rewrite with new API supporting io.Readers
[mt.git] / rudp / listen.go
index 5b7154a56c7dc02ef91e89a5ff5888e074a4c21f..e1cacf49fb267ceab8a1ccdda932dc786934744f 100644 (file)
@@ -2,155 +2,213 @@ package rudp
 
 import (
        "errors"
-       "fmt"
        "net"
        "sync"
 )
 
+func tryClose(ch chan struct{}) (ok bool) {
+       defer func() { recover() }()
+       close(ch)
+       return true
+}
+
+type udpClt struct {
+       l      *Listener
+       id     PeerID
+       addr   net.Addr
+       pkts   chan []byte
+       closed chan struct{}
+}
+
+func (c *udpClt) mkConn() {
+       conn := newConn(c, c.id, PeerIDSrv)
+       go func() {
+               <-conn.Closed()
+               c.l.wg.Done()
+       }()
+       conn.sendRaw(func(buf []byte) int {
+               buf[0] = uint8(rawCtl)
+               buf[1] = uint8(ctlSetPeerID)
+               be.PutUint16(buf[2:4], uint16(conn.ID()))
+               return 4
+       }, PktInfo{})()
+       select {
+       case c.l.conns <- conn:
+       case <-c.l.closed:
+               conn.Close()
+       }
+}
+
+func (c *udpClt) Write(pkt []byte) (int, error) {
+       select {
+       case <-c.closed:
+               return 0, net.ErrClosed
+       default:
+       }
+
+       return c.l.pc.WriteTo(pkt, c.addr)
+}
+
+func (c *udpClt) recvUDP() ([]byte, error) {
+       select {
+       case pkt := <-c.pkts:
+               return pkt, nil
+       case <-c.closed:
+               return nil, net.ErrClosed
+       }
+}
+
+func (c *udpClt) Close() error {
+       if !tryClose(c.closed) {
+               return net.ErrClosed
+       }
+
+       c.l.mu.Lock()
+       defer c.l.mu.Unlock()
+
+       delete(c.l.ids, c.id)
+       delete(c.l.clts, c.addr.String())
+
+       return nil
+}
+
+func (c *udpClt) LocalAddr() net.Addr  { return c.l.pc.LocalAddr() }
+func (c *udpClt) RemoteAddr() net.Addr { return c.addr }
+
+// All Listener's methods are safe for concurrent use.
 type Listener struct {
-       conn net.PacketConn
+       pc net.PacketConn
 
-       clts chan cltPeer
-       errs chan error
+       peerID PeerID
+       conns  chan *Conn
+       errs   chan error
+       closed chan struct{}
+       wg     sync.WaitGroup
 
-       mu        sync.Mutex
-       addr2peer map[string]cltPeer
-       id2peer   map[PeerID]cltPeer
-       peerID    PeerID
+       mu   sync.RWMutex
+       ids  map[PeerID]bool
+       clts map[string]*udpClt
 }
 
-// Listen listens for packets on conn until it is closed.
-func Listen(conn net.PacketConn) *Listener {
+// Listen listens for connections on pc, pc is closed once the returned Listener
+// and all Conns connected through it are closed.
+func Listen(pc net.PacketConn) *Listener {
        l := &Listener{
-               conn: conn,
+               pc: pc,
 
-               clts: make(chan cltPeer),
-               errs: make(chan error),
+               conns:  make(chan *Conn),
+               closed: make(chan struct{}),
 
-               addr2peer: make(map[string]cltPeer),
-               id2peer:   make(map[PeerID]cltPeer),
+               ids:  make(map[PeerID]bool),
+               clts: make(map[string]*udpClt),
        }
 
-       pkts := make(chan netPkt)
-       go readNetPkts(l.conn, pkts, l.errs)
        go func() {
-               for pkt := range pkts {
-                       if err := l.processNetPkt(pkt); err != nil {
-                               l.errs <- err
+               for {
+                       if err := l.processNetPkt(); err != nil {
+                               if errors.Is(err, net.ErrClosed) {
+                                       break
+                               }
+                               select {
+                               case l.errs <- err:
+                               case <-l.closed:
+                               }
                        }
                }
-
-               close(l.clts)
-
-               for _, clt := range l.addr2peer {
-                       clt.Close()
-               }
        }()
 
        return l
 }
 
-// Accept waits for and returns a connecting Peer.
-// You should keep calling this until it returns net.ErrClosed
-// so it doesn't leak a goroutine.
-func (l *Listener) Accept() (*Peer, error) {
+// Accept waits for and returns the next incoming Conn or an error.
+func (l *Listener) Accept() (*Conn, error) {
        select {
-       case clt, ok := <-l.clts:
-               if !ok {
-                       select {
-                       case err := <-l.errs:
-                               return nil, err
-                       default:
-                               return nil, net.ErrClosed
-                       }
-               }
-               close(clt.accepted)
-               return clt.Peer, nil
+       case c := <-l.conns:
+               return c, nil
        case err := <-l.errs:
                return nil, err
+       case <-l.closed:
+               return nil, net.ErrClosed
        }
 }
 
-// Addr returns the net.PacketConn the Listener is listening on.
-func (l *Listener) Conn() net.PacketConn { return l.conn }
+// Close makes the Listener stop listening for new Conns.
+// Blocked Accept calls will return net.ErrClosed.
+// Already Accepted Conns are not closed.
+func (l *Listener) Close() error {
+       if !tryClose(l.closed) {
+               return net.ErrClosed
+       }
 
-var ErrOutOfPeerIDs = errors.New("out of peer ids")
+       go func() {
+               l.wg.Wait()
+               l.pc.Close()
+       }()
 
-type cltPeer struct {
-       *Peer
-       pkts     chan<- netPkt
-       accepted chan struct{} // close-only
+       return nil
 }
 
-func (l *Listener) processNetPkt(pkt netPkt) error {
-       l.mu.Lock()
-       defer l.mu.Unlock()
-
-       addrstr := pkt.SrcAddr.String()
+// Addr returns the Listener's network address.
+func (l *Listener) Addr() net.Addr { return l.pc.LocalAddr() }
 
-       clt, ok := l.addr2peer[addrstr]
-       if !ok {
-               prev := l.peerID
-               for l.id2peer[l.peerID].Peer != nil || l.peerID < PeerIDCltMin {
-                       if l.peerID == prev-1 {
-                               return ErrOutOfPeerIDs
-                       }
-                       l.peerID++
-               }
+var ErrOutOfPeerIDs = errors.New("out of peer ids")
 
-               pkts := make(chan netPkt, 256)
+func (l *Listener) processNetPkt() error {
+       buf := make([]byte, maxUDPPktSize)
+       n, addr, err := l.pc.ReadFrom(buf)
+       if err != nil {
+               return err
+       }
 
-               clt = cltPeer{
-                       Peer:     newPeer(l.conn, pkt.SrcAddr, l.peerID, PeerIDSrv),
-                       pkts:     pkts,
-                       accepted: make(chan struct{}),
+       l.mu.RLock()
+       clt, ok := l.clts[addr.String()]
+       l.mu.RUnlock()
+       if !ok {
+               select {
+               case <-l.closed:
+                       return nil
+               default:
                }
 
-               l.addr2peer[addrstr] = clt
-               l.id2peer[clt.ID()] = clt
-
-               data := make([]byte, 1+1+2)
-               data[0] = uint8(rawTypeCtl)
-               data[1] = uint8(ctlSetPeerID)
-               be.PutUint16(data[2:4], uint16(clt.ID()))
-               if _, err := clt.sendRaw(rawPkt{Data: data}); err != nil {
-                       if errors.Is(err, net.ErrClosed) {
-                               return nil
-                       }
-                       return fmt.Errorf("can't set client peer id: %w", err)
+               clt, err = l.add(addr)
+               if err != nil {
+                       return err
                }
+       }
 
-               go func() {
-                       select {
-                       case l.clts <- clt:
-                       case <-clt.Disco():
-                       }
-
-                       clt.processNetPkts(pkts)
-               }()
+       select {
+       case clt.pkts <- buf[:n]:
+       case <-clt.closed:
+       }
 
-               go func() {
-                       <-clt.Disco()
+       return nil
+}
 
-                       l.mu.Lock()
-                       close(pkts)
-                       delete(l.addr2peer, addrstr)
-                       delete(l.id2peer, clt.ID())
-                       l.mu.Unlock()
-               }()
-       }
+func (l *Listener) add(addr net.Addr) (*udpClt, error) {
+       l.mu.Lock()
+       defer l.mu.Unlock()
 
-       select {
-       case <-clt.accepted:
-               clt.pkts <- pkt
-       default:
-               select {
-               case clt.pkts <- pkt:
-               default:
-                       // It's OK to drop packets if the buffer is full
-                       // because MT RUDP can cope with packet loss.
+       start := l.peerID
+       l.peerID++
+       for l.peerID < PeerIDCltMin || l.ids[l.peerID] {
+               if l.peerID == start {
+                       return nil, ErrOutOfPeerIDs
                }
+               l.peerID++
+       }
+       l.ids[l.peerID] = true
+
+       clt := &udpClt{
+               l:      l,
+               id:     l.peerID,
+               addr:   addr,
+               pkts:   make(chan []byte),
+               closed: make(chan struct{}),
        }
+       l.clts[addr.String()] = clt
 
-       return nil
+       l.wg.Add(1)
+       go clt.mkConn()
+
+       return clt, nil
 }