]> git.lizzy.rs Git - hydra-dragonfire.git/commitdiff
Event system
authorElias Fleckenstein <eliasfleckenstein@web.de>
Tue, 31 May 2022 16:10:27 +0000 (18:10 +0200)
committerElias Fleckenstein <eliasfleckenstein@web.de>
Tue, 31 May 2022 16:10:27 +0000 (18:10 +0200)
13 files changed:
client.go
convert/push_auto.go
convert/push_mkauto.lua
doc/client.md
doc/hydra.md
doc/pkts.md [new file with mode: 0644]
doc/poll.md
example/chat-client.lua
example/dump-traffic.lua
example/print-node.lua
hydra.go
pkts.go [new file with mode: 0644]
poll.go

index 13795e5ecfced6d50d57502fab03ea9a885e2943..4595d131383ac8b2fc746934713be2d7a63b5be7 100644 (file)
--- a/client.go
+++ b/client.go
@@ -29,24 +29,38 @@ type Client struct {
        address    string
        state      clientState
        conn       mt.Peer
-       queue      chan *mt.Pkt
-       wildcard   bool
-       subscribed map[string]struct{}
+       queue      chan Event
        components map[string]Component
+       table      *lua.LTable
        userdata   *lua.LUserData
 }
 
 var clientFuncs = map[string]lua.LGFunction{
-       "address":     l_client_address,
-       "state":       l_client_state,
-       "connect":     l_client_connect,
-       "poll":        l_client_poll,
-       "close":       l_client_close,
-       "enable":      l_client_enable,
-       "subscribe":   l_client_subscribe,
-       "unsubscribe": l_client_unsubscribe,
-       "wildcard":    l_client_wildcard,
-       "send":        l_client_send,
+       "address": l_client_address,
+       "state":   l_client_state,
+       "connect": l_client_connect,
+       "poll":    l_client_poll,
+       "close":   l_client_close,
+       "enable":  l_client_enable,
+       "send":    l_client_send,
+}
+
+type EventError struct {
+       err string
+}
+
+func (evt EventError) handle(l *lua.LState, val lua.LValue) {
+       l.SetField(val, "type", lua.LString("error"))
+       l.SetField(val, "error", lua.LString(evt.err))
+}
+
+type EventDisconnect struct {
+       client *Client
+}
+
+func (evt EventDisconnect) handle(l *lua.LState, val lua.LValue) {
+       l.SetField(val, "type", lua.LString("disconnect"))
+       evt.client.state = csDisconnected
 }
 
 func getClient(l *lua.LState) *Client {
@@ -65,17 +79,6 @@ func getClients(l *lua.LState) []*Client {
        return clients
 }
 
-func getStrings(l *lua.LState) []string {
-       n := l.GetTop()
-
-       strs := make([]string, 0, n-1)
-       for i := 2; i <= n; i++ {
-               strs = append(strs, l.CheckString(i))
-       }
-
-       return strs
-}
-
 func (client *Client) closeConn() {
        client.mu.Lock()
        defer client.mu.Unlock()
@@ -90,9 +93,8 @@ func l_client(l *lua.LState) int {
 
        client.address = l.CheckString(1)
        client.state = csNew
-       client.wildcard = false
-       client.subscribed = map[string]struct{}{}
        client.components = map[string]Component{}
+       client.table = l.NewTable()
        client.userdata = l.NewUserData()
        client.userdata.Value = client
        l.SetMetatable(client.userdata, l.GetTypeMetatable("hydra.client"))
@@ -105,7 +107,9 @@ func l_client_index(l *lua.LState) int {
        client := getClient(l)
        key := l.CheckString(2)
 
-       if fun, exists := clientFuncs[key]; exists {
+       if key == "data" {
+               l.Push(client.table)
+       } else if fun, exists := clientFuncs[key]; exists {
                l.Push(l.NewFunction(fun))
        } else if component, exists := client.components[key]; exists {
                l.Push(component.push())
@@ -154,7 +158,7 @@ func l_client_connect(l *lua.LState) int {
 
        client.state = csConnected
        client.conn = mt.Connect(conn)
-       client.queue = make(chan *mt.Pkt, 1024)
+       client.queue = make(chan Event, 1024)
 
        go func() {
                for {
@@ -165,15 +169,12 @@ func l_client_connect(l *lua.LState) int {
                                for _, component := range client.components {
                                        component.process(&pkt)
                                }
-                               _, subscribed := client.subscribed[string(convert.PushPktType(&pkt))]
                                client.mu.Unlock()
-
-                               if subscribed || client.wildcard {
-                                       client.queue <- &pkt
-                               }
                        } else if errors.Is(err, net.ErrClosed) {
-                               close(client.queue)
+                               client.queue <- EventDisconnect{client: client}
                                return
+                       } else {
+                               client.queue <- EventError{err: err.Error()}
                        }
                }
        }()
@@ -189,11 +190,7 @@ func l_client_connect(l *lua.LState) int {
 
 func l_client_poll(l *lua.LState) int {
        client := getClient(l)
-       _, pkt, timeout := doPoll(l, []*Client{client})
-
-       l.Push(convert.PushPkt(l, pkt))
-       l.Push(lua.LBool(timeout))
-       return 2
+       return doPoll(l, []*Client{client})
 }
 
 func l_client_close(l *lua.LState) int {
@@ -204,16 +201,22 @@ func l_client_close(l *lua.LState) int {
 
 func l_client_enable(l *lua.LState) int {
        client := getClient(l)
+       n := l.GetTop()
+
        client.mu.Lock()
        defer client.mu.Unlock()
 
-       for _, compname := range getStrings(l) {
+       for i := 2; i <= n; i++ {
+               compname := l.CheckString(i)
+
                if component, exists := client.components[compname]; !exists {
                        switch compname {
                        case "auth":
                                component = &Auth{}
                        case "map":
                                component = &Map{}
+                       case "pkts":
+                               component = &Pkts{}
                        default:
                                panic("invalid component: " + compname)
                        }
@@ -226,36 +229,6 @@ func l_client_enable(l *lua.LState) int {
        return 0
 }
 
-func l_client_subscribe(l *lua.LState) int {
-       client := getClient(l)
-       client.mu.Lock()
-       defer client.mu.Unlock()
-
-       for _, pkt := range getStrings(l) {
-               client.subscribed[pkt] = struct{}{}
-       }
-
-       return 0
-}
-
-func l_client_unsubscribe(l *lua.LState) int {
-       client := getClient(l)
-       client.mu.Lock()
-       defer client.mu.Unlock()
-
-       for _, pkt := range getStrings(l) {
-               delete(client.subscribed, pkt)
-       }
-
-       return 0
-}
-
-func l_client_wildcard(l *lua.LState) int {
-       client := getClient(l)
-       client.wildcard = l.ToBool(2)
-       return 0
-}
-
 func l_client_send(l *lua.LState) int {
        client := getClient(l)
 
@@ -271,7 +244,7 @@ func l_client_send(l *lua.LState) int {
 
        if client.state == csConnected {
                ack, err := client.conn.SendCmd(cmd)
-               if err != nil {
+               if err != nil && !errors.Is(err, net.ErrClosed) {
                        panic(err)
                }
 
index beb30c25271984720542a223ad3277fc6ccce82d..bcd0948d61a95ab27b55ae007e254e5988558e54 100644 (file)
@@ -408,7 +408,6 @@ func PushPkt(l *lua.LState, pkt *mt.Pkt) lua.LValue {
                return lua.LNil
        }
        tbl := l.NewTable()
-       l.SetField(tbl, "_type", PushPktType(pkt))
        switch val := pkt.Cmd.(type) {
        case *mt.ToCltAcceptAuth:
                l.SetField(tbl, "map_seed", lua.LNumber(val.MapSeed))
index df13198df9a93e2080925d4efa2295e1311a1787..3e1f290f6ccffd254b5d21990aa0d3a28988eb72 100755 (executable)
@@ -120,7 +120,6 @@ func PushPkt(l *lua.LState, pkt *mt.Pkt) lua.LValue {
                return lua.LNil
        }
        tbl := l.NewTable()
-       l.SetField(tbl, "_type", PushPktType(pkt))
        switch val := pkt.Cmd.(type) {
 ]] .. pkt_impl .. [[
        }
index 9c1922e16aa753a41ae5fdcc542f66e5731206b2..fe9fb7d9878f7ab7ab88336e57d9e144cc4dc78f 100644 (file)
@@ -9,17 +9,15 @@ After being disconnect, a client cannot be reconnected.
 - `self:address()`: Returns address passed to `hydra.client` upon creation as a string.
 - `self:state()`: Returns current connection state as a string ("new", "connected", "disconnected")
 - `self:connect()`: Connects to server. Throws an error if the client is not in "new" state OR address resolution / dial fails (Note: If required, you can use `pcall` to catch and handle errors instead of crashing the script). Connection failure (= host found, but no minetest server running on port) is equivalent to an immediate disconnect and does not cause an error to be thrown.
-- `self:poll([timeout])`: Polls packets from client. See [poll.md](poll.md) for behavior and return values.
+- `self:poll([timeout])`: Polls events from client. See [poll.md](poll.md) for behavior and return values.
 - `self:close()`: Closes the network connection if in `connected` state. The client remains in `connected` state until passed to poll.
 - `self:enable(component)`: Enables the component with the name `component` (string), if not already enabled. By default, no components are enabled. See Components section.
-- `self:subscribe(pkt1, [pkt2, ...])`: Subscribes to all packet passed as arguments (strings). For available packets, see [client_pkts.md](client_pkts.md). By default, the client is not subscribed to any packets.
-- `self:unsubscribe(pkt1, [pkt2, ...])`: Unsubscribes from all packet passed as arguments (strings).
-- `self:wildcard(wildcard)`: Sets wildcard mode to `wildcard` (boolean). If wildcard is enabled, ALL packets are returned by poll, even those that the client did not subscribe to. It is not recommended to use this without a reason since converting packets to Lua costs performance and creates and overhead due to poll returning more often. `wildcard` is unnecessary if only certain packets are handled anyway, but it is useful for traffic inspection and debugging.
-- `self:send(pkt_name, pkt_data, [ack])`: Sends a packet to server. Throws an error if the client is not connected. `pkt_name` is the type of the packet as string. `pkt_data` is a table containing packet parameters. Some packets don't have parameters (e.g. `respawn`) - in this case, `pkt_data` can be omitted. See [server_pkts.md](server_pkts.md) for available packets. If `ack` is true, this function will block until acknowledgement from server is received.
+- `self:send(pkt_type, pkt_data, [ack])`: Sends a packet to server. Throws an error if the client is not connected. `pkt_type` is the type of the packet as string. `pkt_data` is a table containing packet parameters. Some packets don't have parameters (e.g. `respawn`) - in this case, `pkt_data` can be omitted. See [server_pkts.md](server_pkts.md) for available packets. If `ack` is true, this function will block until acknowledgement from server is received.
 
 ## Components
 
 Enabled components can be accessed by using `self.<component name>`.
 
+- `self.pkt`: Allows you to handle selected packets yourself. Most scripts use this. See [pkts.md](pkts.md).
 - `self.auth`: Handles authentication. Recommended for the vast majority of scripts. See [auth.md](auth.md).
 - `self.map`: Stores MapBlocks received from server. See [map.md](map.md).
index 4eac2cfa6907cf6b50034bd3964ba03de3ab01e1..fd916e07fbd481c6ca3067cfbc687b0a5785105c 100644 (file)
@@ -13,6 +13,5 @@ The `hydra` table contains functions necessary to handle connections.
 
 - `hydra.client(address)`: Returns a new client. Address must be a string. For client functions, see [client.md](client.md).
 - `hydra.dtime()`: Utility function that turns the elapsed time in seconds (floating point) since it was last called (or since program start).
-- `hydra.canceled()`: Utility function that returns true if the program was interrupted (SIGINT, SIGTERM, SIGHUP).
-- `hydra.poll(clients, [timeout])`: Polls subscribed packets from all clients in `clients` (table). For behavior and return value, see [poll.md](poll.md).
+- `hydra.poll(clients, [timeout])`: Polls events from all clients in `clients` (table). For behavior and return value, see [poll.md](poll.md).
 - `hydra.close(clients)`: Closes all clients in `clients` (table) that are currently connected. See `client:close()` in [client.md](client.md) for more info.
diff --git a/doc/pkts.md b/doc/pkts.md
new file mode 100644 (file)
index 0000000..fba5083
--- /dev/null
@@ -0,0 +1,18 @@
+# Packet Handler Component
+Source code: [pkts.go](../pkts.go)
+
+The packet handler component allows you to handle packets yourself. It fires events in the form of `{ type = "pkt", client = ..., pkt_type = "...", pkt_data = { ... } }``` when subscribed packets are received.
+For available packets, see [client_pkts.md](client_pkts.md). By default, not packets are packets subscribed.
+
+## Wildcard mode
+
+If wildcard is enabled, events for all packets are fired, even ones that are not subscribed. It is not recommended to use this without a reason since converting packets to Lua costs performance and creates and overhead due to poll returning more often. `wildcard` is unnecessary if only certain packets are handled anyway, but it is useful for traffic inspection and debugging.
+
+## Functions
+
+- `self:subscribe(pkt1, [pkt2, ...])`: Subscribes to all packet types passed as arguments (strings).
+
+- `self:unsubscribe(pkt1, [pkt2, ...])`: Unsubscribes from all packet passed as arguments (strings).
+
+- `self:wildcard(wildcard)`: Sets wildcard mode to `wildcard` (boolean).
+
index 494fe48293f54cc368015f30da3c6c8bc360ce50..16aa62ce1fe3cef9aa03de08965acaa6799fe481 100644 (file)
@@ -1,49 +1,19 @@
 # Polling API
 Source code: [poll.go](../poll.go)
 
-**TL;DR**: poll is complex and has many different cases, but in general, it returns the received packet and the associated client; if one of the clients closes, a nil packet is returned once. client may also be nil in some cases so watch out for that.
+`poll` waits for and returns the next event from one or more clients, or `nil` if none of the clients passed to it are active (`connected` state).
+Optionally, a timeout can be passed to poll; if no other event occurs until the timeout elapses, a timeout event is returned.
 
-Together with sending, polling is the core function of hydra. It is used to receive packets from a packet queue.
+## Events
 
-For each client, only packets that the client has subscribed to are inserted into that queue, unless wildcard is enabled.
+An event is a table that contains a string `type`. Depending on the type, it may have different other fields.
 
-Packet receival from network happens asynchronously. When a packet is received and has been processed by components, it is enqueued for polling if the client is subscribed to it. **Because of the poll queue, packets may be returned by poll that the client was subscribed to in the past but unsubscribed recently.** Since the queue has a limited capacity of 1024 packets (this may change in the future), it is your responsibility to actually poll in a frequency suitable to keep up with the amount of packets you expect based on what you are subscribed to. If the queue is full, the thread responsible for receival will block.
+- `type = "interrupt"`: Fired globally when the program was interrupted using a signal.
 
-Clients that are not in `connected` state are ignored by poll.
+- `type = "timeout"`: Fired when the timeout elapses.
 
-Poll blocks until one of these conditions is met (in this order). The return value depends on which condition is met:
+- `type = "pkt"`: Fired when a packet was received. See [pkts.md](pkts.md)
 
-1. No clients are available when the function is called. This happens if either no clients were passed to poll or none of them is connected.
+- `type = "disconnect"`: Fired when a client connection closed. Has a `client` field. 
 
-2. One of the clients closes. In this case, the client that closed is set to `disconnected` state. The close may happen before or during the call to poll, but it has effect only once.
-
-3. A packet is in queue for one of the clients (Main case).
-
-4. An interrupt signal is received during polling (See `hydra.canceled`).
-
-5. The configured timeout elapses.
-
-## Different versions
-
-There is two different versions of poll: `client:poll` for polling a single client and `hydra.poll` for polling multiple clients.
-They are mostly equivalent but differ in return values and arguments:
-
-- `client:poll([timeout])` polls from the client `client` and returns `pkt, interrupted`
-
-- `hydra.poll(clients, [timeout])` takes table of clients as argument and returns `pkt, client, interrupted`
-
-## Arguments and return values
-
-The timeout argument is an optional floating point number holding the timeout in seconds, if `nil`, poll will block until one of the conditions 1.-4. are met. Timeout may be `0`, in this case poll returns immediately even if none of the other conditions are met immediately.
-
-Return values for different cases:
-
-1. If no clients are available, `nil, nil, false` (or `nil, false` respectively) is returned.
-
-2. If a client closes, `nil, client, false` (or `nil, false` respectively) is returned.
-
-3. If a packet is available, poll returns `pkt, client, false` (or `pkt, false` respectively). `pkt` is a table containing the received packet (see [client_pkts.md](client_pkts.md)) and `client` is the client reference that has received the packet.
-
-4. If the program is interrupted, poll returns `nil, nil, true` (or `nil, true` respectively).
-
-5. If the timeout elapses, poll returns `nil, nil, true` (or `nil, true` respectively).
+- `type = "error"`: Fired when an error occurs during deserialization of a packet. Has a `client` field. Stores the error message in an `error` field.
index c1828ec0d5eaf2e1514822a9df7cb3f39e573c6e..889369fe4ad453bde1c7c372a57a763c2dd96130 100755 (executable)
@@ -2,19 +2,24 @@
 local escapes = require("escapes")
 local client = require("client")()
 
-client:subscribe("chat_msg")
+client:enable("pkts")
+client.pkts:subscribe("chat_msg")
+
 client:connect()
 
-while not hydra.canceled() do
-       local pkt, interrupt = client:poll(1)
+while true do
+       local evt = client:poll(1)
 
-       if pkt then
-               print(escapes.strip_all(pkt.text))
-       elseif interrupt then
-               client:send("chat_msg", {msg = "test"})
-       else
-               print("disconnected")
+       if not evt then
+               break
+       end
+
+       if not evt or evt.type == "interrupt" or evt.type == "disconnect" then
                break
+       elseif evt.type == "pkt" then
+               print(escapes.strip_all(evt.pkt_data.text))
+       elseif evt.type == "timeout" then
+               client:send("chat_msg", {msg = "test"})
        end
 end
 
index 5dc83b6eea4ffad1db81006ba892a8f6f74f35a6..f003c745b17ddf6b2e07df2856eec11dcfb9a220 100755 (executable)
@@ -3,7 +3,9 @@ local escapes = require("escapes")
 local base64 = require("base64")
 local client = require("client")()
 
-client:wildcard(true)
+client:enable("pkts")
+client.pkts:wildcard(true)
+
 client:connect()
 
 local function dump(val, indent)
@@ -16,41 +18,43 @@ local function dump(val, indent)
                end
                print(val)
        else
-               print(val._type or "")
+               print()
 
                local idt = (indent or "") .. "  "
                for k, v in pairs(val) do
-                       if k ~= "_type" then
-                               io.write(idt .. k .. " ")
-                               dump(v, idt)
-                       end
+                       io.write(idt .. k .. " ")
+                       dump(v, idt)
                end
        end
 end
 
-while not hydra.canceled() do
-       local pkt, interrupt = client:poll()
+while true do
+       local evt = client:poll()
 
-       if pkt then
-               if pkt._type == "srp_bytes_salt_b" then
-                       pkt.b = base64.encode(pkt.b)
-                       pkt.salt = base64.encode(pkt.salt)
+       if not evt or evt.type == "disconnect" or evt.type == "interrupt" then
+               break
+       elseif evt.type == "error" then
+               print(evt.error)
+       elseif evt.type == "pkt" then
+               local type, data = evt.pkt_type, evt.pkt_data
+
+               if type == "srp_bytes_salt_b" then
+                       data.b = base64.encode(data.b)
+                       data.salt = base64.encode(data.salt)
                end
 
-               if pkt._type == "chat_msg" then
-                       pkt.text = escapes.strip_all(pkt.text)
+               if type == "chat_msg" then
+                       data.text = escapes.strip_all(data.text)
                end
 
-               if pkt._type == "blk_data" then
-                       pkt.blk.param0 = {}
-                       pkt.blk.param1 = {}
-                       pkt.blk.param2 = {}
+               if type == "blk_data" then
+                       data.blk.param0 = {}
+                       data.blk.param1 = {}
+                       data.blk.param2 = {}
                end
 
-               dump(pkt)
-       elseif not interrupt then
-               print("disconnected")
-               break
+               io.write(type)
+               dump(data)
        end
 end
 
index 3cf514e846cbe2a768490df64142cce9ef8fd93c..d72f4744476d665a6de0620b957039c7040a3a4f 100755 (executable)
@@ -1,20 +1,21 @@
 #!/usr/bin/env hydra-dragonfire
 local client = require("client")()
-client:enable("map")
 
-client:subscribe("move_player")
+client:enable("pkts", "map")
+client.pkts:subscribe("move_player")
+
 client:connect()
 
 local pos
 
-while not hydra.canceled() do
-       local pkt, interrupted = client:poll(1)
+while true do
+       local evt = client:poll(1)
 
-       if pkt then
-               pos = (pkt.pos / hydra.BS + vec3(0, -1, 0)):round()
-       elseif not interrupted then
+       if not evt or evt.type == "disconnect" or evt.type == "interrupt" then
                break
-       elseif pos then
+       elseif evt.type == "pkt" then
+               pos = (evt.pkt_data.pos / hydra.BS + vec3(0, -1, 0)):round()
+       elseif evt.type == "timeout" and pos then
                local node = client.map:node(pos)
                print(pos, node and node.param0)
        end
index f34ca2bb733f55f03b659a4cc1a4ce6f87e5cab7..7f7b2dac28e563bb8ced73b410711b868971673f 100644 (file)
--- a/hydra.go
+++ b/hydra.go
@@ -2,7 +2,6 @@ package main
 
 import (
        _ "embed"
-       "github.com/dragonfireclient/hydra-dragonfire/convert"
        "github.com/yuin/gopher-lua"
        "os"
        "os/signal"
@@ -11,7 +10,7 @@ import (
 )
 
 var lastTime = time.Now()
-var canceled = false
+var signalChannel chan os.Signal
 
 var serializeVer uint8 = 28
 var protoVer uint16 = 39
@@ -40,17 +39,10 @@ var builtinFiles = []string{
 }
 
 var hydraFuncs = map[string]lua.LGFunction{
-       "client":   l_client,
-       "dtime":    l_dtime,
-       "canceled": l_canceled,
-       "poll":     l_poll,
-       "close":    l_close,
-}
-
-func signalChannel() chan os.Signal {
-       sig := make(chan os.Signal, 1)
-       signal.Notify(sig, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
-       return sig
+       "client": l_client,
+       "dtime":  l_dtime,
+       "poll":   l_poll,
+       "close":  l_close,
 }
 
 func l_dtime(l *lua.LState) int {
@@ -59,21 +51,8 @@ func l_dtime(l *lua.LState) int {
        return 1
 }
 
-func l_canceled(l *lua.LState) int {
-       l.Push(lua.LBool(canceled))
-       return 1
-}
-
 func l_poll(l *lua.LState) int {
-       client, pkt, timeout := doPoll(l, getClients(l))
-       l.Push(convert.PushPkt(l, pkt))
-       if client == nil {
-               l.Push(lua.LNil)
-       } else {
-               l.Push(client.userdata)
-       }
-       l.Push(lua.LBool(timeout))
-       return 3
+       return doPoll(l, getClients(l))
 }
 
 func l_close(l *lua.LState) int {
@@ -89,12 +68,10 @@ func main() {
                panic("missing filename")
        }
 
-       go func() {
-               <-signalChannel()
-               canceled = true
-       }()
+       signalChannel = make(chan os.Signal, 1)
+       signal.Notify(signalChannel, os.Interrupt, syscall.SIGTERM, syscall.SIGHUP)
 
-       l := lua.NewState(lua.Options{IncludeGoStackTrace: true})
+       l := lua.NewState()
        defer l.Close()
 
        arg := l.NewTable()
@@ -112,6 +89,7 @@ func main() {
        l.SetField(l.NewTypeMetatable("hydra.auth"), "__index", l.SetFuncs(l.NewTable(), authFuncs))
        l.SetField(l.NewTypeMetatable("hydra.client"), "__index", l.NewFunction(l_client_index))
        l.SetField(l.NewTypeMetatable("hydra.map"), "__index", l.SetFuncs(l.NewTable(), mapFuncs))
+       l.SetField(l.NewTypeMetatable("hydra.pkts"), "__index", l.SetFuncs(l.NewTable(), pktsFuncs))
 
        for _, str := range builtinFiles {
                if err := l.DoString(str); err != nil {
diff --git a/pkts.go b/pkts.go
new file mode 100644 (file)
index 0000000..e34f3b8
--- /dev/null
+++ b/pkts.go
@@ -0,0 +1,99 @@
+package main
+
+import (
+       "github.com/anon55555/mt"
+       "github.com/dragonfireclient/hydra-dragonfire/convert"
+       "github.com/yuin/gopher-lua"
+       "sync"
+)
+
+type Pkts struct {
+       client     *Client
+       mu         sync.Mutex
+       wildcard   bool
+       subscribed map[string]struct{}
+       userdata   *lua.LUserData
+}
+
+var pktsFuncs = map[string]lua.LGFunction{
+       "subscribe":   l_pkts_subscribe,
+       "unsubscribe": l_pkts_unsubscribe,
+       "wildcard":    l_pkts_wildcard,
+}
+
+type EventPkt struct {
+       pktType string
+       pktData *mt.Pkt
+}
+
+func (evt EventPkt) handle(l *lua.LState, val lua.LValue) {
+       l.SetField(val, "type", lua.LString("pkt"))
+       l.SetField(val, "pkt_type", lua.LString(evt.pktType))
+       l.SetField(val, "pkt_data", convert.PushPkt(l, evt.pktData))
+}
+
+func getPkts(l *lua.LState) *Pkts {
+       return l.CheckUserData(1).Value.(*Pkts)
+}
+
+func (pkts *Pkts) create(client *Client, l *lua.LState) {
+       pkts.client = client
+       pkts.wildcard = false
+       pkts.subscribed = map[string]struct{}{}
+       pkts.userdata = l.NewUserData()
+       pkts.userdata.Value = pkts
+       l.SetMetatable(pkts.userdata, l.GetTypeMetatable("hydra.pkts"))
+}
+
+func (pkts *Pkts) push() lua.LValue {
+       return pkts.userdata
+}
+
+func (pkts *Pkts) connect() {
+}
+
+func (pkts *Pkts) process(pkt *mt.Pkt) {
+       pktType := string(convert.PushPktType(pkt))
+
+       pkts.mu.Lock()
+       _, subscribed := pkts.subscribed[pktType]
+       pkts.mu.Unlock()
+
+       if subscribed || pkts.wildcard {
+               pkts.client.queue <- EventPkt{pktType: pktType, pktData: pkt}
+       }
+}
+
+func l_pkts_subscribe(l *lua.LState) int {
+       pkts := getPkts(l)
+       n := l.GetTop()
+
+       pkts.mu.Lock()
+       defer pkts.mu.Unlock()
+
+       for i := 2; i <= n; i++ {
+               pkts.subscribed[l.CheckString(i)] = struct{}{}
+       }
+
+       return 0
+}
+
+func l_pkts_unsubscribe(l *lua.LState) int {
+       pkts := getPkts(l)
+       n := l.GetTop()
+
+       pkts.mu.Lock()
+       defer pkts.mu.Unlock()
+
+       for i := 2; i <= n; i++ {
+               delete(pkts.subscribed, l.CheckString(i))
+       }
+
+       return 0
+}
+
+func l_pkts_wildcard(l *lua.LState) int {
+       pkts := getPkts(l)
+       pkts.wildcard = l.ToBool(2)
+       return 0
+}
diff --git a/poll.go b/poll.go
index db00ef9561d487de29d1448f50d9b8180408d96b..dfc30a6fa0701a3d51f8e19fb472c8571a566841 100644 (file)
--- a/poll.go
+++ b/poll.go
@@ -1,21 +1,30 @@
 package main
 
 import (
-       "github.com/anon55555/mt"
        "github.com/yuin/gopher-lua"
        "reflect"
        "time"
 )
 
-func doPoll(l *lua.LState, clients []*Client) (*Client, *mt.Pkt, bool) {
-       var timeout time.Duration
-       hasTimeout := false
-       if l.GetTop() > 1 {
-               timeout = time.Duration(float64(l.ToNumber(2)) * float64(time.Second))
-               hasTimeout = true
-       }
+type Event interface {
+       handle(l *lua.LState, val lua.LValue)
+}
+
+type EventTimeout struct{}
 
+func (evt EventTimeout) handle(l *lua.LState, val lua.LValue) {
+       l.SetField(val, "type", lua.LString("timeout"))
+}
+
+type EventInterrupt struct{}
+
+func (evt EventInterrupt) handle(l *lua.LState, val lua.LValue) {
+       l.SetField(val, "type", lua.LString("interrupt"))
+}
+
+func doPoll(l *lua.LState, clients []*Client) int {
        cases := make([]reflect.SelectCase, 0, len(clients)+2)
+
        for _, client := range clients {
                if client.state != csConnected {
                        continue
@@ -30,35 +39,39 @@ func doPoll(l *lua.LState, clients []*Client) (*Client, *mt.Pkt, bool) {
        offset := len(cases)
 
        if offset < 1 {
-               return nil, nil, false
+               return 0
        }
 
        cases = append(cases, reflect.SelectCase{
                Dir:  reflect.SelectRecv,
-               Chan: reflect.ValueOf(signalChannel()),
+               Chan: reflect.ValueOf(signalChannel),
        })
 
-       if hasTimeout {
+       if l.GetTop() > 1 {
+               timeout := time.After(time.Duration(float64(l.ToNumber(2)) * float64(time.Second)))
+
                cases = append(cases, reflect.SelectCase{
                        Dir:  reflect.SelectRecv,
-                       Chan: reflect.ValueOf(time.After(timeout)),
+                       Chan: reflect.ValueOf(timeout),
                })
        }
 
-       idx, value, ok := reflect.Select(cases)
+       idx, value, _ := reflect.Select(cases)
 
-       if idx >= offset {
-               return nil, nil, true
-       }
-
-       client := clients[idx]
+       var evt Event
+       tbl := l.NewTable()
 
-       var pkt *mt.Pkt = nil
-       if ok {
-               pkt = value.Interface().(*mt.Pkt)
+       if idx > offset {
+               evt = EventTimeout{}
+       } else if idx == offset {
+               evt = EventInterrupt{}
        } else {
-               client.state = csDisconnected
+               evt = value.Interface().(Event)
+               l.SetField(tbl, "client", clients[idx].userdata)
        }
 
-       return client, pkt, false
+       evt.handle(l, tbl)
+
+       l.Push(tbl)
+       return 1
 }