diff --git a/server/api/api_server.go b/server/api/api_server.go index d1f3ee699..ea048a13c 100644 --- a/server/api/api_server.go +++ b/server/api/api_server.go @@ -27,7 +27,9 @@ type APIServer struct { sync.Mutex logger *zap.Logger erupeConfig *cfg.Config - db *sqlx.DB + userRepo APIUserRepo + charRepo APICharacterRepo + sessionRepo APISessionRepo httpServer *http.Server isShuttingDown bool } @@ -37,9 +39,13 @@ func NewAPIServer(config *Config) *APIServer { s := &APIServer{ logger: config.Logger, erupeConfig: config.ErupeConfig, - db: config.DB, httpServer: &http.Server{}, } + if config.DB != nil { + s.userRepo = NewAPIUserRepository(config.DB) + s.charRepo = NewAPICharacterRepository(config.DB) + s.sessionRepo = NewAPISessionRepository(config.DB) + } return s } diff --git a/server/api/dbutils.go b/server/api/dbutils.go index ecb046a39..1bd8f8397 100644 --- a/server/api/dbutils.go +++ b/server/api/dbutils.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "erupe-ce/common/token" + "errors" "fmt" "time" @@ -11,41 +12,25 @@ import ( ) func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) { - // Create salted hash of user password passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) if err != nil { return 0, 0, err } - - var ( - id uint32 - rights uint32 - ) - err = s.db.QueryRowContext( - ctx, ` - INSERT INTO users (username, password, return_expires) - VALUES ($1, $2, $3) - RETURNING id, rights - `, - username, string(passwordHash), time.Now().Add(time.Hour*24*30), - ).Scan(&id, &rights) - return id, rights, err + return s.userRepo.Register(ctx, username, string(passwordHash), time.Now().Add(time.Hour*24*30)) } func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) { loginToken := token.Generate(16) - var tid uint32 - err := s.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, loginToken).Scan(&tid) + tid, err := s.sessionRepo.CreateToken(ctx, uid, loginToken) if err != nil { return 0, "", err } return tid, loginToken, nil } -func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, error) { - var userID uint32 - err := s.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID) - if err == sql.ErrNoRows { +func (s *APIServer) userIDFromToken(ctx context.Context, tkn string) (uint32, error) { + userID, err := s.sessionRepo.GetUserIDByToken(ctx, tkn) + if errors.Is(err, sql.ErrNoRows) { return 0, fmt.Errorf("invalid login token") } else if err != nil { return 0, err @@ -54,82 +39,50 @@ func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, } func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) { - var character Character - err := s.db.GetContext(ctx, &character, - "SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1", - userID, - ) - if err == sql.ErrNoRows { - var count int - _ = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count) + character, err := s.charRepo.GetNewCharacter(ctx, userID) + if errors.Is(err, sql.ErrNoRows) { + count, _ := s.charRepo.CountForUser(ctx, userID) if count >= 16 { return character, fmt.Errorf("cannot have more than 16 characters") } - err = s.db.GetContext(ctx, &character, ` - INSERT INTO characters ( - user_id, is_female, is_new_character, name, unk_desc_string, - hr, gr, weapon_type, last_login - ) - VALUES ($1, false, true, '', '', 0, 0, 0, $2) - RETURNING id, name, is_female, weapon_type, hr, gr, last_login`, - userID, uint32(time.Now().Unix()), - ) + character, err = s.charRepo.Create(ctx, userID, uint32(time.Now().Unix())) } return character, err } -func (s *APIServer) deleteCharacter(ctx context.Context, userID uint32, charID uint32) error { - var isNew bool - err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew) +func (s *APIServer) deleteCharacter(_ context.Context, _ uint32, charID uint32) error { + isNew, err := s.charRepo.IsNew(charID) if err != nil { return err } if isNew { - _, err = s.db.Exec("DELETE FROM characters WHERE id = $1", charID) - } else { - _, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID) + return s.charRepo.HardDelete(charID) } - return err + return s.charRepo.SoftDelete(charID) } func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) { - var characters []Character - err := s.db.SelectContext( - ctx, &characters, ` - SELECT id, name, is_female, weapon_type, hr, gr, last_login - FROM characters - WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`, - uid, - ) - if err != nil { - return nil, err - } - return characters, nil + return s.charRepo.GetForUser(ctx, uid) } func (s *APIServer) getReturnExpiry(uid uint32) time.Time { - var returnExpiry, lastLogin time.Time - _ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + lastLogin, _ := s.userRepo.GetLastLogin(uid) + var returnExpiry time.Time if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) { returnExpiry = time.Now().Add(time.Hour * 24 * 30) - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry) } else { - err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) + var err error + returnExpiry, err = s.userRepo.GetReturnExpiry(uid) if err != nil { returnExpiry = time.Now() - _, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) + _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry) } } - _, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid) + _ = s.userRepo.UpdateLastLogin(uid, time.Now()) return returnExpiry } func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) { - row := s.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", cid, uid) - result := make(map[string]interface{}) - err := row.MapScan(result) - if err != nil { - return nil, err - } - return result, nil + return s.charRepo.ExportSave(ctx, uid, cid) } diff --git a/server/api/endpoints.go b/server/api/endpoints.go index 9b78535e4..25dfb68e9 100644 --- a/server/api/endpoints.go +++ b/server/api/endpoints.go @@ -162,12 +162,7 @@ func (s *APIServer) Login(w http.ResponseWriter, r *http.Request) { w.WriteHeader(400) return } - var ( - userID uint32 - userRights uint32 - password string - ) - err := s.db.QueryRow("SELECT id, password, rights FROM users WHERE username = $1", reqData.Username).Scan(&userID, &password, &userRights) + userID, password, userRights, err := s.userRepo.GetCredentials(ctx, reqData.Username) if err == sql.ErrNoRows { w.WriteHeader(400) _, _ = w.Write([]byte("username-error")) diff --git a/server/api/endpoints_test.go b/server/api/endpoints_test.go index 0722b00bd..1e172faab 100644 --- a/server/api/endpoints_test.go +++ b/server/api/endpoints_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" cfg "erupe-ce/config" "erupe-ce/common/gametime" @@ -33,7 +34,6 @@ func TestLauncherEndpoint(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } // Create test request @@ -123,7 +123,6 @@ func TestLoginEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } // Invalid JSON @@ -148,7 +147,6 @@ func TestLoginEndpointEmptyCredentials(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -200,7 +198,6 @@ func TestRegisterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"username": "test"` @@ -223,7 +220,6 @@ func TestRegisterEndpointEmptyCredentials(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -271,7 +267,6 @@ func TestCreateCharacterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": ` @@ -294,7 +289,6 @@ func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": "test"` @@ -317,7 +311,6 @@ func TestExportSaveEndpointInvalidJSON(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } invalidJSON := `{"token": ` @@ -342,7 +335,6 @@ func TestScreenShotEndpointDisabled(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil) @@ -379,7 +371,6 @@ func TestScreenShotGetInvalidToken(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, } tests := []struct { @@ -408,10 +399,16 @@ func TestScreenShotGetInvalidToken(t *testing.T) { } } +// newTestUserRepo returns a mock user repo suitable for newAuthData tests. +func newTestUserRepo() *mockAPIUserRepo { + return &mockAPIUserRepo{ + lastLogin: time.Now(), + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + } +} + // TestNewAuthDataStructure tests the newAuthData helper function func TestNewAuthDataStructure(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -423,7 +420,7 @@ func TestNewAuthDataStructure(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } characters := []Character{ @@ -466,8 +463,6 @@ func TestNewAuthDataStructure(t *testing.T) { // TestNewAuthDataDebugMode tests newAuthData with debug mode enabled func TestNewAuthDataDebugMode(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -477,7 +472,7 @@ func TestNewAuthDataDebugMode(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } characters := []Character{ @@ -500,8 +495,6 @@ func TestNewAuthDataDebugMode(t *testing.T) { // TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData func TestNewAuthDataMezFesConfiguration(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -513,7 +506,7 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -534,8 +527,6 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) { // TestNewAuthDataHideNotices tests notice hiding in newAuthData func TestNewAuthDataHideNotices(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -546,7 +537,7 @@ func TestNewAuthDataHideNotices(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -558,8 +549,6 @@ func TestNewAuthDataHideNotices(t *testing.T) { // TestNewAuthDataTimestamps tests timestamp generation in newAuthData func TestNewAuthDataTimestamps(t *testing.T) { - t.Skip("newAuthData requires database for getReturnExpiry - needs integration test") - logger := NewTestLogger(t) defer func() { _ = logger.Sync() }() @@ -567,7 +556,7 @@ func TestNewAuthDataTimestamps(t *testing.T) { server := &APIServer{ logger: logger, erupeConfig: c, - db: nil, + userRepo: newTestUserRepo(), } authData := server.newAuthData(1, 0, 1, "token", []Character{}) @@ -611,6 +600,10 @@ func BenchmarkNewAuthData(b *testing.B) { server := &APIServer{ logger: logger, erupeConfig: c, + userRepo: &mockAPIUserRepo{ + lastLogin: time.Now(), + returnExpiry: time.Now().Add(time.Hour * 24 * 30), + }, } characters := make([]Character, 16) diff --git a/server/api/repo_character.go b/server/api/repo_character.go new file mode 100644 index 000000000..6bddc5815 --- /dev/null +++ b/server/api/repo_character.go @@ -0,0 +1,87 @@ +package api + +import ( + "context" + + "github.com/jmoiron/sqlx" +) + +// APICharacterRepository implements APICharacterRepo with PostgreSQL. +type APICharacterRepository struct { + db *sqlx.DB +} + +// NewAPICharacterRepository creates a new APICharacterRepository. +func NewAPICharacterRepository(db *sqlx.DB) *APICharacterRepository { + return &APICharacterRepository{db: db} +} + +func (r *APICharacterRepository) GetNewCharacter(ctx context.Context, userID uint32) (Character, error) { + var character Character + err := r.db.GetContext(ctx, &character, + "SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1", + userID, + ) + return character, err +} + +func (r *APICharacterRepository) CountForUser(ctx context.Context, userID uint32) (int, error) { + var count int + err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count) + return count, err +} + +func (r *APICharacterRepository) Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) { + var character Character + err := r.db.GetContext(ctx, &character, ` + INSERT INTO characters ( + user_id, is_female, is_new_character, name, unk_desc_string, + hr, gr, weapon_type, last_login + ) + VALUES ($1, false, true, '', '', 0, 0, 0, $2) + RETURNING id, name, is_female, weapon_type, hr, gr, last_login`, + userID, lastLogin, + ) + return character, err +} + +func (r *APICharacterRepository) IsNew(charID uint32) (bool, error) { + var isNew bool + err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew) + return isNew, err +} + +func (r *APICharacterRepository) HardDelete(charID uint32) error { + _, err := r.db.Exec("DELETE FROM characters WHERE id = $1", charID) + return err +} + +func (r *APICharacterRepository) SoftDelete(charID uint32) error { + _, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID) + return err +} + +func (r *APICharacterRepository) GetForUser(ctx context.Context, userID uint32) ([]Character, error) { + var characters []Character + err := r.db.SelectContext( + ctx, &characters, ` + SELECT id, name, is_female, weapon_type, hr, gr, last_login + FROM characters + WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`, + userID, + ) + if err != nil { + return nil, err + } + return characters, nil +} + +func (r *APICharacterRepository) ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) { + row := r.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", charID, userID) + result := make(map[string]interface{}) + err := row.MapScan(result) + if err != nil { + return nil, err + } + return result, nil +} diff --git a/server/api/repo_interfaces.go b/server/api/repo_interfaces.go new file mode 100644 index 000000000..c0e24c3ec --- /dev/null +++ b/server/api/repo_interfaces.go @@ -0,0 +1,53 @@ +package api + +import ( + "context" + "time" +) + +// Repository interfaces decouple API server business logic from concrete +// PostgreSQL implementations, enabling mock/stub injection for unit tests. + +// APIUserRepo defines the contract for user-related data access. +type APIUserRepo interface { + // Register creates a new user and returns their ID and rights. + Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (id uint32, rights uint32, err error) + // GetCredentials returns the user's ID, password hash, and rights. + GetCredentials(ctx context.Context, username string) (id uint32, passwordHash string, rights uint32, err error) + // GetLastLogin returns the user's last login time. + GetLastLogin(uid uint32) (time.Time, error) + // GetReturnExpiry returns the user's return expiry time. + GetReturnExpiry(uid uint32) (time.Time, error) + // UpdateReturnExpiry sets the user's return expiry time. + UpdateReturnExpiry(uid uint32, expiry time.Time) error + // UpdateLastLogin sets the user's last login time. + UpdateLastLogin(uid uint32, loginTime time.Time) error +} + +// APICharacterRepo defines the contract for character-related data access. +type APICharacterRepo interface { + // GetNewCharacter returns an existing new (unfinished) character for a user. + GetNewCharacter(ctx context.Context, userID uint32) (Character, error) + // CountForUser returns the total number of characters for a user. + CountForUser(ctx context.Context, userID uint32) (int, error) + // Create inserts a new character and returns it. + Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) + // IsNew returns whether a character is a new (unfinished) character. + IsNew(charID uint32) (bool, error) + // HardDelete permanently removes a character. + HardDelete(charID uint32) error + // SoftDelete marks a character as deleted. + SoftDelete(charID uint32) error + // GetForUser returns all finalized (non-deleted) characters for a user. + GetForUser(ctx context.Context, userID uint32) ([]Character, error) + // ExportSave returns the full character row as a map. + ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) +} + +// APISessionRepo defines the contract for session/token data access. +type APISessionRepo interface { + // CreateToken inserts a new sign session and returns its ID and token. + CreateToken(ctx context.Context, uid uint32, token string) (tokenID uint32, err error) + // GetUserIDByToken returns the user ID for a given session token. + GetUserIDByToken(ctx context.Context, token string) (uint32, error) +} diff --git a/server/api/repo_mocks_test.go b/server/api/repo_mocks_test.go new file mode 100644 index 000000000..ab4bce375 --- /dev/null +++ b/server/api/repo_mocks_test.go @@ -0,0 +1,124 @@ +package api + +import ( + "context" + "time" +) + +// mockAPIUserRepo implements APIUserRepo for testing. +type mockAPIUserRepo struct { + registerID uint32 + registerRights uint32 + registerErr error + + credentialsID uint32 + credentialsPassword string + credentialsRights uint32 + credentialsErr error + + lastLogin time.Time + lastLoginErr error + + returnExpiry time.Time + returnExpiryErr error + + updateReturnExpiryErr error + updateLastLoginErr error +} + +func (m *mockAPIUserRepo) Register(_ context.Context, _, _ string, _ time.Time) (uint32, uint32, error) { + return m.registerID, m.registerRights, m.registerErr +} + +func (m *mockAPIUserRepo) GetCredentials(_ context.Context, _ string) (uint32, string, uint32, error) { + return m.credentialsID, m.credentialsPassword, m.credentialsRights, m.credentialsErr +} + +func (m *mockAPIUserRepo) GetLastLogin(_ uint32) (time.Time, error) { + return m.lastLogin, m.lastLoginErr +} + +func (m *mockAPIUserRepo) GetReturnExpiry(_ uint32) (time.Time, error) { + return m.returnExpiry, m.returnExpiryErr +} + +func (m *mockAPIUserRepo) UpdateReturnExpiry(_ uint32, _ time.Time) error { + return m.updateReturnExpiryErr +} + +func (m *mockAPIUserRepo) UpdateLastLogin(_ uint32, _ time.Time) error { + return m.updateLastLoginErr +} + +// mockAPICharacterRepo implements APICharacterRepo for testing. +type mockAPICharacterRepo struct { + newCharacter Character + newCharacterErr error + + countForUser int + countForUserErr error + + createChar Character + createCharErr error + + isNewResult bool + isNewErr error + + hardDeleteErr error + softDeleteErr error + + characters []Character + charactersErr error + + exportResult map[string]interface{} + exportErr error +} + +func (m *mockAPICharacterRepo) GetNewCharacter(_ context.Context, _ uint32) (Character, error) { + return m.newCharacter, m.newCharacterErr +} + +func (m *mockAPICharacterRepo) CountForUser(_ context.Context, _ uint32) (int, error) { + return m.countForUser, m.countForUserErr +} + +func (m *mockAPICharacterRepo) Create(_ context.Context, _ uint32, _ uint32) (Character, error) { + return m.createChar, m.createCharErr +} + +func (m *mockAPICharacterRepo) IsNew(_ uint32) (bool, error) { + return m.isNewResult, m.isNewErr +} + +func (m *mockAPICharacterRepo) HardDelete(_ uint32) error { + return m.hardDeleteErr +} + +func (m *mockAPICharacterRepo) SoftDelete(_ uint32) error { + return m.softDeleteErr +} + +func (m *mockAPICharacterRepo) GetForUser(_ context.Context, _ uint32) ([]Character, error) { + return m.characters, m.charactersErr +} + +func (m *mockAPICharacterRepo) ExportSave(_ context.Context, _, _ uint32) (map[string]interface{}, error) { + return m.exportResult, m.exportErr +} + +// mockAPISessionRepo implements APISessionRepo for testing. +type mockAPISessionRepo struct { + createTokenID uint32 + createTokenErr error + + userID uint32 + userIDErr error +} + +func (m *mockAPISessionRepo) CreateToken(_ context.Context, _ uint32, _ string) (uint32, error) { + return m.createTokenID, m.createTokenErr +} + +func (m *mockAPISessionRepo) GetUserIDByToken(_ context.Context, _ string) (uint32, error) { + return m.userID, m.userIDErr +} diff --git a/server/api/repo_session.go b/server/api/repo_session.go new file mode 100644 index 000000000..80a842d00 --- /dev/null +++ b/server/api/repo_session.go @@ -0,0 +1,29 @@ +package api + +import ( + "context" + + "github.com/jmoiron/sqlx" +) + +// APISessionRepository implements APISessionRepo with PostgreSQL. +type APISessionRepository struct { + db *sqlx.DB +} + +// NewAPISessionRepository creates a new APISessionRepository. +func NewAPISessionRepository(db *sqlx.DB) *APISessionRepository { + return &APISessionRepository{db: db} +} + +func (r *APISessionRepository) CreateToken(ctx context.Context, uid uint32, token string) (uint32, error) { + var tid uint32 + err := r.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&tid) + return tid, err +} + +func (r *APISessionRepository) GetUserIDByToken(ctx context.Context, token string) (uint32, error) { + var userID uint32 + err := r.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID) + return userID, err +} diff --git a/server/api/repo_user.go b/server/api/repo_user.go new file mode 100644 index 000000000..dfb25664f --- /dev/null +++ b/server/api/repo_user.go @@ -0,0 +1,66 @@ +package api + +import ( + "context" + "time" + + "github.com/jmoiron/sqlx" +) + +// APIUserRepository implements APIUserRepo with PostgreSQL. +type APIUserRepository struct { + db *sqlx.DB +} + +// NewAPIUserRepository creates a new APIUserRepository. +func NewAPIUserRepository(db *sqlx.DB) *APIUserRepository { + return &APIUserRepository{db: db} +} + +func (r *APIUserRepository) Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (uint32, uint32, error) { + var ( + id uint32 + rights uint32 + ) + err := r.db.QueryRowContext( + ctx, ` + INSERT INTO users (username, password, return_expires) + VALUES ($1, $2, $3) + RETURNING id, rights + `, + username, passwordHash, returnExpires, + ).Scan(&id, &rights) + return id, rights, err +} + +func (r *APIUserRepository) GetCredentials(ctx context.Context, username string) (uint32, string, uint32, error) { + var ( + id uint32 + passwordHash string + rights uint32 + ) + err := r.db.QueryRowContext(ctx, "SELECT id, password, rights FROM users WHERE username = $1", username).Scan(&id, &passwordHash, &rights) + return id, passwordHash, rights, err +} + +func (r *APIUserRepository) GetLastLogin(uid uint32) (time.Time, error) { + var lastLogin time.Time + err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) + return lastLogin, err +} + +func (r *APIUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) { + var returnExpiry time.Time + err := r.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) + return returnExpiry, err +} + +func (r *APIUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error { + _, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid) + return err +} + +func (r *APIUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error { + _, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid) + return err +} diff --git a/server/channelserver/handlers_commands.go b/server/channelserver/handlers_commands.go index 57407f7e1..98beceab6 100644 --- a/server/channelserver/handlers_commands.go +++ b/server/channelserver/handlers_commands.go @@ -32,9 +32,9 @@ func initCommands(cmds []cfg.Command, logger *zap.Logger) { for _, cmd := range cmds { commands[cmd.Name] = cmd if cmd.Enabled { - logger.Info(fmt.Sprintf("Command %s: Enabled, prefix: %s", cmd.Name, cmd.Prefix)) + logger.Info("Command registered", zap.String("name", cmd.Name), zap.String("prefix", cmd.Prefix), zap.Bool("enabled", true)) } else { - logger.Info(fmt.Sprintf("Command %s: Disabled", cmd.Name)) + logger.Info("Command registered", zap.String("name", cmd.Name), zap.Bool("enabled", false)) } } }) diff --git a/server/channelserver/handlers_stage.go b/server/channelserver/handlers_stage.go index 4ce68e4db..c2780a5f6 100644 --- a/server/channelserver/handlers_stage.go +++ b/server/channelserver/handlers_stage.go @@ -1,7 +1,6 @@ package channelserver import ( - "fmt" "strings" "time" @@ -90,7 +89,7 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) { if s.stage != nil { // avoids lock up when using bed for dream quests // Notify the client to duplicate the existing objects. - s.logger.Info(fmt.Sprintf("Sending existing stage objects to %s", s.Name)) + s.logger.Info("Sending existing stage objects", zap.String("session", s.Name)) // Lock stage to safely iterate over objects map // We need to copy the objects list first to avoid holding the lock during packet building diff --git a/server/entranceserver/entrance_server.go b/server/entranceserver/entrance_server.go index 0f39a70e6..fb98c945c 100644 --- a/server/entranceserver/entrance_server.go +++ b/server/entranceserver/entrance_server.go @@ -19,7 +19,8 @@ type Server struct { sync.Mutex logger *zap.Logger erupeConfig *cfg.Config - db *sqlx.DB + serverRepo EntranceServerRepo + sessionRepo EntranceSessionRepo listener net.Listener isShuttingDown bool } @@ -36,7 +37,10 @@ func NewServer(config *Config) *Server { s := &Server{ logger: config.Logger, erupeConfig: config.ErupeConfig, - db: config.DB, + } + if config.DB != nil { + s.serverRepo = NewEntranceServerRepository(config.DB) + s.sessionRepo = NewEntranceSessionRepository(config.DB) } return s } diff --git a/server/entranceserver/make_resp.go b/server/entranceserver/make_resp.go index 3b7a52067..5a57d9045 100644 --- a/server/entranceserver/make_resp.go +++ b/server/entranceserver/make_resp.go @@ -71,7 +71,9 @@ func encodeServerInfo(config *cfg.Config, s *Server, local bool) []byte { bf.WriteUint16(uint16(channelIdx | 16)) bf.WriteUint16(ci.MaxPlayers) var currentPlayers uint16 - _ = s.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", sid).Scan(¤tPlayers) + if s.serverRepo != nil { + currentPlayers, _ = s.serverRepo.GetCurrentPlayers(sid) + } bf.WriteUint16(currentPlayers) bf.WriteUint16(0) bf.WriteUint16(0) @@ -164,12 +166,10 @@ func makeUsrResp(pkt []byte, s *Server) []byte { for i := 0; i < int(userEntries); i++ { cid := bf.ReadUint32() var sid uint16 - err := s.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", cid).Scan(&sid) - if err != nil { - resp.WriteUint16(0) - } else { - resp.WriteUint16(sid) + if s.sessionRepo != nil { + sid, _ = s.sessionRepo.GetServerIDForCharacter(cid) } + resp.WriteUint16(sid) resp.WriteUint16(0) } diff --git a/server/entranceserver/make_resp_test.go b/server/entranceserver/make_resp_test.go index e397b3547..53192b787 100644 --- a/server/entranceserver/make_resp_test.go +++ b/server/entranceserver/make_resp_test.go @@ -113,6 +113,99 @@ func TestClanMemberLimitsBoundsChecking(t *testing.T) { } +// TestEncodeServerInfo_WithMockRepo tests encodeServerInfo with a mock server repo +func TestEncodeServerInfo_WithMockRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + Host: "127.0.0.1", + Entrance: cfg.Entrance{ + Enabled: true, + Port: 53310, + Entries: []cfg.EntranceServerInfo{ + { + Name: "TestServer", + Description: "Test", + IP: "127.0.0.1", + Type: 0, + Recommended: 0, + AllowedClientFlags: 0xFFFFFFFF, + Channels: []cfg.EntranceChannelInfo{ + { + Port: 54001, + MaxPlayers: 100, + }, + }, + }, + }, + }, + GameplayOptions: cfg.GameplayOptions{ + ClanMemberLimits: [][]uint8{{1, 60}}, + }, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + serverRepo: &mockEntranceServerRepo{currentPlayers: 42}, + } + + result := encodeServerInfo(config, server, true) + if len(result) == 0 { + t.Error("encodeServerInfo returned empty result") + } +} + +// TestMakeUsrResp_WithMockRepo tests makeUsrResp with a mock session repo +func TestMakeUsrResp_WithMockRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + sessionRepo: &mockEntranceSessionRepo{serverID: 1234}, + } + + // Build a minimal USR request packet: + // 4 bytes ALL+ prefix, 1 byte 0x00, 2 bytes entry count, then 4 bytes per entry (char ID) + pkt := []byte{ + 'A', 'L', 'L', '+', + 0x00, + 0x00, 0x01, // 1 entry + 0x00, 0x00, 0x00, 0x01, // char_id = 1 + } + + result := makeUsrResp(pkt, server) + if len(result) == 0 { + t.Error("makeUsrResp returned empty result") + } +} + +// TestMakeUsrResp_NilSessionRepo tests makeUsrResp when sessionRepo is nil +func TestMakeUsrResp_NilSessionRepo(t *testing.T) { + config := &cfg.Config{ + RealClientMode: cfg.Z1, + } + + server := &Server{ + logger: zap.NewNop(), + erupeConfig: config, + } + + pkt := []byte{ + 'A', 'L', 'L', '+', + 0x00, + 0x00, 0x01, + 0x00, 0x00, 0x00, 0x01, + } + + result := makeUsrResp(pkt, server) + if len(result) == 0 { + t.Error("makeUsrResp returned empty result") + } +} + // TestEncodeServerInfo_MissingSecondColumnClanMemberLimits tests accessing [last][1] when [last] is too small // Previously panicked: runtime error: index out of range [1] // After fix: Should handle missing column gracefully with default value (60) diff --git a/server/entranceserver/repo_interfaces.go b/server/entranceserver/repo_interfaces.go new file mode 100644 index 000000000..ccfad2964 --- /dev/null +++ b/server/entranceserver/repo_interfaces.go @@ -0,0 +1,19 @@ +package entranceserver + +// Repository interfaces decouple entrance server business logic from concrete +// PostgreSQL implementations, enabling mock/stub injection for unit tests. + +// EntranceServerRepo defines the contract for server-related data access +// used by the entrance server when building server list responses. +type EntranceServerRepo interface { + // GetCurrentPlayers returns the current player count for a given server ID. + GetCurrentPlayers(serverID int) (uint16, error) +} + +// EntranceSessionRepo defines the contract for session-related data access +// used by the entrance server when resolving user locations. +type EntranceSessionRepo interface { + // GetServerIDForCharacter returns the server ID where the given character + // is currently signed in, or 0 if not found. + GetServerIDForCharacter(charID uint32) (uint16, error) +} diff --git a/server/entranceserver/repo_mocks_test.go b/server/entranceserver/repo_mocks_test.go new file mode 100644 index 000000000..64b4776e9 --- /dev/null +++ b/server/entranceserver/repo_mocks_test.go @@ -0,0 +1,21 @@ +package entranceserver + +// mockEntranceServerRepo implements EntranceServerRepo for testing. +type mockEntranceServerRepo struct { + currentPlayers uint16 + currentPlayersErr error +} + +func (m *mockEntranceServerRepo) GetCurrentPlayers(_ int) (uint16, error) { + return m.currentPlayers, m.currentPlayersErr +} + +// mockEntranceSessionRepo implements EntranceSessionRepo for testing. +type mockEntranceSessionRepo struct { + serverID uint16 + serverIDErr error +} + +func (m *mockEntranceSessionRepo) GetServerIDForCharacter(_ uint32) (uint16, error) { + return m.serverID, m.serverIDErr +} diff --git a/server/entranceserver/repo_server.go b/server/entranceserver/repo_server.go new file mode 100644 index 000000000..d45941f9d --- /dev/null +++ b/server/entranceserver/repo_server.go @@ -0,0 +1,22 @@ +package entranceserver + +import "github.com/jmoiron/sqlx" + +// EntranceServerRepository implements EntranceServerRepo with PostgreSQL. +type EntranceServerRepository struct { + db *sqlx.DB +} + +// NewEntranceServerRepository creates a new EntranceServerRepository. +func NewEntranceServerRepository(db *sqlx.DB) *EntranceServerRepository { + return &EntranceServerRepository{db: db} +} + +func (r *EntranceServerRepository) GetCurrentPlayers(serverID int) (uint16, error) { + var currentPlayers uint16 + err := r.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", serverID).Scan(¤tPlayers) + if err != nil { + return 0, err + } + return currentPlayers, nil +} diff --git a/server/entranceserver/repo_session.go b/server/entranceserver/repo_session.go new file mode 100644 index 000000000..008aee8b0 --- /dev/null +++ b/server/entranceserver/repo_session.go @@ -0,0 +1,22 @@ +package entranceserver + +import "github.com/jmoiron/sqlx" + +// EntranceSessionRepository implements EntranceSessionRepo with PostgreSQL. +type EntranceSessionRepository struct { + db *sqlx.DB +} + +// NewEntranceSessionRepository creates a new EntranceSessionRepository. +func NewEntranceSessionRepository(db *sqlx.DB) *EntranceSessionRepository { + return &EntranceSessionRepository{db: db} +} + +func (r *EntranceSessionRepository) GetServerIDForCharacter(charID uint32) (uint16, error) { + var sid uint16 + err := r.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", charID).Scan(&sid) + if err != nil { + return 0, err + } + return sid, nil +}