From a260500bb56142f3f1ff31eb8740169017bd1ab5 Mon Sep 17 00:00:00 2001 From: wish Date: Sun, 30 Apr 2023 13:51:30 +1000 Subject: [PATCH] further sign server rewrite --- server/signserver/dbutils.go | 39 +++++++++--- server/signserver/dsgn_resp.go | 10 ++- server/signserver/session.go | 109 +++++++++++++-------------------- 3 files changed, 80 insertions(+), 78 deletions(-) diff --git a/server/signserver/dbutils.go b/server/signserver/dbutils.go index d66d85731..0410f351d 100644 --- a/server/signserver/dbutils.go +++ b/server/signserver/dbutils.go @@ -1,8 +1,10 @@ package signserver import ( + "database/sql" "errors" "erupe-ce/common/mhfcourse" + "erupe-ce/common/token" "strings" "time" @@ -220,11 +222,18 @@ func (s *Server) checkToken(uid uint32) (bool, error) { return false, nil } -func (s *Server) registerToken(uid uint32, token string) (uint32, error) { - var id uint32 - var err error - err = s.db.QueryRow("INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&id) - return id, err +func (s *Server) registerUidToken(uid uint32) (uint32, string, error) { + token := token.Generate(16) + var tid uint32 + err := s.db.QueryRow(`INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id`, uid, token).Scan(&tid) + return tid, token, err +} + +func (s *Server) registerPsnToken(psn string) (uint32, string, error) { + token := token.Generate(16) + var tid uint32 + err := s.db.QueryRow(`INSERT INTO sign_sessions (psn_id, token) VALUES ($1, $2) RETURNING id`, psn, token).Scan(&tid) + return tid, token, err } func (s *Server) validateToken(token string, tokenID uint32) bool { @@ -240,16 +249,28 @@ func (s *Server) validateToken(token string, tokenID uint32) bool { return true } -func (s *Server) validateLogin(user string, pass string) (uint32, error) { +func (s *Server) validateLogin(user string, pass string) (uint32, RespID) { var uid uint32 var passDB string err := s.db.QueryRow(`SELECT id, password FROM users WHERE username = $1`, user).Scan(&uid, &passDB) if err != nil { - return 0, err + if err == sql.ErrNoRows { + s.logger.Info("User not found", zap.String("User", user)) + if s.erupeConfig.DevMode && s.erupeConfig.DevModeOptions.AutoCreateAccount { + uid, err = s.registerDBAccount(user, pass) + if err == nil { + return uid, SIGN_SUCCESS + } else { + return 0, SIGN_EABORT + } + } + return 0, SIGN_EAUTH + } + return 0, SIGN_EABORT } else { if bcrypt.CompareHashAndPassword([]byte(passDB), []byte(pass)) == nil { - return uid, nil + return uid, SIGN_SUCCESS } - return 0, nil + return 0, SIGN_EPASS } } diff --git a/server/signserver/dsgn_resp.go b/server/signserver/dsgn_resp.go index f07161e7e..03f9b8aee 100644 --- a/server/signserver/dsgn_resp.go +++ b/server/signserver/dsgn_resp.go @@ -4,7 +4,6 @@ import ( "erupe-ce/common/byteframe" ps "erupe-ce/common/pascalstring" "erupe-ce/common/stringsupport" - "erupe-ce/common/token" "erupe-ce/server/channelserver" "fmt" "go.uber.org/zap" @@ -19,8 +18,13 @@ func (s *Session) makeSignResponse(uid uint32) []byte { } bf := byteframe.NewByteFrame() - sessToken := token.Generate(16) - tokenID, err := s.server.registerToken(uid, sessToken) + var tokenID uint32 + var sessToken string + if uid == 0 && s.psn != "" { + tokenID, sessToken, err = s.server.registerPsnToken(s.psn) + } else { + tokenID, sessToken, err = s.server.registerUidToken(uid) + } if err != nil { bf.WriteUint8(uint8(SIGN_EABORT)) return bf.Data() diff --git a/server/signserver/session.go b/server/signserver/session.go index 3ade1426a..532998054 100644 --- a/server/signserver/session.go +++ b/server/signserver/session.go @@ -31,6 +31,7 @@ type Session struct { rawConn net.Conn cryptConn *network.CryptConn client client + psn string } func (s *Session) work() { @@ -87,46 +88,24 @@ func (s *Session) handlePacket(pkt []byte) error { func (s *Session) authenticate(username string, password string) { newCharaReq := false - if username[len(username)-1] == 43 { // '+' username = username[:len(username)-1] newCharaReq = true } - bf := byteframe.NewByteFrame() - - uid, err := s.server.validateLogin(username, password) - switch { - case err == sql.ErrNoRows: - s.logger.Info("User not found", zap.String("Username", username)) - if s.server.erupeConfig.DevMode && s.server.erupeConfig.DevModeOptions.AutoCreateAccount { - uid, err = s.server.registerDBAccount(username, password) - if err == nil && uid > 0 { - bf.WriteBytes(s.makeSignResponse(uid)) - } - } else { - bf.WriteUint8(uint8(SIGN_EAUTH)) + uid, resp := s.server.validateLogin(username, password) + switch resp { + case SIGN_SUCCESS: + if newCharaReq { + _ = s.server.newUserChara(username) } - case err != nil: - s.logger.Error("Error getting user details", zap.Error(err)) - bf.WriteUint8(uint8(SIGN_EABORT)) + bf.WriteBytes(s.makeSignResponse(uid)) default: - if uid > 0 { - s.logger.Debug("Passwords match!") - if newCharaReq { - _ = s.server.newUserChara(username) - } - bf.WriteBytes(s.makeSignResponse(uid)) - } else { - s.logger.Warn("Incorrect password") - bf.WriteUint8(uint8(SIGN_EPASS)) - } + bf.WriteUint8(uint8(resp)) } - if s.server.erupeConfig.DevMode && s.server.erupeConfig.DevModeOptions.LogOutboundMessages { fmt.Printf("\n[Server] -> [Client]\nData [%d bytes]:\n%s\n", len(bf.Data()), hex.Dump(bf.Data())) } - _ = s.cryptConn.SendPacket(bf.Data()) } @@ -148,9 +127,9 @@ func (s *Session) handlePSSGN(bf *byteframe.ByteFrame) { _ = bf.ReadNullTerminatedBytes() // VITA = 0000000256, PS3 = 0000000255 _ = bf.ReadBytes(2) // VITA = 1, PS3 = ! _ = bf.ReadBytes(82) - psnUser := string(bf.ReadNullTerminatedBytes()) + s.psn = string(bf.ReadNullTerminatedBytes()) var uid uint32 - err := s.server.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, psnUser).Scan(&uid) + err := s.server.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, s.psn).Scan(&uid) if err != nil { if err == sql.ErrNoRows { s.cryptConn.SendPacket(s.makeSignResponse(0)) @@ -166,45 +145,43 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) { _ = bf.ReadNullTerminatedBytes() // Client ID credentials := strings.Split(stringsupport.SJISToUTF8(bf.ReadNullTerminatedBytes()), "\n") token := string(bf.ReadNullTerminatedBytes()) - if s.server.erupeConfig.DevModeOptions.DisableTokenCheck || !s.server.validateToken(token, 0) { - uid, err := s.server.validateLogin(credentials[0], credentials[1]) - if err == nil && uid > 0 { - var psn string - err = s.server.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psn) - if err != nil { - s.sendCode(SIGN_ECOGLINK) - return - } - - var exists int - err = s.server.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psn).Scan(&exists) - if err != nil { - s.sendCode(SIGN_ECOGLINK) - return - } else if exists > 0 { - s.sendCode(SIGN_EPSI) - return - } - - var currentPSN string - err = s.server.db.QueryRow(`SELECT psn_id FROM users WHERE username = $1`, credentials[0]).Scan(¤tPSN) - if err != nil { - s.sendCode(SIGN_ECOGLINK) - return - } else if psn != currentPSN { - s.sendCode(SIGN_EMBID) - return - } - - _, err = s.server.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psn, credentials[0]) - if err == nil { - s.sendCode(SIGN_SUCCESS) - } - } else { + uid, resp := s.server.validateLogin(credentials[0], credentials[1]) + if resp == SIGN_SUCCESS && uid > 0 { + var psn string + err := s.server.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psn) + if err != nil { s.sendCode(SIGN_ECOGLINK) + return } + // Since we check for the psn_id, this will never run + var exists int + err = s.server.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psn).Scan(&exists) + if err != nil { + s.sendCode(SIGN_ECOGLINK) + return + } else if exists > 0 { + s.sendCode(SIGN_EPSI) + return + } + + var currentPSN string + err = s.server.db.QueryRow(`SELECT COALESCE(psn_id, '') FROM users WHERE username = $1`, credentials[0]).Scan(¤tPSN) + if err != nil { + s.sendCode(SIGN_ECOGLINK) + return + } else if currentPSN != "" { + s.sendCode(SIGN_EMBID) + return + } + + _, err = s.server.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psn, credentials[0]) + if err == nil { + s.sendCode(SIGN_SUCCESS) + return + } } + s.sendCode(SIGN_ECOGLINK) } func (s *Session) handleDSGN(bf *byteframe.ByteFrame) {