From 086b338f84ad548edd997ef83b2771feacbc80c9 Mon Sep 17 00:00:00 2001 From: wish Date: Sat, 13 May 2023 01:21:37 +1000 Subject: [PATCH] automatically create new character when none exist --- server/signserver/dbutils.go | 42 ++++++++-------------------------- server/signserver/dsgn_resp.go | 8 ++++++- server/signserver/session.go | 4 ++-- 3 files changed, 18 insertions(+), 36 deletions(-) diff --git a/server/signserver/dbutils.go b/server/signserver/dbutils.go index ee4dcc493..3d7e34318 100644 --- a/server/signserver/dbutils.go +++ b/server/signserver/dbutils.go @@ -8,15 +8,9 @@ import ( "golang.org/x/crypto/bcrypt" ) -func (s *Server) newUserChara(username string) error { - var id int - err := s.db.QueryRow("SELECT id FROM users WHERE username = $1", username).Scan(&id) - if err != nil { - return err - } - +func (s *Server) newUserChara(uid int) error { var numNewChars int - err = s.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", id).Scan(&numNewChars) + err := s.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", uid).Scan(&numNewChars) if err != nil { return err } @@ -31,7 +25,7 @@ func (s *Server) newUserChara(username string) error { user_id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login) VALUES($1, False, True, '', '', 0, 0, 0, $2)`, - id, + uid, uint32(time.Now().Unix()), ) if err != nil { @@ -41,38 +35,20 @@ func (s *Server) newUserChara(username string) error { return nil } -func (s *Server) registerDBAccount(username string, password string) error { +func (s *Server) registerDBAccount(username string, password string) (int, error) { // Create salted hash of user password passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { - return err - } - - _, err = s.db.Exec("INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3)", username, string(passwordHash), time.Now().Add(time.Hour*24*30)) - if err != nil { - return err + return 0, err } var id int - err = s.db.QueryRow("SELECT id FROM users WHERE username = $1", username).Scan(&id) + err = s.db.QueryRow("INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3) RETURNING id", username, string(passwordHash), time.Now().Add(time.Hour*24*30)).Scan(&id) if err != nil { - return err + return 0, err } - // Create a base new character. - _, err = s.db.Exec(` - INSERT INTO characters ( - user_id, is_female, is_new_character, name, unk_desc_string, - hrp, gr, weapon_type, last_login) - VALUES($1, False, True, '', '', 0, 0, 0, $2)`, - id, - uint32(time.Now().Unix()), - ) - if err != nil { - return err - } - - return nil + return id, nil } type character struct { @@ -89,7 +65,7 @@ type character struct { func (s *Server) getCharactersForUser(uid int) ([]character, error) { characters := make([]character, 0) - err := s.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id ASC", uid) + err := s.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id", uid) if err != nil { return nil, err } diff --git a/server/signserver/dsgn_resp.go b/server/signserver/dsgn_resp.go index 40dc707d8..5d47dc067 100644 --- a/server/signserver/dsgn_resp.go +++ b/server/signserver/dsgn_resp.go @@ -14,6 +14,12 @@ import ( func (s *Session) makeSignResponse(uid int) []byte { // Get the characters from the DB. chars, err := s.server.getCharactersForUser(uid) + if len(chars) == 0 { + err = s.server.newUserChara(uid) + if err == nil { + chars, err = s.server.getCharactersForUser(uid) + } + } if err != nil { s.logger.Warn("Error getting characters from DB", zap.Error(err)) } @@ -23,7 +29,7 @@ func (s *Session) makeSignResponse(uid int) []byte { bf := byteframe.NewByteFrame() - bf.WriteUint8(1) // resp_code + bf.WriteUint8(uint8(SIGN_SUCCESS)) // resp_code if (s.server.erupeConfig.PatchServerManifest != "" && s.server.erupeConfig.PatchServerFile != "") || s.client == PS3 { bf.WriteUint8(2) } else { diff --git a/server/signserver/session.go b/server/signserver/session.go index 3eb914c93..de46c5337 100644 --- a/server/signserver/session.go +++ b/server/signserver/session.go @@ -99,7 +99,7 @@ func (s *Session) authenticate(username string, password string) { s.logger.Info("User not found", zap.String("Username", username)) if s.server.erupeConfig.DevMode && s.server.erupeConfig.DevModeOptions.AutoCreateAccount { s.logger.Info("Creating user", zap.String("Username", username)) - err = s.server.registerDBAccount(username, password) + id, err = s.server.registerDBAccount(username, password) if err == nil { bf.WriteBytes(s.makeSignResponse(id)) } @@ -113,7 +113,7 @@ func (s *Session) authenticate(username string, password string) { if bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil || s.client == VITA || s.client == PS3 || s.client == WIIU { s.logger.Debug("Passwords match!") if newCharaReq { - err = s.server.newUserChara(username) + err = s.server.newUserChara(id) if err != nil { s.logger.Error("Error adding new character to user", zap.Error(err)) bf.WriteUint8(uint8(SIGN_EABORT))