X-Git-Url: https://git.lizzy.rs/?a=blobdiff_plain;f=auth.go;h=879ffcee6ecbd85060be09b88148739e71d7d435;hb=f0318bd020abe57c0cf365b0479b5d14b95ff07a;hp=bc8d24e09bdab7c58e4186bd8a988fd95c76b543;hpb=fea98ddbbe886845ed41ab87d9a2d24323c8de82;p=hydra-dragonfire.git diff --git a/auth.go b/auth.go index bc8d24e..879ffce 100644 --- a/auth.go +++ b/auth.go @@ -2,8 +2,9 @@ package main import ( "github.com/HimbeerserverDE/srp" - "github.com/Shopify/go-lua" "github.com/anon55555/mt" + "github.com/dragonfireclient/hydra/tolua" + "github.com/yuin/gopher-lua" "strings" "time" ) @@ -23,47 +24,50 @@ type Auth struct { username string password string language string + version string state authState err string srpBytesA, bytesA []byte + userdata *lua.LUserData } -func getAuth(l *lua.State) *Auth { - return lua.CheckUserData(l, 1, "hydra.auth").(*Auth) +var authFuncs = map[string]lua.LGFunction{ + "username": l_auth_username, + "password": l_auth_password, + "language": l_auth_language, + "version": l_auth_version, + "state": l_auth_state, } -func (auth *Auth) create(client *Client) { +func getAuth(l *lua.LState) *Auth { + return l.CheckUserData(1).Value.(*Auth) +} + +func (auth *Auth) create(client *Client, l *lua.LState) { + if client.state != csNew { + panic("can't add auth component after connect") + } + auth.client = client auth.language = "en_US" + auth.version = "hydra-dragonfire" auth.state = asInit + auth.userdata = l.NewUserData() + auth.userdata.Value = auth + l.SetMetatable(auth.userdata, l.GetTypeMetatable("hydra.auth")) } -func (auth *Auth) push(l *lua.State) { - l.PushUserData(auth) - - if lua.NewMetaTable(l, "hydra.auth") { - lua.NewLibrary(l, []lua.RegistryFunction{ - {Name: "username", Function: l_auth_username}, - {Name: "password", Function: l_auth_password}, - {Name: "language", Function: l_auth_language}, - {Name: "state", Function: l_auth_state}, - }) - l.SetField(-2, "__index") - } - l.SetMetaTable(-2) +func (auth *Auth) tolua() lua.LValue { + return auth.userdata } -func (auth *Auth) canConnect() (bool, string) { +func (auth *Auth) connect() { if auth.username == "" { - return false, "missing username" + panic("missing username") } - return true, "" -} - -func (auth *Auth) connect() { go func() { - for auth.state == asInit && auth.client.state == csConnected { + for auth.client.state == csConnected && auth.state == asInit { auth.client.conn.SendCmd(&mt.ToSrvInit{ SerializeVer: 28, MinProtoVer: 39, @@ -75,10 +79,10 @@ func (auth *Auth) connect() { }() } -func (auth *Auth) setError(err string) { - auth.state = asError +func (auth *Auth) fail(err string) { auth.err = err - auth.client.conn.Close() + auth.state = asError + auth.client.disconnect() } func (auth *Auth) checkState(state authState, pkt *mt.Pkt) bool { @@ -86,12 +90,12 @@ func (auth *Auth) checkState(state authState, pkt *mt.Pkt) bool { return true } - auth.setError("received " + pktToString(pkt) + " in invalid state") + auth.fail("received " + string(tolua.PktType(pkt)) + " in invalid state") return false } -func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { - if pkt == nil { +func (auth *Auth) process(pkt *mt.Pkt) { + if auth.state == asError { return } @@ -102,14 +106,14 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { } if cmd.SerializeVer != 28 { - auth.setError("unsupported serialize_ver") + auth.fail("unsupported serialize version") return } if cmd.AuthMethods == mt.FirstSRP { salt, verifier, err := srp.NewClient([]byte(strings.ToLower(auth.username)), []byte(auth.password)) if err != nil { - auth.setError(err.Error()) + auth.fail(err.Error()) return } @@ -123,7 +127,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { var err error auth.srpBytesA, auth.bytesA, err = srp.InitiateHandshake() if err != nil { - auth.setError(err.Error()) + auth.fail(err.Error()) return } @@ -133,8 +137,8 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { }) auth.state = asRequested } else { - auth.setError("invalid auth methods") - return + auth.fail("invalid auth methods") + return } case *mt.ToCltSRPBytesSaltB: @@ -144,7 +148,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { srpBytesK, err := srp.CompleteHandshake(auth.srpBytesA, auth.bytesA, []byte(strings.ToLower(auth.username)), []byte(auth.password), cmd.Salt, cmd.B) if err != nil { - auth.setError(err.Error()) + auth.fail(err.Error()) return } @@ -153,7 +157,7 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { auth.bytesA = []byte{} if M == nil { - auth.setError("srp safety check fail") + auth.fail("srp safety check fail") return } @@ -180,72 +184,60 @@ func (auth *Auth) handle(pkt *mt.Pkt, l *lua.State, idx int) { Patch: 0, Reserved: 0, Formspec: 4, - Version: "hydra-dragonfire", + Version: auth.version, }) auth.state = asActive } } -func l_auth_username(l *lua.State) int { - auth := getAuth(l) - - if l.IsString(2) { - if auth.client.state > csNew { - panic("can't change username after connecting") +func (auth *Auth) accessProperty(l *lua.LState, key string, ptr *string) int { + if str, ok := l.Get(2).(lua.LString); ok { + if auth.client.state != csNew { + panic("can't change " + key + " after connecting") } - auth.username = lua.CheckString(l, 2) + *ptr = string(str) return 0 } else { - l.PushString(auth.username) + l.Push(lua.LString(*ptr)) return 1 } } -func l_auth_password(l *lua.State) int { +func l_auth_username(l *lua.LState) int { auth := getAuth(l) + return auth.accessProperty(l, "username", &auth.username) +} - if l.IsString(2) { - if auth.client.state > csNew { - panic("can't change password after connecting") - } - auth.password = lua.CheckString(l, 2) - return 0 - } else { - l.PushString(auth.password) - return 1 - } +func l_auth_password(l *lua.LState) int { + auth := getAuth(l) + return auth.accessProperty(l, "password", &auth.password) } -func l_auth_language(l *lua.State) int { +func l_auth_language(l *lua.LState) int { auth := getAuth(l) + return auth.accessProperty(l, "language", &auth.language) +} - if l.IsString(2) { - if auth.client.state > csNew { - panic("can't change language after connecting") - } - auth.language = lua.CheckString(l, 2) - return 0 - } else { - l.PushString(auth.language) - return 1 - } +func l_auth_version(l *lua.LState) int { + auth := getAuth(l) + return auth.accessProperty(l, "version", &auth.version) } -func l_auth_state(l *lua.State) int { +func l_auth_state(l *lua.LState) int { auth := getAuth(l) switch auth.state { case asInit: - l.PushString("init") + l.Push(lua.LString("init")) case asRequested: - l.PushString("requested") + l.Push(lua.LString("requested")) case asVerified: - l.PushString("verified") + l.Push(lua.LString("verified")) case asActive: - l.PushString("active") + l.Push(lua.LString("active")) case asError: - l.PushString("error") - l.PushString(auth.err) + l.Push(lua.LString("error")) + l.Push(lua.LString(auth.err)) return 2 }