diff --git a/server/signserver/dbutils.go b/server/signserver/dbutils.go index c6930bf15..5022f5434 100644 --- a/server/signserver/dbutils.go +++ b/server/signserver/dbutils.go @@ -5,7 +5,6 @@ import ( "errors" "erupe-ce/common/mhfcourse" "erupe-ce/common/token" - "strings" "time" "go.uber.org/zap" @@ -13,34 +12,20 @@ import ( ) func (s *Server) newUserChara(uid uint32) error { - var numNewChars int - err := s.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", uid).Scan(&numNewChars) + numNewChars, err := s.charRepo.CountNewCharacters(uid) if err != nil { return err } // prevent users with an uninitialised character from creating more if numNewChars >= 1 { - return err + return nil } - _, err = s.db.Exec(` - INSERT INTO characters ( - user_id, is_female, is_new_character, name, unk_desc_string, - hr, gr, weapon_type, last_login) - VALUES($1, False, True, '', '', 0, 0, 0, $2)`, - uid, - uint32(time.Now().Unix()), - ) - if err != nil { - return err - } - - return nil + return s.charRepo.CreateCharacter(uid, uint32(time.Now().Unix())) } func (s *Server) registerDBAccount(username string, password string) (uint32, error) { - var uid uint32 s.logger.Info("Creating user", zap.String("User", username)) // Create salted hash of user password @@ -49,7 +34,7 @@ func (s *Server) registerDBAccount(username string, password string) (uint32, er return 0, err } - err = s.db.QueryRow("INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3) RETURNING id", username, string(passwordHash), time.Now().Add(time.Hour*24*30)).Scan(&uid) + uid, err := s.userRepo.Register(username, string(passwordHash), time.Now().Add(time.Hour*24*30)) if err != nil { return 0, err } @@ -57,81 +42,65 @@ func (s *Server) registerDBAccount(username string, password string) (uint32, er return uid, nil } -type character struct { - ID uint32 `db:"id"` - IsFemale bool `db:"is_female"` - IsNewCharacter bool `db:"is_new_character"` - Name string `db:"name"` - UnkDescString string `db:"unk_desc_string"` - HR uint16 `db:"hr"` - GR uint16 `db:"gr"` - WeaponType uint16 `db:"weapon_type"` - LastLogin uint32 `db:"last_login"` -} - func (s *Server) getCharactersForUser(uid uint32) ([]character, error) { - characters := make([]character, 0) - err := s.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id", uid) - if err != nil { - return nil, err - } - return characters, nil + return s.charRepo.GetForUser(uid) } func (s *Server) getReturnExpiry(uid uint32) time.Time { - var returnExpiry, lastLogin time.Time - _ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + var returnExpiry time.Time + lastLogin, err := s.userRepo.GetLastLogin(uid) + if err != nil { + s.logger.Warn("Failed to get last login", zap.Uint32("uid", uid), zap.Error(err)) + lastLogin = time.Now() + } if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) { returnExpiry = time.Now().Add(time.Hour * 24 * 30) - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + if err := s.userRepo.UpdateReturnExpiry(uid, returnExpiry); err != nil { + s.logger.Warn("Failed to update return expiry", zap.Uint32("uid", uid), zap.Error(err)) + } } else { - err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) + returnExpiry, err = s.userRepo.GetReturnExpiry(uid) if err != nil { returnExpiry = time.Now() - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + if err := s.userRepo.UpdateReturnExpiry(uid, returnExpiry); err != nil { + s.logger.Warn("Failed to update return expiry (fallback)", zap.Uint32("uid", uid), zap.Error(err)) + } } } - _, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid) + if err := s.userRepo.UpdateLastLogin(uid, time.Now()); err != nil { + s.logger.Warn("Failed to update last login", zap.Uint32("uid", uid), zap.Error(err)) + } return returnExpiry } func (s *Server) getLastCID(uid uint32) uint32 { - var lastPlayed uint32 - _ = s.db.QueryRow("SELECT last_character FROM users WHERE id=$1", uid).Scan(&lastPlayed) + lastPlayed, err := s.userRepo.GetLastCharacter(uid) + if err != nil { + s.logger.Warn("Failed to get last character", zap.Uint32("uid", uid), zap.Error(err)) + return 0 + } return lastPlayed } func (s *Server) getUserRights(uid uint32) uint32 { - var rights uint32 - if uid != 0 { - _ = s.db.QueryRow("SELECT rights FROM users WHERE id=$1", uid).Scan(&rights) - _, rights = mhfcourse.GetCourseStruct(rights, s.erupeConfig.DefaultCourses) + if uid == 0 { + return 0 } + rights, err := s.userRepo.GetRights(uid) + if err != nil { + s.logger.Warn("Failed to get user rights", zap.Uint32("uid", uid), zap.Error(err)) + return 0 + } + _, rights = mhfcourse.GetCourseStruct(rights, s.erupeConfig.DefaultCourses) return rights } -type members struct { - CID uint32 // Local character ID - ID uint32 `db:"id"` - Name string `db:"name"` -} - func (s *Server) getFriendsForCharacters(chars []character) []members { friends := make([]members, 0) for _, char := range chars { - friendsCSV := "" - _ = s.db.QueryRow("SELECT friends FROM characters WHERE id=$1", char.ID).Scan(&friendsCSV) - friendsSlice := strings.Split(friendsCSV, ",") - friendQuery := "SELECT id, name FROM characters WHERE id=" - for i := 0; i < len(friendsSlice); i++ { - friendQuery += friendsSlice[i] - if i+1 != len(friendsSlice) { - friendQuery += " OR id=" - } - } - charFriends := make([]members, 0) - err := s.db.Select(&charFriends, friendQuery) + charFriends, err := s.charRepo.GetFriends(char.ID) if err != nil { + s.logger.Warn("Failed to get friends", zap.Uint32("charID", char.ID), zap.Error(err)) continue } for i := range charFriends { @@ -145,79 +114,56 @@ func (s *Server) getFriendsForCharacters(chars []character) []members { func (s *Server) getGuildmatesForCharacters(chars []character) []members { guildmates := make([]members, 0) for _, char := range chars { - var inGuild int - _ = s.db.QueryRow("SELECT count(*) FROM guild_characters WHERE character_id=$1", char.ID).Scan(&inGuild) - if inGuild > 0 { - var guildID int - err := s.db.QueryRow("SELECT guild_id FROM guild_characters WHERE character_id=$1", char.ID).Scan(&guildID) - if err != nil { - continue - } - charGuildmates := make([]members, 0) - err = s.db.Select(&charGuildmates, "SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=$1 AND character_id!=$2", guildID, char.ID) - if err != nil { - continue - } - for i := range charGuildmates { - charGuildmates[i].CID = char.ID - } - guildmates = append(guildmates, charGuildmates...) + charGuildmates, err := s.charRepo.GetGuildmates(char.ID) + if err != nil { + s.logger.Warn("Failed to get guildmates", zap.Uint32("charID", char.ID), zap.Error(err)) + continue } + for i := range charGuildmates { + charGuildmates[i].CID = char.ID + } + guildmates = append(guildmates, charGuildmates...) } return guildmates } -func (s *Server) deleteCharacter(cid int, token string, tokenID uint32) error { - if !s.validateToken(token, tokenID) { +func (s *Server) deleteCharacter(cid int, tok string, tokenID uint32) error { + if !s.validateToken(tok, tokenID) { return errors.New("invalid token") } - var isNew bool - err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", cid).Scan(&isNew) + isNew, err := s.charRepo.IsNewCharacter(cid) if err != nil { return err } if isNew { - _, err = s.db.Exec("DELETE FROM characters WHERE id = $1", cid) - } else { - _, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", cid) + return s.charRepo.HardDelete(cid) } - if err != nil { - return err - } - return nil + return s.charRepo.SoftDelete(cid) } func (s *Server) registerUidToken(uid uint32) (uint32, string, error) { _token := token.Generate(16) - var tid uint32 - err := s.db.QueryRow(`INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id`, uid, _token).Scan(&tid) + tid, err := s.sessionRepo.RegisterUID(uid, _token) return tid, _token, err } func (s *Server) registerPsnToken(psn string) (uint32, string, error) { _token := token.Generate(16) - var tid uint32 - err := s.db.QueryRow(`INSERT INTO sign_sessions (psn_id, token) VALUES ($1, $2) RETURNING id`, psn, _token).Scan(&tid) + tid, err := s.sessionRepo.RegisterPSN(psn, _token) return tid, _token, err } -func (s *Server) validateToken(token string, tokenID uint32) bool { - query := `SELECT count(*) FROM sign_sessions WHERE token = $1` - if tokenID > 0 { - query += ` AND id = $2` - } - var exists int - err := s.db.QueryRow(query, token, tokenID).Scan(&exists) - if err != nil || exists == 0 { +func (s *Server) validateToken(tok string, tokenID uint32) bool { + valid, err := s.sessionRepo.Validate(tok, tokenID) + if err != nil { + s.logger.Warn("Failed to validate token", zap.Error(err)) return false } - return true + return valid } func (s *Server) validateLogin(user string, pass string) (uint32, RespID) { - var uid uint32 - var passDB string - err := s.db.QueryRow(`SELECT id, password FROM users WHERE username = $1`, user).Scan(&uid, &passDB) + uid, passDB, err := s.userRepo.GetCredentials(user) if err != nil { if errors.Is(err, sql.ErrNoRows) { s.logger.Info("User not found", zap.String("User", user)) @@ -225,26 +171,25 @@ func (s *Server) validateLogin(user string, pass string) (uint32, RespID) { uid, err = s.registerDBAccount(user, pass) if err == nil { return uid, SIGN_SUCCESS - } else { - return 0, SIGN_EABORT } + return 0, SIGN_EABORT } return 0, SIGN_EAUTH } return 0, SIGN_EABORT - } else { - if bcrypt.CompareHashAndPassword([]byte(passDB), []byte(pass)) == nil { - var bans int - err = s.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires IS NULL`, uid).Scan(&bans) - if err == nil && bans > 0 { - return uid, SIGN_EELIMINATE - } - err = s.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires > now()`, uid).Scan(&bans) - if err == nil && bans > 0 { - return uid, SIGN_ESUSPEND - } - return uid, SIGN_SUCCESS - } + } + + if bcrypt.CompareHashAndPassword([]byte(passDB), []byte(pass)) != nil { return 0, SIGN_EPASS } + + bans, err := s.userRepo.CountPermanentBans(uid) + if err == nil && bans > 0 { + return uid, SIGN_EELIMINATE + } + bans, err = s.userRepo.CountActiveBans(uid) + if err == nil && bans > 0 { + return uid, SIGN_ESUSPEND + } + return uid, SIGN_SUCCESS } diff --git a/server/signserver/dbutils_test.go b/server/signserver/dbutils_test.go index 19ff6ebf1..9a2ad3b7c 100644 --- a/server/signserver/dbutils_test.go +++ b/server/signserver/dbutils_test.go @@ -6,8 +6,7 @@ import ( "time" cfg "erupe-ce/config" - "github.com/DATA-DOG/go-sqlmock" - "github.com/jmoiron/sqlx" + "go.uber.org/zap" ) @@ -292,315 +291,276 @@ func TestMultipleMembers(t *testing.T) { } } -// Helper to create a test server with mocked database -func newTestServerWithMock(t *testing.T) (*Server, sqlmock.Sqlmock) { - db, mock, err := sqlmock.New() - if err != nil { - t.Fatalf("failed to create sqlmock: %v", err) +func TestGetCharactersForUser(t *testing.T) { + charRepo := &mockSignCharacterRepo{ + characters: []character{ + {ID: 1, IsFemale: false, Name: "Hunter1", HR: 100, GR: 50, WeaponType: 3, LastLogin: 1700000000}, + {ID: 2, IsFemale: true, Name: "Hunter2", HR: 200, GR: 100, WeaponType: 7, LastLogin: 1700000001}, + }, } - sqlxDB := sqlx.NewDb(db, "sqlmock") - server := &Server{ logger: zap.NewNop(), - db: sqlxDB, erupeConfig: &cfg.Config{}, + charRepo: charRepo, } - return server, mock -} - -func TestGetCharactersForUser(t *testing.T) { - server, mock := newTestServerWithMock(t) - - rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hr", "gr", "weapon_type", "last_login"}). - AddRow(1, false, false, "Hunter1", "desc1", 100, 50, 3, 1700000000). - AddRow(2, true, false, "Hunter2", "desc2", 200, 100, 7, 1700000001) - - mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id"). - WithArgs(uint32(1)). - WillReturnRows(rows) - chars, err := server.getCharactersForUser(1) if err != nil { t.Errorf("getCharactersForUser() error: %v", err) } - if len(chars) != 2 { t.Errorf("getCharactersForUser() returned %d characters, want 2", len(chars)) } - if chars[0].Name != "Hunter1" { t.Errorf("First character name = %s, want Hunter1", chars[0].Name) } - if chars[1].IsFemale != true { t.Error("Second character should be female") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetCharactersForUserNoCharacters(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + characters: []character{}, + } - rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hr", "gr", "weapon_type", "last_login"}) - - mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id"). - WithArgs(uint32(1)). - WillReturnRows(rows) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars, err := server.getCharactersForUser(1) if err != nil { t.Errorf("getCharactersForUser() error: %v", err) } - if len(chars) != 0 { t.Errorf("getCharactersForUser() returned %d characters, want 0", len(chars)) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetCharactersForUserDBError(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + getForUserErr: sql.ErrConnDone, + } - mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id"). - WithArgs(uint32(1)). - WillReturnError(sql.ErrConnDone) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } _, err := server.getCharactersForUser(1) if err == nil { t.Error("getCharactersForUser() should return error") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetLastCID(t *testing.T) { - server, mock := newTestServerWithMock(t) + userRepo := &mockSignUserRepo{ + lastCharacter: 12345, + } - mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(12345)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } lastCID := server.getLastCID(1) if lastCID != 12345 { t.Errorf("getLastCID() = %d, want 12345", lastCID) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetLastCIDNoResult(t *testing.T) { - server, mock := newTestServerWithMock(t) + userRepo := &mockSignUserRepo{ + lastCharacterErr: sql.ErrNoRows, + } - mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnError(sql.ErrNoRows) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } lastCID := server.getLastCID(1) if lastCID != 0 { t.Errorf("getLastCID() with no result = %d, want 0", lastCID) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetUserRights(t *testing.T) { - server, mock := newTestServerWithMock(t) + userRepo := &mockSignUserRepo{ + rights: 30, + } - mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(30)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } rights := server.getUserRights(1) if rights == 0 { t.Error("getUserRights() should return non-zero value") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetReturnExpiry(t *testing.T) { - server, mock := newTestServerWithMock(t) - recentLogin := time.Now().Add(-time.Hour * 24) - mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin)) + userRepo := &mockSignUserRepo{ + lastLogin: recentLogin, + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + } - mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30))) - - mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). - WithArgs(sqlmock.AnyArg(), uint32(1)). - WillReturnResult(sqlmock.NewResult(0, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } expiry := server.getReturnExpiry(1) - if expiry.Before(time.Now()) { t.Error("getReturnExpiry() should return future date") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !userRepo.updateLastLoginCalled { + t.Error("getReturnExpiry() should update last login") } } func TestGetReturnExpiryInactiveUser(t *testing.T) { - server, mock := newTestServerWithMock(t) - oldLogin := time.Now().Add(-time.Hour * 24 * 100) - mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(oldLogin)) + userRepo := &mockSignUserRepo{ + lastLogin: oldLogin, + } - mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2"). - WithArgs(sqlmock.AnyArg(), uint32(1)). - WillReturnResult(sqlmock.NewResult(0, 1)) - - mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). - WithArgs(sqlmock.AnyArg(), uint32(1)). - WillReturnResult(sqlmock.NewResult(0, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } expiry := server.getReturnExpiry(1) - if expiry.Before(time.Now()) { t.Error("getReturnExpiry() should return future date for inactive user") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !userRepo.updateReturnExpiryCalled { + t.Error("getReturnExpiry() should update return expiry for inactive user") + } + if !userRepo.updateLastLoginCalled { + t.Error("getReturnExpiry() should update last login") } } func TestGetReturnExpiryDBError(t *testing.T) { - server, mock := newTestServerWithMock(t) - recentLogin := time.Now().Add(-time.Hour * 24) - mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin)) + userRepo := &mockSignUserRepo{ + lastLogin: recentLogin, + returnExpiryErr: sql.ErrNoRows, + } - mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnError(sql.ErrNoRows) - - mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2"). - WithArgs(sqlmock.AnyArg(), uint32(1)). - WillReturnResult(sqlmock.NewResult(0, 1)) - - mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). - WithArgs(sqlmock.AnyArg(), uint32(1)). - WillReturnResult(sqlmock.NewResult(0, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } expiry := server.getReturnExpiry(1) - if expiry.IsZero() { t.Error("getReturnExpiry() should return non-zero time even on error") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !userRepo.updateReturnExpiryCalled { + t.Error("getReturnExpiry() should update return expiry on fallback") } } func TestNewUserChara(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + newCharCount: 0, + } - mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - - mock.ExpectExec("INSERT INTO characters"). - WithArgs(uint32(1), sqlmock.AnyArg()). - WillReturnResult(sqlmock.NewResult(1, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } err := server.newUserChara(1) if err != nil { t.Errorf("newUserChara() error: %v", err) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !charRepo.createCalled { + t.Error("newUserChara() should call CreateCharacter") } } func TestNewUserCharaAlreadyHasNewChar(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + newCharCount: 1, + } - mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } err := server.newUserChara(1) if err != nil { t.Errorf("newUserChara() should return nil when user already has new char: %v", err) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if charRepo.createCalled { + t.Error("newUserChara() should not call CreateCharacter when user already has new char") } } func TestNewUserCharaCountError(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + newCharCountErr: sql.ErrConnDone, + } - mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). - WithArgs(uint32(1)). - WillReturnError(sql.ErrConnDone) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } err := server.newUserChara(1) if err == nil { t.Error("newUserChara() should return error when count query fails") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestNewUserCharaInsertError(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + newCharCount: 0, + createErr: sql.ErrConnDone, + } - mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - - mock.ExpectExec("INSERT INTO characters"). - WithArgs(uint32(1), sqlmock.AnyArg()). - WillReturnError(sql.ErrConnDone) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } err := server.newUserChara(1) if err == nil { t.Error("newUserChara() should return error when insert fails") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestRegisterDBAccount(t *testing.T) { - server, mock := newTestServerWithMock(t) + userRepo := &mockSignUserRepo{ + registerUID: 1, + } - mock.ExpectQuery("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\) RETURNING id"). - WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()). - WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } uid, err := server.registerDBAccount("newuser", "password123") if err != nil { @@ -609,128 +569,125 @@ func TestRegisterDBAccount(t *testing.T) { if uid != 1 { t.Errorf("registerDBAccount() uid = %d, want 1", uid) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !userRepo.registered { + t.Error("registerDBAccount() should call Register") } } func TestRegisterDBAccountDuplicateUser(t *testing.T) { - server, mock := newTestServerWithMock(t) + userRepo := &mockSignUserRepo{ + registerErr: sql.ErrNoRows, + } - mock.ExpectQuery("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\) RETURNING id"). - WithArgs("existinguser", sqlmock.AnyArg(), sqlmock.AnyArg()). - WillReturnError(sql.ErrNoRows) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } _, err := server.registerDBAccount("existinguser", "password123") if err == nil { t.Error("registerDBAccount() should return error for duplicate user") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestDeleteCharacter(t *testing.T) { - server, mock := newTestServerWithMock(t) + sessionRepo := &mockSignSessionRepo{ + validateResult: true, + } + charRepo := &mockSignCharacterRepo{ + isNew: false, + } - // validateToken: SELECT count(*) FROM sign_sessions WHERE token = $1 - // When tokenID=0, query has no AND clause but both args are still passed to QueryRow - mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). - WithArgs("validtoken", uint32(0)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - - mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). - WithArgs(123). - WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false)) - - mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1"). - WithArgs(123). - WillReturnResult(sqlmock.NewResult(0, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + charRepo: charRepo, + } err := server.deleteCharacter(123, "validtoken", 0) if err != nil { t.Errorf("deleteCharacter() error: %v", err) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !charRepo.softDeleteCalled { + t.Error("deleteCharacter() should soft delete existing character") } } func TestDeleteNewCharacter(t *testing.T) { - server, mock := newTestServerWithMock(t) + sessionRepo := &mockSignSessionRepo{ + validateResult: true, + } + charRepo := &mockSignCharacterRepo{ + isNew: true, + } - mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). - WithArgs("validtoken", uint32(0)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - - mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). - WithArgs(123). - WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(true)) - - mock.ExpectExec("DELETE FROM characters WHERE id = \\$1"). - WithArgs(123). - WillReturnResult(sqlmock.NewResult(0, 1)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + charRepo: charRepo, + } err := server.deleteCharacter(123, "validtoken", 0) if err != nil { t.Errorf("deleteCharacter() error: %v", err) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if !charRepo.hardDeleteCalled { + t.Error("deleteCharacter() should hard delete new character") } } func TestDeleteCharacterInvalidToken(t *testing.T) { - server, mock := newTestServerWithMock(t) + sessionRepo := &mockSignSessionRepo{ + validateResult: false, + } - mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). - WithArgs("invalidtoken", uint32(0)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + } err := server.deleteCharacter(123, "invalidtoken", 0) if err == nil { t.Error("deleteCharacter() should return error for invalid token") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestDeleteCharacterDeleteError(t *testing.T) { - server, mock := newTestServerWithMock(t) + sessionRepo := &mockSignSessionRepo{ + validateResult: true, + } + charRepo := &mockSignCharacterRepo{ + isNew: false, + softDeleteErr: sql.ErrConnDone, + } - mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). - WithArgs("validtoken", uint32(0)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - - mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). - WithArgs(123). - WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false)) - - mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1"). - WithArgs(123). - WillReturnError(sql.ErrConnDone) + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + charRepo: charRepo, + } err := server.deleteCharacter(123, "validtoken", 0) if err == nil { t.Error("deleteCharacter() should return error when update fails") } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) - } } func TestGetFriendsForCharactersEmpty(t *testing.T) { - server, _ := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{} + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars := []character{} - friends := server.getFriendsForCharacters(chars) if len(friends) != 0 { t.Errorf("getFriendsForCharacters() for empty chars = %d, want 0", len(friends)) @@ -738,10 +695,15 @@ func TestGetFriendsForCharactersEmpty(t *testing.T) { } func TestGetGuildmatesForCharactersEmpty(t *testing.T) { - server, _ := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{} + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars := []character{} - guildmates := server.getGuildmatesForCharacters(chars) if len(guildmates) != 0 { t.Errorf("getGuildmatesForCharacters() for empty chars = %d, want 0", len(guildmates)) @@ -749,79 +711,269 @@ func TestGetGuildmatesForCharactersEmpty(t *testing.T) { } func TestGetFriendsForCharacters(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + friends: []members{ + {ID: 2, Name: "Friend1"}, + {ID: 3, Name: "Friend2"}, + }, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars := []character{ {ID: 1, Name: "Hunter1"}, } - mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"friends"}).AddRow("2,3")) - - mock.ExpectQuery("SELECT id, name FROM characters WHERE id=2 OR id=3"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(2, "Friend1"). - AddRow(3, "Friend2")) - friends := server.getFriendsForCharacters(chars) if len(friends) != 2 { t.Errorf("getFriendsForCharacters() = %d, want 2", len(friends)) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if friends[0].CID != 1 { + t.Errorf("friends[0].CID = %d, want 1", friends[0].CID) } } func TestGetGuildmatesForCharacters(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + guildmates: []members{ + {ID: 2, Name: "Guildmate1"}, + {ID: 3, Name: "Guildmate2"}, + }, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars := []character{ {ID: 1, Name: "Hunter1"}, } - mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) - - mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"guild_id"}).AddRow(100)) - - mock.ExpectQuery("SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=\\$1 AND character_id!=\\$2"). - WithArgs(100, uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}). - AddRow(2, "Guildmate1"). - AddRow(3, "Guildmate2")) - guildmates := server.getGuildmatesForCharacters(chars) if len(guildmates) != 2 { t.Errorf("getGuildmatesForCharacters() = %d, want 2", len(guildmates)) } - - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) + if guildmates[0].CID != 1 { + t.Errorf("guildmates[0].CID = %d, want 1", guildmates[0].CID) } } func TestGetGuildmatesNotInGuild(t *testing.T) { - server, mock := newTestServerWithMock(t) + charRepo := &mockSignCharacterRepo{ + guildmates: nil, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } chars := []character{ {ID: 1, Name: "Hunter1"}, } - mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). - WithArgs(uint32(1)). - WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) - guildmates := server.getGuildmatesForCharacters(chars) if len(guildmates) != 0 { t.Errorf("getGuildmatesForCharacters() for non-guild member = %d, want 0", len(guildmates)) } +} - if err := mock.ExpectationsWereMet(); err != nil { - t.Errorf("unfulfilled expectations: %v", err) +func TestValidateLoginSuccess(t *testing.T) { + // bcrypt hash for "password123" + hash := "$2a$10$N9qo8uLOickgx2ZMRZoMyeIjZAgcfl7p92ldGxad68LJZdL17lhWy" + userRepo := &mockSignUserRepo{ + credUID: 1, + credPassword: hash, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } + + // Note: bcrypt verification will fail with this test hash since it's not a real hash of "password123" + // The important thing is testing the flow, not actual bcrypt verification + _, resp := server.validateLogin("testuser", "password123") + // This will return SIGN_EPASS since the hash doesn't match, which is expected behavior + if resp == SIGN_EABORT { + t.Error("validateLogin() should not abort for valid credentials lookup") + } +} + +func TestValidateLoginUserNotFound(t *testing.T) { + userRepo := &mockSignUserRepo{ + credErr: sql.ErrNoRows, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } + + _, resp := server.validateLogin("unknown", "password") + if resp != SIGN_EAUTH { + t.Errorf("validateLogin() for unknown user = %d, want SIGN_EAUTH(%d)", resp, SIGN_EAUTH) + } +} + +func TestValidateLoginAutoCreate(t *testing.T) { + userRepo := &mockSignUserRepo{ + credErr: sql.ErrNoRows, + registerUID: 42, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{ + AutoCreateAccount: true, + }, + userRepo: userRepo, + } + + uid, resp := server.validateLogin("newuser", "password") + if resp != SIGN_SUCCESS { + t.Errorf("validateLogin() with auto-create = %d, want SIGN_SUCCESS(%d)", resp, SIGN_SUCCESS) + } + if uid != 42 { + t.Errorf("validateLogin() uid = %d, want 42", uid) + } +} + +func TestValidateLoginDBError(t *testing.T) { + userRepo := &mockSignUserRepo{ + credErr: sql.ErrConnDone, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } + + _, resp := server.validateLogin("testuser", "password") + if resp != SIGN_EABORT { + t.Errorf("validateLogin() on DB error = %d, want SIGN_EABORT(%d)", resp, SIGN_EABORT) + } +} + +func TestValidateTokenValid(t *testing.T) { + sessionRepo := &mockSignSessionRepo{ + validateResult: true, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + } + + if !server.validateToken("validtoken", 0) { + t.Error("validateToken() should return true for valid token") + } +} + +func TestValidateTokenInvalid(t *testing.T) { + sessionRepo := &mockSignSessionRepo{ + validateResult: false, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + } + + if server.validateToken("invalidtoken", 0) { + t.Error("validateToken() should return false for invalid token") + } +} + +func TestValidateTokenDBError(t *testing.T) { + sessionRepo := &mockSignSessionRepo{ + validateErr: sql.ErrConnDone, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + sessionRepo: sessionRepo, + } + + if server.validateToken("token", 0) { + t.Error("validateToken() should return false on DB error") + } +} + +func TestGetUserRightsZeroUID(t *testing.T) { + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + } + + rights := server.getUserRights(0) + if rights != 0 { + t.Errorf("getUserRights(0) = %d, want 0", rights) + } +} + +func TestGetUserRightsDBError(t *testing.T) { + userRepo := &mockSignUserRepo{ + rightsErr: sql.ErrConnDone, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + userRepo: userRepo, + } + + rights := server.getUserRights(1) + if rights != 0 { + t.Errorf("getUserRights() on error = %d, want 0", rights) + } +} + +func TestGetFriendsForCharactersError(t *testing.T) { + charRepo := &mockSignCharacterRepo{ + getFriendsErr: errMockDB, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } + + chars := []character{{ID: 1, Name: "Hunter1"}} + friends := server.getFriendsForCharacters(chars) + if len(friends) != 0 { + t.Errorf("getFriendsForCharacters() on error = %d, want 0", len(friends)) + } +} + +func TestGetGuildmatesForCharactersError(t *testing.T) { + charRepo := &mockSignCharacterRepo{ + getGuildmatesErr: errMockDB, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: &cfg.Config{}, + charRepo: charRepo, + } + + chars := []character{{ID: 1, Name: "Hunter1"}} + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 0 { + t.Errorf("getGuildmatesForCharacters() on error = %d, want 0", len(guildmates)) } } diff --git a/server/signserver/dsgn_resp.go b/server/signserver/dsgn_resp.go index 29d7ab4ff..686ec0eab 100644 --- a/server/signserver/dsgn_resp.go +++ b/server/signserver/dsgn_resp.go @@ -333,8 +333,10 @@ func (s *Session) makeSignResponse(uid uint32) []byte { bf.WriteBytes(filters.Data()) if s.client == VITA || s.client == PS3 || s.client == PS4 { - var psnUser string - _ = s.server.db.QueryRow("SELECT psn_id FROM users WHERE id = $1", uid).Scan(&psnUser) + psnUser, err := s.server.userRepo.GetPSNIDForUser(uid) + if err != nil { + s.logger.Warn("Failed to get PSN ID for user", zap.Uint32("uid", uid), zap.Error(err)) + } bf.WriteBytes(stringsupport.PaddedString(psnUser, 20, true)) } diff --git a/server/signserver/dsgn_resp_test.go b/server/signserver/dsgn_resp_test.go index 0949c0007..dce05f800 100644 --- a/server/signserver/dsgn_resp_test.go +++ b/server/signserver/dsgn_resp_test.go @@ -4,12 +4,33 @@ import ( "fmt" "strings" "testing" + "time" "go.uber.org/zap" cfg "erupe-ce/config" ) +// newMakeSignResponseServer creates a Server with mock repos for makeSignResponse tests. +func newMakeSignResponseServer(config *cfg.Config) *Server { + return &Server{ + erupeConfig: config, + logger: zap.NewNop(), + charRepo: &mockSignCharacterRepo{ + characters: []character{}, + friends: nil, + guildmates: nil, + }, + userRepo: &mockSignUserRepo{ + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + lastLogin: time.Now(), + }, + sessionRepo: &mockSignSessionRepo{ + registerUIDTokenID: 1, + }, + } +} + // TestMakeSignResponse_EmptyCapLinkValues verifies the crash is FIXED when CapLink.Values is empty // Previously panicked: runtime error: index out of range [0] with length 0 // From erupe.log.1:659796 and 659853 @@ -37,10 +58,7 @@ func TestMakeSignResponse_EmptyCapLinkValues(t *testing.T) { session := &Session{ logger: zap.NewNop(), - server: &Server{ - erupeConfig: config, - logger: zap.NewNop(), - }, + server: newMakeSignResponseServer(config), client: PC100, } @@ -61,7 +79,7 @@ func TestMakeSignResponse_EmptyCapLinkValues(t *testing.T) { // This should NOT panic on array bounds anymore result := session.makeSignResponse(0) if len(result) > 0 { - t.Log("✅ makeSignResponse handled empty CapLink.Values without array bounds panic") + t.Log("makeSignResponse handled empty CapLink.Values without array bounds panic") } } @@ -89,10 +107,7 @@ func TestMakeSignResponse_InsufficientCapLinkValues(t *testing.T) { session := &Session{ logger: zap.NewNop(), - server: &Server{ - erupeConfig: config, - logger: zap.NewNop(), - }, + server: newMakeSignResponseServer(config), client: PC100, } @@ -110,7 +125,7 @@ func TestMakeSignResponse_InsufficientCapLinkValues(t *testing.T) { // This should NOT panic on array bounds anymore result := session.makeSignResponse(0) if len(result) > 0 { - t.Log("✅ makeSignResponse handled insufficient CapLink.Values without array bounds panic") + t.Log("makeSignResponse handled insufficient CapLink.Values without array bounds panic") } } @@ -138,10 +153,7 @@ func TestMakeSignResponse_MissingCapLinkValues234(t *testing.T) { session := &Session{ logger: zap.NewNop(), - server: &Server{ - erupeConfig: config, - logger: zap.NewNop(), - }, + server: newMakeSignResponseServer(config), client: PC100, } @@ -159,7 +171,7 @@ func TestMakeSignResponse_MissingCapLinkValues234(t *testing.T) { // This should NOT panic on array bounds anymore result := session.makeSignResponse(0) if len(result) > 0 { - t.Log("✅ makeSignResponse handled missing CapLink.Values[2/3/4] without array bounds panic") + t.Log("makeSignResponse handled missing CapLink.Values[2/3/4] without array bounds panic") } } @@ -207,7 +219,47 @@ func TestCapLinkValuesBoundsChecking(t *testing.T) { } } - t.Logf("✅ %s: All 5 indices accessible without panic", tc.name) + t.Logf("%s: All 5 indices accessible without panic", tc.name) }) } } + +// TestMakeSignResponse_FullFlow tests the complete makeSignResponse with mock repos. +func TestMakeSignResponse_FullFlow(t *testing.T) { + config := &cfg.Config{ + DebugOptions: cfg.DebugOptions{ + CapLink: cfg.CapLinkOptions{ + Values: []uint16{0, 0, 0, 0, 0}, + }, + }, + GameplayOptions: cfg.GameplayOptions{ + MezFesSoloTickets: 100, + MezFesGroupTickets: 100, + }, + } + + server := newMakeSignResponseServer(config) + // Give the server some characters + server.charRepo = &mockSignCharacterRepo{ + characters: []character{ + {ID: 1, Name: "TestHunter", HR: 100, GR: 50, WeaponType: 3, LastLogin: 1700000000}, + }, + } + + conn := newMockConn() + session := &Session{ + logger: zap.NewNop(), + server: server, + rawConn: conn, + client: PC100, + } + + result := session.makeSignResponse(1) + if len(result) == 0 { + t.Error("makeSignResponse() returned empty result") + } + // First byte should be SIGN_SUCCESS + if result[0] != uint8(SIGN_SUCCESS) { + t.Errorf("makeSignResponse() first byte = %d, want %d (SIGN_SUCCESS)", result[0], SIGN_SUCCESS) + } +} diff --git a/server/signserver/repo_character.go b/server/signserver/repo_character.go new file mode 100644 index 000000000..46ca13f07 --- /dev/null +++ b/server/signserver/repo_character.go @@ -0,0 +1,119 @@ +package signserver + +import ( + "strings" + + "github.com/jmoiron/sqlx" + "github.com/lib/pq" +) + +// SignCharacterRepository implements SignCharacterRepo with PostgreSQL. +type SignCharacterRepository struct { + db *sqlx.DB +} + +// NewSignCharacterRepository creates a new SignCharacterRepository. +func NewSignCharacterRepository(db *sqlx.DB) *SignCharacterRepository { + return &SignCharacterRepository{db: db} +} + +func (r *SignCharacterRepository) CountNewCharacters(uid uint32) (int, error) { + var count int + err := r.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", uid).Scan(&count) + return count, err +} + +func (r *SignCharacterRepository) CreateCharacter(uid uint32, lastLogin uint32) error { + _, err := r.db.Exec(` + INSERT INTO characters ( + user_id, is_female, is_new_character, name, unk_desc_string, + hr, gr, weapon_type, last_login) + VALUES($1, False, True, '', '', 0, 0, 0, $2)`, + uid, lastLogin, + ) + return err +} + +func (r *SignCharacterRepository) GetForUser(uid uint32) ([]character, error) { + characters := make([]character, 0) + err := r.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id", uid) + if err != nil { + return nil, err + } + return characters, nil +} + +func (r *SignCharacterRepository) IsNewCharacter(cid int) (bool, error) { + var isNew bool + err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", cid).Scan(&isNew) + return isNew, err +} + +func (r *SignCharacterRepository) HardDelete(cid int) error { + _, err := r.db.Exec("DELETE FROM characters WHERE id = $1", cid) + return err +} + +func (r *SignCharacterRepository) SoftDelete(cid int) error { + _, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", cid) + return err +} + +// GetFriends returns friends for a character using parameterized queries +// (fixes the SQL injection vector from the original string-concatenated approach). +func (r *SignCharacterRepository) GetFriends(charID uint32) ([]members, error) { + var friendsCSV string + err := r.db.QueryRow("SELECT friends FROM characters WHERE id=$1", charID).Scan(&friendsCSV) + if err != nil { + return nil, err + } + if friendsCSV == "" { + return nil, nil + } + + friendsSlice := strings.Split(friendsCSV, ",") + // Filter out empty strings + ids := make([]string, 0, len(friendsSlice)) + for _, s := range friendsSlice { + s = strings.TrimSpace(s) + if s != "" { + ids = append(ids, s) + } + } + if len(ids) == 0 { + return nil, nil + } + + // Use parameterized ANY($1) instead of string-concatenated WHERE id=X OR id=Y + friends := make([]members, 0) + err = r.db.Select(&friends, "SELECT id, name FROM characters WHERE id = ANY($1)", pq.Array(ids)) + if err != nil { + return nil, err + } + return friends, nil +} + +// GetGuildmates returns guildmates for a character. +func (r *SignCharacterRepository) GetGuildmates(charID uint32) ([]members, error) { + var inGuild int + err := r.db.QueryRow("SELECT count(*) FROM guild_characters WHERE character_id=$1", charID).Scan(&inGuild) + if err != nil { + return nil, err + } + if inGuild == 0 { + return nil, nil + } + + var guildID int + err = r.db.QueryRow("SELECT guild_id FROM guild_characters WHERE character_id=$1", charID).Scan(&guildID) + if err != nil { + return nil, err + } + + guildmates := make([]members, 0) + err = r.db.Select(&guildmates, "SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=$1 AND character_id!=$2", guildID, charID) + if err != nil { + return nil, err + } + return guildmates, nil +} diff --git a/server/signserver/repo_interfaces.go b/server/signserver/repo_interfaces.go new file mode 100644 index 000000000..9f3c68191 --- /dev/null +++ b/server/signserver/repo_interfaces.go @@ -0,0 +1,66 @@ +package signserver + +import "time" + +// Repository interfaces decouple sign server business logic from concrete +// PostgreSQL implementations, enabling mock/stub injection for unit tests. + +// character represents a player character record from the characters table. +type character struct { + ID uint32 `db:"id"` + IsFemale bool `db:"is_female"` + IsNewCharacter bool `db:"is_new_character"` + Name string `db:"name"` + UnkDescString string `db:"unk_desc_string"` + HR uint16 `db:"hr"` + GR uint16 `db:"gr"` + WeaponType uint16 `db:"weapon_type"` + LastLogin uint32 `db:"last_login"` +} + +// members represents a friend or guildmate entry used in the sign response. +type members struct { + CID uint32 // Local character ID + ID uint32 `db:"id"` + Name string `db:"name"` +} + +// SignUserRepo defines the contract for user-related data access (users, bans tables). +type SignUserRepo interface { + GetCredentials(username string) (uid uint32, passwordHash string, err error) + Register(username, passwordHash string, returnExpires time.Time) (uint32, error) + GetRights(uid uint32) (uint32, error) + GetLastCharacter(uid uint32) (uint32, error) + GetLastLogin(uid uint32) (time.Time, error) + GetReturnExpiry(uid uint32) (time.Time, error) + UpdateReturnExpiry(uid uint32, expiry time.Time) error + UpdateLastLogin(uid uint32, loginTime time.Time) error + CountPermanentBans(uid uint32) (int, error) + CountActiveBans(uid uint32) (int, error) + GetByWiiUKey(wiiuKey string) (uint32, error) + GetByPSNID(psnID string) (uint32, error) + CountByPSNID(psnID string) (int, error) + GetPSNIDForUsername(username string) (string, error) + SetPSNID(username, psnID string) error + GetPSNIDForUser(uid uint32) (string, error) +} + +// SignCharacterRepo defines the contract for character data access. +type SignCharacterRepo interface { + CountNewCharacters(uid uint32) (int, error) + CreateCharacter(uid uint32, lastLogin uint32) error + GetForUser(uid uint32) ([]character, error) + IsNewCharacter(cid int) (bool, error) + HardDelete(cid int) error + SoftDelete(cid int) error + GetFriends(charID uint32) ([]members, error) + GetGuildmates(charID uint32) ([]members, error) +} + +// SignSessionRepo defines the contract for sign session/token data access. +type SignSessionRepo interface { + RegisterUID(uid uint32, token string) (tokenID uint32, err error) + RegisterPSN(psnID, token string) (tokenID uint32, err error) + Validate(token string, tokenID uint32) (bool, error) + GetPSNIDByToken(token string) (string, error) +} diff --git a/server/signserver/repo_mocks_test.go b/server/signserver/repo_mocks_test.go new file mode 100644 index 000000000..06010572c --- /dev/null +++ b/server/signserver/repo_mocks_test.go @@ -0,0 +1,263 @@ +package signserver + +import ( + "errors" + "time" +) + +// errMockDB is a sentinel for mock repo error injection. +var errMockDB = errors.New("mock database error") + +// --- mockSignUserRepo --- + +type mockSignUserRepo struct { + // GetCredentials + credUID uint32 + credPassword string + credErr error + + // Register + registerUID uint32 + registerErr error + registered bool + + // GetRights + rights uint32 + rightsErr error + + // GetLastCharacter + lastCharacter uint32 + lastCharacterErr error + + // GetLastLogin + lastLogin time.Time + lastLoginErr error + + // GetReturnExpiry + returnExpiry time.Time + returnExpiryErr error + + // UpdateReturnExpiry + updateReturnExpiryErr error + updateReturnExpiryCalled bool + + // UpdateLastLogin + updateLastLoginErr error + updateLastLoginCalled bool + + // CountPermanentBans + permanentBans int + permanentBansErr error + + // CountActiveBans + activeBans int + activeBansErr error + + // GetByWiiUKey + wiiuUID uint32 + wiiuErr error + + // GetByPSNID + psnUID uint32 + psnErr error + + // CountByPSNID + psnCount int + psnCountErr error + + // GetPSNIDForUsername + psnIDForUsername string + psnIDForUsernameErr error + + // SetPSNID + setPSNIDErr error + setPSNIDCalled bool + + // GetPSNIDForUser + psnIDForUser string + psnIDForUserErr error +} + +func (m *mockSignUserRepo) GetCredentials(username string) (uint32, string, error) { + return m.credUID, m.credPassword, m.credErr +} + +func (m *mockSignUserRepo) Register(username, passwordHash string, returnExpires time.Time) (uint32, error) { + m.registered = true + return m.registerUID, m.registerErr +} + +func (m *mockSignUserRepo) GetRights(uid uint32) (uint32, error) { + return m.rights, m.rightsErr +} + +func (m *mockSignUserRepo) GetLastCharacter(uid uint32) (uint32, error) { + return m.lastCharacter, m.lastCharacterErr +} + +func (m *mockSignUserRepo) GetLastLogin(uid uint32) (time.Time, error) { + return m.lastLogin, m.lastLoginErr +} + +func (m *mockSignUserRepo) GetReturnExpiry(uid uint32) (time.Time, error) { + return m.returnExpiry, m.returnExpiryErr +} + +func (m *mockSignUserRepo) UpdateReturnExpiry(uid uint32, expiry time.Time) error { + m.updateReturnExpiryCalled = true + return m.updateReturnExpiryErr +} + +func (m *mockSignUserRepo) UpdateLastLogin(uid uint32, loginTime time.Time) error { + m.updateLastLoginCalled = true + return m.updateLastLoginErr +} + +func (m *mockSignUserRepo) CountPermanentBans(uid uint32) (int, error) { + return m.permanentBans, m.permanentBansErr +} + +func (m *mockSignUserRepo) CountActiveBans(uid uint32) (int, error) { + return m.activeBans, m.activeBansErr +} + +func (m *mockSignUserRepo) GetByWiiUKey(wiiuKey string) (uint32, error) { + return m.wiiuUID, m.wiiuErr +} + +func (m *mockSignUserRepo) GetByPSNID(psnID string) (uint32, error) { + return m.psnUID, m.psnErr +} + +func (m *mockSignUserRepo) CountByPSNID(psnID string) (int, error) { + return m.psnCount, m.psnCountErr +} + +func (m *mockSignUserRepo) GetPSNIDForUsername(username string) (string, error) { + return m.psnIDForUsername, m.psnIDForUsernameErr +} + +func (m *mockSignUserRepo) SetPSNID(username, psnID string) error { + m.setPSNIDCalled = true + return m.setPSNIDErr +} + +func (m *mockSignUserRepo) GetPSNIDForUser(uid uint32) (string, error) { + return m.psnIDForUser, m.psnIDForUserErr +} + +// --- mockSignCharacterRepo --- + +type mockSignCharacterRepo struct { + // CountNewCharacters + newCharCount int + newCharCountErr error + + // CreateCharacter + createErr error + createCalled bool + + // GetForUser + characters []character + getForUserErr error + + // IsNewCharacter + isNew bool + isNewErr error + + // HardDelete + hardDeleteErr error + hardDeleteCalled bool + + // SoftDelete + softDeleteErr error + softDeleteCalled bool + + // GetFriends + friends []members + getFriendsErr error + + // GetGuildmates + guildmates []members + getGuildmatesErr error +} + +func (m *mockSignCharacterRepo) CountNewCharacters(uid uint32) (int, error) { + return m.newCharCount, m.newCharCountErr +} + +func (m *mockSignCharacterRepo) CreateCharacter(uid uint32, lastLogin uint32) error { + m.createCalled = true + return m.createErr +} + +func (m *mockSignCharacterRepo) GetForUser(uid uint32) ([]character, error) { + return m.characters, m.getForUserErr +} + +func (m *mockSignCharacterRepo) IsNewCharacter(cid int) (bool, error) { + return m.isNew, m.isNewErr +} + +func (m *mockSignCharacterRepo) HardDelete(cid int) error { + m.hardDeleteCalled = true + return m.hardDeleteErr +} + +func (m *mockSignCharacterRepo) SoftDelete(cid int) error { + m.softDeleteCalled = true + return m.softDeleteErr +} + +func (m *mockSignCharacterRepo) GetFriends(charID uint32) ([]members, error) { + return m.friends, m.getFriendsErr +} + +func (m *mockSignCharacterRepo) GetGuildmates(charID uint32) ([]members, error) { + return m.guildmates, m.getGuildmatesErr +} + +// --- mockSignSessionRepo --- + +type mockSignSessionRepo struct { + // RegisterUID + registerUIDTokenID uint32 + registerUIDErr error + + // RegisterPSN + registerPSNTokenID uint32 + registerPSNErr error + + // Validate + validateResult bool + validateErr error + + // GetPSNIDByToken + psnIDByToken string + psnIDByTokenErr error +} + +func (m *mockSignSessionRepo) RegisterUID(uid uint32, token string) (uint32, error) { + return m.registerUIDTokenID, m.registerUIDErr +} + +func (m *mockSignSessionRepo) RegisterPSN(psnID, token string) (uint32, error) { + return m.registerPSNTokenID, m.registerPSNErr +} + +func (m *mockSignSessionRepo) Validate(token string, tokenID uint32) (bool, error) { + return m.validateResult, m.validateErr +} + +func (m *mockSignSessionRepo) GetPSNIDByToken(token string) (string, error) { + return m.psnIDByToken, m.psnIDByTokenErr +} + +// newTestServer creates a Server with mock repos for testing. +func newTestServer(userRepo SignUserRepo, charRepo SignCharacterRepo, sessionRepo SignSessionRepo) *Server { + return &Server{ + userRepo: userRepo, + charRepo: charRepo, + sessionRepo: sessionRepo, + } +} diff --git a/server/signserver/repo_session.go b/server/signserver/repo_session.go new file mode 100644 index 000000000..ef654c0e1 --- /dev/null +++ b/server/signserver/repo_session.go @@ -0,0 +1,44 @@ +package signserver + +import "github.com/jmoiron/sqlx" + +// SignSessionRepository implements SignSessionRepo with PostgreSQL. +type SignSessionRepository struct { + db *sqlx.DB +} + +// NewSignSessionRepository creates a new SignSessionRepository. +func NewSignSessionRepository(db *sqlx.DB) *SignSessionRepository { + return &SignSessionRepository{db: db} +} + +func (r *SignSessionRepository) RegisterUID(uid uint32, token string) (uint32, error) { + var tid uint32 + err := r.db.QueryRow(`INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id`, uid, token).Scan(&tid) + return tid, err +} + +func (r *SignSessionRepository) RegisterPSN(psnID, token string) (uint32, error) { + var tid uint32 + err := r.db.QueryRow(`INSERT INTO sign_sessions (psn_id, token) VALUES ($1, $2) RETURNING id`, psnID, token).Scan(&tid) + return tid, err +} + +func (r *SignSessionRepository) Validate(token string, tokenID uint32) (bool, error) { + query := `SELECT count(*) FROM sign_sessions WHERE token = $1` + if tokenID > 0 { + query += ` AND id = $2` + } + var exists int + err := r.db.QueryRow(query, token, tokenID).Scan(&exists) + if err != nil { + return false, err + } + return exists > 0, nil +} + +func (r *SignSessionRepository) GetPSNIDByToken(token string) (string, error) { + var psnID string + err := r.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psnID) + return psnID, err +} diff --git a/server/signserver/repo_user.go b/server/signserver/repo_user.go new file mode 100644 index 000000000..fa9ee84d2 --- /dev/null +++ b/server/signserver/repo_user.go @@ -0,0 +1,114 @@ +package signserver + +import ( + "time" + + "github.com/jmoiron/sqlx" +) + +// SignUserRepository implements SignUserRepo with PostgreSQL. +type SignUserRepository struct { + db *sqlx.DB +} + +// NewSignUserRepository creates a new SignUserRepository. +func NewSignUserRepository(db *sqlx.DB) *SignUserRepository { + return &SignUserRepository{db: db} +} + +func (r *SignUserRepository) GetCredentials(username string) (uint32, string, error) { + var uid uint32 + var passwordHash string + err := r.db.QueryRow(`SELECT id, password FROM users WHERE username = $1`, username).Scan(&uid, &passwordHash) + return uid, passwordHash, err +} + +func (r *SignUserRepository) Register(username, passwordHash string, returnExpires time.Time) (uint32, error) { + var uid uint32 + err := r.db.QueryRow( + "INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3) RETURNING id", + username, passwordHash, returnExpires, + ).Scan(&uid) + return uid, err +} + +func (r *SignUserRepository) GetRights(uid uint32) (uint32, error) { + var rights uint32 + err := r.db.QueryRow("SELECT rights FROM users WHERE id=$1", uid).Scan(&rights) + return rights, err +} + +func (r *SignUserRepository) GetLastCharacter(uid uint32) (uint32, error) { + var lastPlayed uint32 + err := r.db.QueryRow("SELECT last_character FROM users WHERE id=$1", uid).Scan(&lastPlayed) + return lastPlayed, err +} + +func (r *SignUserRepository) GetLastLogin(uid uint32) (time.Time, error) { + var lastLogin time.Time + err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + return lastLogin, err +} + +func (r *SignUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) { + var expiry time.Time + err := r.db.Get(&expiry, "SELECT return_expires FROM users WHERE id=$1", uid) + return expiry, err +} + +func (r *SignUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error { + _, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid) + return err +} + +func (r *SignUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error { + _, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid) + return err +} + +func (r *SignUserRepository) CountPermanentBans(uid uint32) (int, error) { + var count int + err := r.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires IS NULL`, uid).Scan(&count) + return count, err +} + +func (r *SignUserRepository) CountActiveBans(uid uint32) (int, error) { + var count int + err := r.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires > now()`, uid).Scan(&count) + return count, err +} + +func (r *SignUserRepository) GetByWiiUKey(wiiuKey string) (uint32, error) { + var uid uint32 + err := r.db.QueryRow(`SELECT id FROM users WHERE wiiu_key = $1`, wiiuKey).Scan(&uid) + return uid, err +} + +func (r *SignUserRepository) GetByPSNID(psnID string) (uint32, error) { + var uid uint32 + err := r.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, psnID).Scan(&uid) + return uid, err +} + +func (r *SignUserRepository) CountByPSNID(psnID string) (int, error) { + var count int + err := r.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psnID).Scan(&count) + return count, err +} + +func (r *SignUserRepository) GetPSNIDForUsername(username string) (string, error) { + var psnID string + err := r.db.QueryRow(`SELECT COALESCE(psn_id, '') FROM users WHERE username = $1`, username).Scan(&psnID) + return psnID, err +} + +func (r *SignUserRepository) SetPSNID(username, psnID string) error { + _, err := r.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psnID, username) + return err +} + +func (r *SignUserRepository) GetPSNIDForUser(uid uint32) (string, error) { + var psnID string + err := r.db.QueryRow("SELECT psn_id FROM users WHERE id = $1", uid).Scan(&psnID) + return psnID, err +} diff --git a/server/signserver/session.go b/server/signserver/session.go index 9944e12e5..4564f2f6c 100644 --- a/server/signserver/session.go +++ b/server/signserver/session.go @@ -115,8 +115,7 @@ func (s *Session) authenticate(username string, password string) { func (s *Session) handleWIIUSGN(bf *byteframe.ByteFrame) { _ = bf.ReadBytes(1) wiiuKey := string(bf.ReadBytes(64)) - var uid uint32 - err := s.server.db.QueryRow(`SELECT id FROM users WHERE wiiu_key = $1`, wiiuKey).Scan(&uid) + uid, err := s.server.userRepo.GetByWiiUKey(wiiuKey) if err != nil { if err == sql.ErrNoRows { s.logger.Info("Unlinked Wii U attempted to authenticate", zap.String("Key", wiiuKey)) @@ -142,8 +141,7 @@ func (s *Session) handlePSSGN(bf *byteframe.ByteFrame) { _ = bf.ReadBytes(82) } s.psn = string(bf.ReadNullTerminatedBytes()) - var uid uint32 - err := s.server.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, s.psn).Scan(&uid) + uid, err := s.server.userRepo.GetByPSNID(s.psn) if err != nil { if err == sql.ErrNoRows { _ = s.cryptConn.SendPacket(s.makeSignResponse(0)) @@ -159,19 +157,17 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) { _ = bf.ReadNullTerminatedBytes() // Client ID credStr, _ := stringsupport.SJISToUTF8(bf.ReadNullTerminatedBytes()) credentials := strings.Split(credStr, "\n") - token := string(bf.ReadNullTerminatedBytes()) + tok := string(bf.ReadNullTerminatedBytes()) uid, resp := s.server.validateLogin(credentials[0], credentials[1]) if resp == SIGN_SUCCESS && uid > 0 { - var psn string - err := s.server.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psn) + psn, err := s.server.sessionRepo.GetPSNIDByToken(tok) if err != nil { s.sendCode(SIGN_ECOGLINK) return } // Since we check for the psn_id, this will never run - var exists int - err = s.server.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psn).Scan(&exists) + exists, err := s.server.userRepo.CountByPSNID(psn) if err != nil { s.sendCode(SIGN_ECOGLINK) return @@ -180,8 +176,7 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) { return } - var currentPSN string - err = s.server.db.QueryRow(`SELECT COALESCE(psn_id, '') FROM users WHERE username = $1`, credentials[0]).Scan(¤tPSN) + currentPSN, err := s.server.userRepo.GetPSNIDForUsername(credentials[0]) if err != nil { s.sendCode(SIGN_ECOGLINK) return @@ -190,7 +185,7 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) { return } - _, err = s.server.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psn, credentials[0]) + err = s.server.userRepo.SetPSNID(credentials[0], psn) if err == nil { s.sendCode(SIGN_SUCCESS) return diff --git a/server/signserver/sign_server.go b/server/signserver/sign_server.go index 87e7476cc..207e3aea6 100644 --- a/server/signserver/sign_server.go +++ b/server/signserver/sign_server.go @@ -24,7 +24,9 @@ type Server struct { sync.Mutex logger *zap.Logger erupeConfig *cfg.Config - db *sqlx.DB + userRepo SignUserRepo + charRepo SignCharacterRepo + sessionRepo SignSessionRepo listener net.Listener isShuttingDown bool } @@ -34,7 +36,11 @@ func NewServer(config *Config) *Server { s := &Server{ logger: config.Logger, erupeConfig: config.ErupeConfig, - db: config.DB, + } + if config.DB != nil { + s.userRepo = NewSignUserRepository(config.DB) + s.charRepo = NewSignCharacterRepository(config.DB) + s.sessionRepo = NewSignSessionRepository(config.DB) } return s }