]> git.lizzy.rs Git - hydra-dragonfire.git/blobdiff - auth.go
Migrate to gopher-lua
[hydra-dragonfire.git] / auth.go
diff --git a/auth.go b/auth.go
index bc8d24e09bdab7c58e4186bd8a988fd95c76b543..879ffcee6ecbd85060be09b88148739e71d7d435 100644 (file)
--- a/auth.go
+++ b/auth.go
@@ -2,8 +2,9 @@ package main
 
 import (
        "github.com/HimbeerserverDE/srp"
-       "github.com/Shopify/go-lua"
        "github.com/anon55555/mt"
+       "github.com/dragonfireclient/hydra/tolua"
+       "github.com/yuin/gopher-lua"
        "strings"
        "time"
 )
@@ -23,47 +24,50 @@ type Auth struct {
        username          string
        password          string
        language          string
+       version           string
        state             authState
        err               string
        srpBytesA, bytesA []byte
+       userdata          *lua.LUserData
 }
 
-func getAuth(l *lua.State) *Auth {
-       return lua.CheckUserData(l, 1, "hydra.auth").(*Auth)
+var authFuncs = map[string]lua.LGFunction{
+       "username": l_auth_username,
+       "password": l_auth_password,
+       "language": l_auth_language,
+       "version":  l_auth_version,
+       "state":    l_auth_state,
 }
 
-func (auth *Auth) create(client *Client) {
+func getAuth(l *lua.LState) *Auth {
+       return l.CheckUserData(1).Value.(*Auth)
+}
+
+func (auth *Auth) create(client *Client, l *lua.LState) {
+       if client.state != csNew {
+               panic("can't add auth component after connect")
+       }
+
        auth.client = client
        auth.language = "en_US"
+       auth.version = "hydra-dragonfire"
        auth.state = asInit
+       auth.userdata = l.NewUserData()
+       auth.userdata.Value = auth
+       l.SetMetatable(auth.userdata, l.GetTypeMetatable("hydra.auth"))
 }
 
-func (auth *Auth) push(l *lua.State) {
-       l.PushUserData(auth)
-
-       if lua.NewMetaTable(l, "hydra.auth") {
-               lua.NewLibrary(l, []lua.RegistryFunction{
-                       {Name: "username", Function: l_auth_username},
-                       {Name: "password", Function: l_auth_password},
-                       {Name: "language", Function: l_auth_language},
-                       {Name: "state", Function: l_auth_state},
-               })
-               l.SetField(-2, "__index")
-       }
-       l.SetMetaTable(-2)
+func (auth *Auth) tolua() lua.LValue {
+       return auth.userdata
 }
 
-func (auth *Auth) canConnect() (bool, string) {
+func (auth *Auth) connect() {
        if auth.username == "" {
-               return false, "missing username"
+               panic("missing username")
        }
 
-       return true, ""
-}
-
-func (auth *Auth) connect() {
        go func() {
-               for auth.state == asInit && auth.client.state == csConnected {
+               for auth.client.state == csConnected && auth.state == asInit {
                        auth.client.conn.SendCmd(&mt.ToSrvInit{
                                SerializeVer: 28,
                                MinProtoVer:  39,
@@ -75,10 +79,10 @@ func (auth *Auth) connect() {
        }()
 }
 
-func (auth *Auth) setError(err string) {
-       auth.state = asError
+func (auth *Auth) fail(err string) {
        auth.err = err
-       auth.client.conn.Close()
+       auth.state = asError
+       auth.client.disconnect()
 }
 
 func (auth *Auth) checkState(state authState, pkt *mt.Pkt) bool {
@@ -86,12 +90,12 @@ func (auth *Auth) checkState(state authState, pkt *mt.Pkt) bool {
                return true
        }
 
-       auth.setError("received " + pktToString(pkt) + " in invalid state")
+       auth.fail("received " + string(tolua.PktType(pkt)) + " in invalid state")
        return false
 }
 
-func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
-       if pkt == nil {
+func (auth *Auth) process(pkt *mt.Pkt) {
+       if auth.state == asError {
                return
        }
 
@@ -102,14 +106,14 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
                }
 
                if cmd.SerializeVer != 28 {
-                       auth.setError("unsupported serialize_ver")
+                       auth.fail("unsupported serialize version")
                        return
                }
 
                if cmd.AuthMethods == mt.FirstSRP {
                        salt, verifier, err := srp.NewClient([]byte(strings.ToLower(auth.username)), []byte(auth.password))
                        if err != nil {
-                               auth.setError(err.Error())
+                               auth.fail(err.Error())
                                return
                        }
 
@@ -123,7 +127,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
                        var err error
                        auth.srpBytesA, auth.bytesA, err = srp.InitiateHandshake()
                        if err != nil {
-                               auth.setError(err.Error())
+                               auth.fail(err.Error())
                                return
                        }
 
@@ -133,8 +137,8 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
                        })
                        auth.state = asRequested
                } else {
-                       auth.setError("invalid auth methods")
-                       return                  
+                       auth.fail("invalid auth methods")
+                       return
                }
 
        case *mt.ToCltSRPBytesSaltB:
@@ -144,7 +148,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
 
                srpBytesK, err := srp.CompleteHandshake(auth.srpBytesA, auth.bytesA, []byte(strings.ToLower(auth.username)), []byte(auth.password), cmd.Salt, cmd.B)
                if err != nil {
-                       auth.setError(err.Error())
+                       auth.fail(err.Error())
                        return
                }
 
@@ -153,7 +157,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
                auth.bytesA = []byte{}
 
                if M == nil {
-                       auth.setError("srp safety check fail")
+                       auth.fail("srp safety check fail")
                        return
                }
 
@@ -180,72 +184,60 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) {
                        Patch:    0,
                        Reserved: 0,
                        Formspec: 4,
-                       Version:  "hydra-dragonfire",
+                       Version:  auth.version,
                })
                auth.state = asActive
        }
 }
 
-func l_auth_username(l *lua.State) int {
-       auth := getAuth(l)
-
-       if l.IsString(2) {
-               if auth.client.state > csNew {
-                       panic("can't change username after connecting")
+func (auth *Auth) accessProperty(l *lua.LState, key string, ptr *string) int {
+       if str, ok := l.Get(2).(lua.LString); ok {
+               if auth.client.state != csNew {
+                       panic("can't change " + key + " after connecting")
                }
-               auth.username = lua.CheckString(l, 2)
+               *ptr = string(str)
                return 0
        } else {
-               l.PushString(auth.username)
+               l.Push(lua.LString(*ptr))
                return 1
        }
 }
 
-func l_auth_password(l *lua.State) int {
+func l_auth_username(l *lua.LState) int {
        auth := getAuth(l)
+       return auth.accessProperty(l, "username", &auth.username)
+}
 
-       if l.IsString(2) {
-               if auth.client.state > csNew {
-                       panic("can't change password after connecting")
-               }
-               auth.password = lua.CheckString(l, 2)
-               return 0
-       } else {
-               l.PushString(auth.password)
-               return 1
-       }
+func l_auth_password(l *lua.LState) int {
+       auth := getAuth(l)
+       return auth.accessProperty(l, "password", &auth.password)
 }
 
-func l_auth_language(l *lua.State) int {
+func l_auth_language(l *lua.LState) int {
        auth := getAuth(l)
+       return auth.accessProperty(l, "language", &auth.language)
+}
 
-       if l.IsString(2) {
-               if auth.client.state > csNew {
-                       panic("can't change language after connecting")
-               }
-               auth.language = lua.CheckString(l, 2)
-               return 0
-       } else {
-               l.PushString(auth.language)
-               return 1
-       }
+func l_auth_version(l *lua.LState) int {
+       auth := getAuth(l)
+       return auth.accessProperty(l, "version", &auth.version)
 }
 
-func l_auth_state(l *lua.State) int {
+func l_auth_state(l *lua.LState) int {
        auth := getAuth(l)
 
        switch auth.state {
        case asInit:
-               l.PushString("init")
+               l.Push(lua.LString("init"))
        case asRequested:
-               l.PushString("requested")
+               l.Push(lua.LString("requested"))
        case asVerified:
-               l.PushString("verified")
+               l.Push(lua.LString("verified"))
        case asActive:
-               l.PushString("active")
+               l.Push(lua.LString("active"))
        case asError:
-               l.PushString("error")
-               l.PushString(auth.err)
+               l.Push(lua.LString("error"))
+               l.Push(lua.LString(auth.err))
                return 2
        }