From 82b967b71535475ea2217bdb64b8a3e74c547657 Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Sun, 22 Feb 2026 17:04:58 +0100 Subject: [PATCH] refactor: replace raw SQL with repository interfaces in entranceserver and API server MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extract all direct database calls from entranceserver (2 calls) and API server (17 calls) into typed repository interfaces with PostgreSQL implementations, matching the pattern established in signserver and channelserver. Entranceserver: EntranceServerRepo, EntranceSessionRepo API server: APIUserRepo, APICharacterRepo, APISessionRepo Also fix the 3 remaining fmt.Sprintf calls inside logger invocations in handlers_commands.go and handlers_stage.go, replacing them with structured zap fields. Unskip 5 TestNewAuthData* tests that previously required a real database — they now run with mock repos. --- server/api/api_server.go | 10 +- server/api/dbutils.go | 93 ++++------------ server/api/endpoints.go | 7 +- server/api/endpoints_test.go | 43 ++++---- server/api/repo_character.go | 87 +++++++++++++++ server/api/repo_interfaces.go | 53 +++++++++ server/api/repo_mocks_test.go | 124 ++++++++++++++++++++++ server/api/repo_session.go | 29 +++++ server/api/repo_user.go | 66 ++++++++++++ server/channelserver/handlers_commands.go | 4 +- server/channelserver/handlers_stage.go | 3 +- server/entranceserver/entrance_server.go | 8 +- server/entranceserver/make_resp.go | 12 +-- server/entranceserver/make_resp_test.go | 93 ++++++++++++++++ server/entranceserver/repo_interfaces.go | 19 ++++ server/entranceserver/repo_mocks_test.go | 21 ++++ server/entranceserver/repo_server.go | 22 ++++ server/entranceserver/repo_session.go | 22 ++++ 18 files changed, 601 insertions(+), 115 deletions(-) create mode 100644 server/api/repo_character.go create mode 100644 server/api/repo_interfaces.go create mode 100644 server/api/repo_mocks_test.go create mode 100644 server/api/repo_session.go create mode 100644 server/api/repo_user.go create mode 100644 server/entranceserver/repo_interfaces.go create mode 100644 server/entranceserver/repo_mocks_test.go create mode 100644 server/entranceserver/repo_server.go create mode 100644 server/entranceserver/repo_session.go diff --git a/server/api/api_server.go b/server/api/api_server.go index d1f3ee699..ea048a13c 100644 --- a/server/api/api_server.go +++ b/server/api/api_server.go @@ -27,7 +27,9 @@ type APIServer struct { sync.Mutex logger *zap.Logger erupeConfig *cfg.Config - db *sqlx.DB + userRepo APIUserRepo + charRepo APICharacterRepo + sessionRepo APISessionRepo httpServer *http.Server isShuttingDown bool } @@ -37,9 +39,13 @@ func NewAPIServer(config *Config) *APIServer { s := &APIServer{ logger: config.Logger, erupeConfig: config.ErupeConfig, - db: config.DB, httpServer: &http.Server{}, } + if config.DB != nil { + s.userRepo = NewAPIUserRepository(config.DB) + s.charRepo = NewAPICharacterRepository(config.DB) + s.sessionRepo = NewAPISessionRepository(config.DB) + } return s } diff --git a/server/api/dbutils.go b/server/api/dbutils.go index ecb046a39..1bd8f8397 100644 --- a/server/api/dbutils.go +++ b/server/api/dbutils.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "erupe-ce/common/token" + "errors" "fmt" "time" @@ -11,41 +12,25 @@ import ( ) func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) { - // Create salted hash of user password passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return 0, 0, err } - - var ( - id uint32 - rights uint32 - ) - err = s.db.QueryRowContext( - ctx, ` - INSERT INTO users (username, password, return_expires) - VALUES ($1, $2, $3) - RETURNING id, rights - `, - username, string(passwordHash), time.Now().Add(time.Hour*24*30), - ).Scan(&id, &rights) - return id, rights, err + return s.userRepo.Register(ctx, username, string(passwordHash), time.Now().Add(time.Hour*24*30)) } func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) { loginToken := token.Generate(16) - var tid uint32 - err := s.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, loginToken).Scan(&tid) + tid, err := s.sessionRepo.CreateToken(ctx, uid, loginToken) if err != nil { return 0, "", err } return tid, loginToken, nil } -func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, error) { - var userID uint32 - err := s.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID) - if err == sql.ErrNoRows { +func (s *APIServer) userIDFromToken(ctx context.Context, tkn string) (uint32, error) { + userID, err := s.sessionRepo.GetUserIDByToken(ctx, tkn) + if errors.Is(err, sql.ErrNoRows) { return 0, fmt.Errorf("invalid login token") } else if err != nil { return 0, err @@ -54,82 +39,50 @@ func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, } func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) { - var character Character - err := s.db.GetContext(ctx, &character, - "SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1", - userID, - ) - if err == sql.ErrNoRows { - var count int - _ = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count) + character, err := s.charRepo.GetNewCharacter(ctx, userID) + if errors.Is(err, sql.ErrNoRows) { + count, _ := s.charRepo.CountForUser(ctx, userID) if count >= 16 { return character, fmt.Errorf("cannot have more than 16 characters") } - err = s.db.GetContext(ctx, &character, ` - INSERT INTO characters ( - user_id, is_female, is_new_character, name, unk_desc_string, - hr, gr, weapon_type, last_login - ) - VALUES ($1, false, true, '', '', 0, 0, 0, $2) - RETURNING id, name, is_female, weapon_type, hr, gr, last_login`, - userID, uint32(time.Now().Unix()), - ) + character, err = s.charRepo.Create(ctx, userID, uint32(time.Now().Unix())) } return character, err } -func (s *APIServer) deleteCharacter(ctx context.Context, userID uint32, charID uint32) error { - var isNew bool - err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew) +func (s *APIServer) deleteCharacter(_ context.Context, _ uint32, charID uint32) error { + isNew, err := s.charRepo.IsNew(charID) if err != nil { return err } if isNew { - _, err = s.db.Exec("DELETE FROM characters WHERE id = $1", charID) - } else { - _, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID) + return s.charRepo.HardDelete(charID) } - return err + return s.charRepo.SoftDelete(charID) } func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) { - var characters []Character - err := s.db.SelectContext( - ctx, &characters, ` - SELECT id, name, is_female, weapon_type, hr, gr, last_login - FROM characters - WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`, - uid, - ) - if err != nil { - return nil, err - } - return characters, nil + return s.charRepo.GetForUser(ctx, uid) } func (s *APIServer) getReturnExpiry(uid uint32) time.Time { - var returnExpiry, lastLogin time.Time - _ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + lastLogin, _ := s.userRepo.GetLastLogin(uid) + var returnExpiry time.Time if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) { returnExpiry = time.Now().Add(time.Hour * 24 * 30) - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry) } else { - err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) + var err error + returnExpiry, err = s.userRepo.GetReturnExpiry(uid) if err != nil { returnExpiry = time.Now() - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry) } } - _, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid) + _ = s.userRepo.UpdateLastLogin(uid, time.Now()) return returnExpiry } func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) { - row := s.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", cid, uid) - result := make(map[string]interface{}) - err := row.MapScan(result) - if err != nil { - return nil, err - } - return result, nil + return s.charRepo.ExportSave(ctx, uid, cid) } diff --git a/server/api/endpoints.go b/server/api/endpoints.go index 9b78535e4..25dfb68e9 100644 --- a/server/api/endpoints.go +++ b/server/api/endpoints.go @@ -162,12 +162,7 @@ func (s *APIServer) Login(w http.ResponseWriter, r *http.Request) { w.WriteHeader(400) return } - var ( - userID uint32 - userRights uint32 - password string - ) - err := s.db.QueryRow("SELECT id, password, rights FROM users WHERE username = $1", reqData.Username).Scan(&userID, &password, &userRights) + userID, password, userRights, err := s.userRepo.GetCredentials(ctx, reqData.Username) if err == sql.ErrNoRows { w.WriteHeader(400) _, _ = w.Write([]byte("username-error")) diff --git a/server/api/endpoints_test.go b/server/api/endpoints_test.go index 0722b00bd..1e172faab 100644 --- a/server/api/endpoints_test.go +++ b/server/api/endpoints_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" cfg "erupe-ce/config" "erupe-ce/common/gametime" @@ -33,7 +34,6 @@ func TestLauncherEndpoint(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } // Create test request @@ -123,7 +123,6 @@ func TestLoginEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } // Invalid JSON @@ -148,7 +147,6 @@ func TestLoginEndpointEmptyCredentials(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -200,7 +198,6 @@ func TestRegisterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"username": "test"` @@ -223,7 +220,6 @@ func TestRegisterEndpointEmptyCredentials(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -271,7 +267,6 @@ func TestCreateCharacterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": ` @@ -294,7 +289,6 @@ func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": "test"` @@ -317,7 +311,6 @@ func TestExportSaveEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": ` @@ -342,7 +335,6 @@ func TestScreenShotEndpointDisabled(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil) @@ -379,7 +371,6 @@ func TestScreenShotGetInvalidToken(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -408,10 +399,16 @@ func TestScreenShotGetInvalidToken(t *testing.T) { } } +// newTestUserRepo returns a mock user repo suitable for newAuthData tests. +func newTestUserRepo() *mockAPIUserRepo { + return &mockAPIUserRepo{ + lastLogin: time.Now(), + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + } +} + // TestNewAuthDataStructure tests the newAuthData helper function func TestNewAuthDataStructure(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -423,7 +420,7 @@ func TestNewAuthDataStructure(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } characters := []Character{ @@ -466,8 +463,6 @@ func TestNewAuthDataStructure(t *testing.T) { // TestNewAuthDataDebugMode tests newAuthData with debug mode enabled func TestNewAuthDataDebugMode(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -477,7 +472,7 @@ func TestNewAuthDataDebugMode(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } characters := []Character{ @@ -500,8 +495,6 @@ func TestNewAuthDataDebugMode(t *testing.T) { // TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData func TestNewAuthDataMezFesConfiguration(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -513,7 +506,7 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -534,8 +527,6 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) { // TestNewAuthDataHideNotices tests notice hiding in newAuthData func TestNewAuthDataHideNotices(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -546,7 +537,7 @@ func TestNewAuthDataHideNotices(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -558,8 +549,6 @@ func TestNewAuthDataHideNotices(t *testing.T) { // TestNewAuthDataTimestamps tests timestamp generation in newAuthData func TestNewAuthDataTimestamps(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -567,7 +556,7 @@ func TestNewAuthDataTimestamps(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -611,6 +600,10 @@ func BenchmarkNewAuthData(b *testing.B) { server := &APIServer{ logger: logger, erupeConfig: c, + userRepo: &mockAPIUserRepo{ + lastLogin: time.Now(), + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + }, } characters := make([]Character, 16) diff --git a/server/api/repo_character.go b/server/api/repo_character.go new file mode 100644 index 000000000..6bddc5815 --- /dev/null +++ b/server/api/repo_character.go @@ -0,0 +1,87 @@ +package api + +import ( + "context" + + "github.com/jmoiron/sqlx" +) + +// APICharacterRepository implements APICharacterRepo with PostgreSQL. +type APICharacterRepository struct { + db *sqlx.DB +} + +// NewAPICharacterRepository creates a new APICharacterRepository. +func NewAPICharacterRepository(db *sqlx.DB) *APICharacterRepository { + return &APICharacterRepository{db: db} +} + +func (r *APICharacterRepository) GetNewCharacter(ctx context.Context, userID uint32) (Character, error) { + var character Character + err := r.db.GetContext(ctx, &character, + "SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1", + userID, + ) + return character, err +} + +func (r *APICharacterRepository) CountForUser(ctx context.Context, userID uint32) (int, error) { + var count int + err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count) + return count, err +} + +func (r *APICharacterRepository) Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) { + var character Character + err := r.db.GetContext(ctx, &character, ` + INSERT INTO characters ( + user_id, is_female, is_new_character, name, unk_desc_string, + hr, gr, weapon_type, last_login + ) + VALUES ($1, false, true, '', '', 0, 0, 0, $2) + RETURNING id, name, is_female, weapon_type, hr, gr, last_login`, + userID, lastLogin, + ) + return character, err +} + +func (r *APICharacterRepository) IsNew(charID uint32) (bool, error) { + var isNew bool + err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew) + return isNew, err +} + +func (r *APICharacterRepository) HardDelete(charID uint32) error { + _, err := r.db.Exec("DELETE FROM characters WHERE id = $1", charID) + return err +} + +func (r *APICharacterRepository) SoftDelete(charID uint32) error { + _, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID) + return err +} + +func (r *APICharacterRepository) GetForUser(ctx context.Context, userID uint32) ([]Character, error) { + var characters []Character + err := r.db.SelectContext( + ctx, &characters, ` + SELECT id, name, is_female, weapon_type, hr, gr, last_login + FROM characters + WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`, + userID, + ) + if err != nil { + return nil, err + } + return characters, nil +} + +func (r *APICharacterRepository) ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) { + row := r.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", charID, userID) + result := make(map[string]interface{}) + err := row.MapScan(result) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/server/api/repo_interfaces.go b/server/api/repo_interfaces.go new file mode 100644 index 000000000..c0e24c3ec --- /dev/null +++ b/server/api/repo_interfaces.go @@ -0,0 +1,53 @@ +package api + +import ( + "context" + "time" +) + +// Repository interfaces decouple API server business logic from concrete +// PostgreSQL implementations, enabling mock/stub injection for unit tests. + +// APIUserRepo defines the contract for user-related data access. +type APIUserRepo interface { + // Register creates a new user and returns their ID and rights. + Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (id uint32, rights uint32, err error) + // GetCredentials returns the user's ID, password hash, and rights. + GetCredentials(ctx context.Context, username string) (id uint32, passwordHash string, rights uint32, err error) + // GetLastLogin returns the user's last login time. + GetLastLogin(uid uint32) (time.Time, error) + // GetReturnExpiry returns the user's return expiry time. + GetReturnExpiry(uid uint32) (time.Time, error) + // UpdateReturnExpiry sets the user's return expiry time. + UpdateReturnExpiry(uid uint32, expiry time.Time) error + // UpdateLastLogin sets the user's last login time. + UpdateLastLogin(uid uint32, loginTime time.Time) error +} + +// APICharacterRepo defines the contract for character-related data access. +type APICharacterRepo interface { + // GetNewCharacter returns an existing new (unfinished) character for a user. + GetNewCharacter(ctx context.Context, userID uint32) (Character, error) + // CountForUser returns the total number of characters for a user. + CountForUser(ctx context.Context, userID uint32) (int, error) + // Create inserts a new character and returns it. + Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) + // IsNew returns whether a character is a new (unfinished) character. + IsNew(charID uint32) (bool, error) + // HardDelete permanently removes a character. + HardDelete(charID uint32) error + // SoftDelete marks a character as deleted. + SoftDelete(charID uint32) error + // GetForUser returns all finalized (non-deleted) characters for a user. + GetForUser(ctx context.Context, userID uint32) ([]Character, error) + // ExportSave returns the full character row as a map. + ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) +} + +// APISessionRepo defines the contract for session/token data access. +type APISessionRepo interface { + // CreateToken inserts a new sign session and returns its ID and token. + CreateToken(ctx context.Context, uid uint32, token string) (tokenID uint32, err error) + // GetUserIDByToken returns the user ID for a given session token. + GetUserIDByToken(ctx context.Context, token string) (uint32, error) +} diff --git a/server/api/repo_mocks_test.go b/server/api/repo_mocks_test.go new file mode 100644 index 000000000..ab4bce375 --- /dev/null +++ b/server/api/repo_mocks_test.go @@ -0,0 +1,124 @@ +package api + +import ( + "context" + "time" +) + +// mockAPIUserRepo implements APIUserRepo for testing. +type mockAPIUserRepo struct { + registerID uint32 + registerRights uint32 + registerErr error + + credentialsID uint32 + credentialsPassword string + credentialsRights uint32 + credentialsErr error + + lastLogin time.Time + lastLoginErr error + + returnExpiry time.Time + returnExpiryErr error + + updateReturnExpiryErr error + updateLastLoginErr error +} + +func (m *mockAPIUserRepo) Register(_ context.Context, _, _ string, _ time.Time) (uint32, uint32, error) { + return m.registerID, m.registerRights, m.registerErr +} + +func (m *mockAPIUserRepo) GetCredentials(_ context.Context, _ string) (uint32, string, uint32, error) { + return m.credentialsID, m.credentialsPassword, m.credentialsRights, m.credentialsErr +} + +func (m *mockAPIUserRepo) GetLastLogin(_ uint32) (time.Time, error) { + return m.lastLogin, m.lastLoginErr +} + +func (m *mockAPIUserRepo) GetReturnExpiry(_ uint32) (time.Time, error) { + return m.returnExpiry, m.returnExpiryErr +} + +func (m *mockAPIUserRepo) UpdateReturnExpiry(_ uint32, _ time.Time) error { + return m.updateReturnExpiryErr +} + +func (m *mockAPIUserRepo) UpdateLastLogin(_ uint32, _ time.Time) error { + return m.updateLastLoginErr +} + +// mockAPICharacterRepo implements APICharacterRepo for testing. +type mockAPICharacterRepo struct { + newCharacter Character + newCharacterErr error + + countForUser int + countForUserErr error + + createChar Character + createCharErr error + + isNewResult bool + isNewErr error + + hardDeleteErr error + softDeleteErr error + + characters []Character + charactersErr error + + exportResult map[string]interface{} + exportErr error +} + +func (m *mockAPICharacterRepo) GetNewCharacter(_ context.Context, _ uint32) (Character, error) { + return m.newCharacter, m.newCharacterErr +} + +func (m *mockAPICharacterRepo) CountForUser(_ context.Context, _ uint32) (int, error) { + return m.countForUser, m.countForUserErr +} + +func (m *mockAPICharacterRepo) Create(_ context.Context, _ uint32, _ uint32) (Character, error) { + return m.createChar, m.createCharErr +} + +func (m *mockAPICharacterRepo) IsNew(_ uint32) (bool, error) { + return m.isNewResult, m.isNewErr +} + +func (m *mockAPICharacterRepo) HardDelete(_ uint32) error { + return m.hardDeleteErr +} + +func (m *mockAPICharacterRepo) SoftDelete(_ uint32) error { + return m.softDeleteErr +} + +func (m *mockAPICharacterRepo) GetForUser(_ context.Context, _ uint32) ([]Character, error) { + return m.characters, m.charactersErr +} + +func (m *mockAPICharacterRepo) ExportSave(_ context.Context, _, _ uint32) (map[string]interface{}, error) { + return m.exportResult, m.exportErr +} + +// mockAPISessionRepo implements APISessionRepo for testing. +type mockAPISessionRepo struct { + createTokenID uint32 + createTokenErr error + + userID uint32 + userIDErr error +} + +func (m *mockAPISessionRepo) CreateToken(_ context.Context, _ uint32, _ string) (uint32, error) { + return m.createTokenID, m.createTokenErr +} + +func (m *mockAPISessionRepo) GetUserIDByToken(_ context.Context, _ string) (uint32, error) { + return m.userID, m.userIDErr +} diff --git a/server/api/repo_session.go b/server/api/repo_session.go new file mode 100644 index 000000000..80a842d00 --- /dev/null +++ b/server/api/repo_session.go @@ -0,0 +1,29 @@ +package api + +import ( + "context" + + "github.com/jmoiron/sqlx" +) + +// APISessionRepository implements APISessionRepo with PostgreSQL. +type APISessionRepository struct { + db *sqlx.DB +} + +// NewAPISessionRepository creates a new APISessionRepository. +func NewAPISessionRepository(db *sqlx.DB) *APISessionRepository { + return &APISessionRepository{db: db} +} + +func (r *APISessionRepository) CreateToken(ctx context.Context, uid uint32, token string) (uint32, error) { + var tid uint32 + err := r.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&tid) + return tid, err +} + +func (r *APISessionRepository) GetUserIDByToken(ctx context.Context, token string) (uint32, error) { + var userID uint32 + err := r.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID) + return userID, err +} diff --git a/server/api/repo_user.go b/server/api/repo_user.go new file mode 100644 index 000000000..dfb25664f --- /dev/null +++ b/server/api/repo_user.go @@ -0,0 +1,66 @@ +package api + +import ( + "context" + "time" + + "github.com/jmoiron/sqlx" +) + +// APIUserRepository implements APIUserRepo with PostgreSQL. +type APIUserRepository struct { + db *sqlx.DB +} + +// NewAPIUserRepository creates a new APIUserRepository. +func NewAPIUserRepository(db *sqlx.DB) *APIUserRepository { + return &APIUserRepository{db: db} +} + +func (r *APIUserRepository) Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (uint32, uint32, error) { + var ( + id uint32 + rights uint32 + ) + err := r.db.QueryRowContext( + ctx, ` + INSERT INTO users (username, password, return_expires) + VALUES ($1, $2, $3) + RETURNING id, rights + `, + username, passwordHash, returnExpires, + ).Scan(&id, &rights) + return id, rights, err +} + +func (r *APIUserRepository) GetCredentials(ctx context.Context, username string) (uint32, string, uint32, error) { + var ( + id uint32 + passwordHash string + rights uint32 + ) + err := r.db.QueryRowContext(ctx, "SELECT id, password, rights FROM users WHERE username = $1", username).Scan(&id, &passwordHash, &rights) + return id, passwordHash, rights, err +} + +func (r *APIUserRepository) GetLastLogin(uid uint32) (time.Time, error) { + var lastLogin time.Time + err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + return lastLogin, err +} + +func (r *APIUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) { + var returnExpiry time.Time + err := r.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) + return returnExpiry, err +} + +func (r *APIUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error { + _, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid) + return err +} + +func (r *APIUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error { + _, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid) + return err +} diff --git a/server/channelserver/handlers_commands.go b/server/channelserver/handlers_commands.go index 57407f7e1..98beceab6 100644 --- a/server/channelserver/handlers_commands.go +++ b/server/channelserver/handlers_commands.go @@ -32,9 +32,9 @@ func initCommands(cmds []cfg.Command, logger *zap.Logger) { for _, cmd := range cmds { commands[cmd.Name] = cmd if cmd.Enabled { - logger.Info(fmt.Sprintf("Command %s: Enabled, prefix: %s", cmd.Name, cmd.Prefix)) + logger.Info("Command registered", zap.String("name", cmd.Name), zap.String("prefix", cmd.Prefix), zap.Bool("enabled", true)) } else { - logger.Info(fmt.Sprintf("Command %s: Disabled", cmd.Name)) + logger.Info("Command registered", zap.String("name", cmd.Name), zap.Bool("enabled", false)) } } }) diff --git a/server/channelserver/handlers_stage.go b/server/channelserver/handlers_stage.go index 4ce68e4db..c2780a5f6 100644 --- a/server/channelserver/handlers_stage.go +++ b/server/channelserver/handlers_stage.go @@ -1,7 +1,6 @@ package channelserver import ( - "fmt" "strings" "time" @@ -90,7 +89,7 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) { if s.stage != nil { // avoids lock up when using bed for dream quests // Notify the client to duplicate the existing objects. - s.logger.Info(fmt.Sprintf("Sending existing stage objects to %s", s.Name)) + s.logger.Info("Sending existing stage objects", zap.String("session", s.Name)) // Lock stage to safely iterate over objects map // We need to copy the objects list first to avoid holding the lock during packet building diff --git a/server/entranceserver/entrance_server.go b/server/entranceserver/entrance_server.go index 0f39a70e6..fb98c945c 100644 --- a/server/entranceserver/entrance_server.go +++ b/server/entranceserver/entrance_server.go @@ -19,7 +19,8 @@ type Server struct { sync.Mutex logger *zap.Logger erupeConfig *cfg.Config - db *sqlx.DB + serverRepo EntranceServerRepo + sessionRepo EntranceSessionRepo listener net.Listener isShuttingDown bool } @@ -36,7 +37,10 @@ func NewServer(config *Config) *Server { s := &Server{ logger: config.Logger, erupeConfig: config.ErupeConfig, - db: config.DB, + } + if config.DB != nil { + s.serverRepo = NewEntranceServerRepository(config.DB) + s.sessionRepo = NewEntranceSessionRepository(config.DB) } return s } diff --git a/server/entranceserver/make_resp.go b/server/entranceserver/make_resp.go index 3b7a52067..5a57d9045 100644 --- a/server/entranceserver/make_resp.go +++ b/server/entranceserver/make_resp.go @@ -71,7 +71,9 @@ func encodeServerInfo(config *cfg.Config, s *Server, local bool) []byte { bf.WriteUint16(uint16(channelIdx | 16)) bf.WriteUint16(ci.MaxPlayers) var currentPlayers uint16 - _ = s.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", sid).Scan(¤tPlayers) + if s.serverRepo != nil { + currentPlayers, _ = s.serverRepo.GetCurrentPlayers(sid) + } bf.WriteUint16(currentPlayers) bf.WriteUint16(0) bf.WriteUint16(0) @@ -164,12 +166,10 @@ func makeUsrResp(pkt []byte, s *Server) []byte { for i := 0; i < int(userEntries); i++ { cid := bf.ReadUint32() var sid uint16 - err := s.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", cid).Scan(&sid) - if err != nil { - resp.WriteUint16(0) - } else { - resp.WriteUint16(sid) + if s.sessionRepo != nil { + sid, _ = s.sessionRepo.GetServerIDForCharacter(cid) } + resp.WriteUint16(sid) resp.WriteUint16(0) } diff --git a/server/entranceserver/make_resp_test.go b/server/entranceserver/make_resp_test.go index e397b3547..53192b787 100644 --- a/server/entranceserver/make_resp_test.go +++ b/server/entranceserver/make_resp_test.go @@ -113,6 +113,99 @@ func TestClanMemberLimitsBoundsChecking(t *testing.T) { } +// TestEncodeServerInfo_WithMockRepo tests encodeServerInfo with a mock server repo +func TestEncodeServerInfo_WithMockRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + Host: "127.0.0.1", + Entrance: cfg.Entrance{ + Enabled: true, + Port: 53310, + Entries: []cfg.EntranceServerInfo{ + { + Name: "TestServer", + Description: "Test", + IP: "127.0.0.1", + Type: 0, + Recommended: 0, + AllowedClientFlags: 0xFFFFFFFF, + Channels: []cfg.EntranceChannelInfo{ + { + Port: 54001, + MaxPlayers: 100, + }, + }, + }, + }, + }, + GameplayOptions: cfg.GameplayOptions{ + ClanMemberLimits: [][]uint8{{1, 60}}, + }, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + serverRepo: &mockEntranceServerRepo{currentPlayers: 42}, + } + + result := encodeServerInfo(config, server, true) + if len(result) == 0 { + t.Error("encodeServerInfo returned empty result") + } +} + +// TestMakeUsrResp_WithMockRepo tests makeUsrResp with a mock session repo +func TestMakeUsrResp_WithMockRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + sessionRepo: &mockEntranceSessionRepo{serverID: 1234}, + } + + // Build a minimal USR request packet: + // 4 bytes ALL+ prefix, 1 byte 0x00, 2 bytes entry count, then 4 bytes per entry (char ID) + pkt := []byte{ + 'A', 'L', 'L', '+', + 0x00, + 0x00, 0x01, // 1 entry + 0x00, 0x00, 0x00, 0x01, // char_id = 1 + } + + result := makeUsrResp(pkt, server) + if len(result) == 0 { + t.Error("makeUsrResp returned empty result") + } +} + +// TestMakeUsrResp_NilSessionRepo tests makeUsrResp when sessionRepo is nil +func TestMakeUsrResp_NilSessionRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + } + + pkt := []byte{ + 'A', 'L', 'L', '+', + 0x00, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x01, + } + + result := makeUsrResp(pkt, server) + if len(result) == 0 { + t.Error("makeUsrResp returned empty result") + } +} + // TestEncodeServerInfo_MissingSecondColumnClanMemberLimits tests accessing [last][1] when [last] is too small // Previously panicked: runtime error: index out of range [1] // After fix: Should handle missing column gracefully with default value (60) diff --git a/server/entranceserver/repo_interfaces.go b/server/entranceserver/repo_interfaces.go new file mode 100644 index 000000000..ccfad2964 --- /dev/null +++ b/server/entranceserver/repo_interfaces.go @@ -0,0 +1,19 @@ +package entranceserver + +// Repository interfaces decouple entrance server business logic from concrete +// PostgreSQL implementations, enabling mock/stub injection for unit tests. + +// EntranceServerRepo defines the contract for server-related data access +// used by the entrance server when building server list responses. +type EntranceServerRepo interface { + // GetCurrentPlayers returns the current player count for a given server ID. + GetCurrentPlayers(serverID int) (uint16, error) +} + +// EntranceSessionRepo defines the contract for session-related data access +// used by the entrance server when resolving user locations. +type EntranceSessionRepo interface { + // GetServerIDForCharacter returns the server ID where the given character + // is currently signed in, or 0 if not found. + GetServerIDForCharacter(charID uint32) (uint16, error) +} diff --git a/server/entranceserver/repo_mocks_test.go b/server/entranceserver/repo_mocks_test.go new file mode 100644 index 000000000..64b4776e9 --- /dev/null +++ b/server/entranceserver/repo_mocks_test.go @@ -0,0 +1,21 @@ +package entranceserver + +// mockEntranceServerRepo implements EntranceServerRepo for testing. +type mockEntranceServerRepo struct { + currentPlayers uint16 + currentPlayersErr error +} + +func (m *mockEntranceServerRepo) GetCurrentPlayers(_ int) (uint16, error) { + return m.currentPlayers, m.currentPlayersErr +} + +// mockEntranceSessionRepo implements EntranceSessionRepo for testing. +type mockEntranceSessionRepo struct { + serverID uint16 + serverIDErr error +} + +func (m *mockEntranceSessionRepo) GetServerIDForCharacter(_ uint32) (uint16, error) { + return m.serverID, m.serverIDErr +} diff --git a/server/entranceserver/repo_server.go b/server/entranceserver/repo_server.go new file mode 100644 index 000000000..d45941f9d --- /dev/null +++ b/server/entranceserver/repo_server.go @@ -0,0 +1,22 @@ +package entranceserver + +import "github.com/jmoiron/sqlx" + +// EntranceServerRepository implements EntranceServerRepo with PostgreSQL. +type EntranceServerRepository struct { + db *sqlx.DB +} + +// NewEntranceServerRepository creates a new EntranceServerRepository. +func NewEntranceServerRepository(db *sqlx.DB) *EntranceServerRepository { + return &EntranceServerRepository{db: db} +} + +func (r *EntranceServerRepository) GetCurrentPlayers(serverID int) (uint16, error) { + var currentPlayers uint16 + err := r.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", serverID).Scan(¤tPlayers) + if err != nil { + return 0, err + } + return currentPlayers, nil +} diff --git a/server/entranceserver/repo_session.go b/server/entranceserver/repo_session.go new file mode 100644 index 000000000..008aee8b0 --- /dev/null +++ b/server/entranceserver/repo_session.go @@ -0,0 +1,22 @@ +package entranceserver + +import "github.com/jmoiron/sqlx" + +// EntranceSessionRepository implements EntranceSessionRepo with PostgreSQL. +type EntranceSessionRepository struct { + db *sqlx.DB +} + +// NewEntranceSessionRepository creates a new EntranceSessionRepository. +func NewEntranceSessionRepository(db *sqlx.DB) *EntranceSessionRepository { + return &EntranceSessionRepository{db: db} +} + +func (r *EntranceSessionRepository) GetServerIDForCharacter(charID uint32) (uint16, error) { + var sid uint16 + err := r.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", charID).Scan(&sid) + if err != nil { + return 0, err + } + return sid, nil +}