refactor: replace raw SQL with repository interfaces in entranceserver and API server

Extract all direct database calls from entranceserver (2 calls) and
API server (17 calls) into typed repository interfaces with PostgreSQL
implementations, matching the pattern established in signserver and
channelserver.

Entranceserver: EntranceServerRepo, EntranceSessionRepo
API server: APIUserRepo, APICharacterRepo, APISessionRepo

Also fix the 3 remaining fmt.Sprintf calls inside logger invocations
in handlers_commands.go and handlers_stage.go, replacing them with
structured zap fields.

Unskip 5 TestNewAuthData* tests that previously required a real
database — they now run with mock repos.
This commit is contained in:
Houmgaor
2026-02-22 17:04:58 +01:00
parent f640cfee27
commit 82b967b715
18 changed files with 601 additions and 115 deletions

View File

@@ -27,7 +27,9 @@ type APIServer struct {
sync.Mutex sync.Mutex
logger *zap.Logger logger *zap.Logger
erupeConfig *cfg.Config erupeConfig *cfg.Config
db *sqlx.DB userRepo APIUserRepo
charRepo APICharacterRepo
sessionRepo APISessionRepo
httpServer *http.Server httpServer *http.Server
isShuttingDown bool isShuttingDown bool
} }
@@ -37,9 +39,13 @@ func NewAPIServer(config *Config) *APIServer {
s := &APIServer{ s := &APIServer{
logger: config.Logger, logger: config.Logger,
erupeConfig: config.ErupeConfig, erupeConfig: config.ErupeConfig,
db: config.DB,
httpServer: &http.Server{}, httpServer: &http.Server{},
} }
if config.DB != nil {
s.userRepo = NewAPIUserRepository(config.DB)
s.charRepo = NewAPICharacterRepository(config.DB)
s.sessionRepo = NewAPISessionRepository(config.DB)
}
return s return s
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"erupe-ce/common/token" "erupe-ce/common/token"
"errors"
"fmt" "fmt"
"time" "time"
@@ -11,41 +12,25 @@ import (
) )
func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) { func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) {
// Create salted hash of user password
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
} }
return s.userRepo.Register(ctx, username, string(passwordHash), time.Now().Add(time.Hour*24*30))
var (
id uint32
rights uint32
)
err = s.db.QueryRowContext(
ctx, `
INSERT INTO users (username, password, return_expires)
VALUES ($1, $2, $3)
RETURNING id, rights
`,
username, string(passwordHash), time.Now().Add(time.Hour*24*30),
).Scan(&id, &rights)
return id, rights, err
} }
func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) { func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) {
loginToken := token.Generate(16) loginToken := token.Generate(16)
var tid uint32 tid, err := s.sessionRepo.CreateToken(ctx, uid, loginToken)
err := s.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, loginToken).Scan(&tid)
if err != nil { if err != nil {
return 0, "", err return 0, "", err
} }
return tid, loginToken, nil return tid, loginToken, nil
} }
func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, error) { func (s *APIServer) userIDFromToken(ctx context.Context, tkn string) (uint32, error) {
var userID uint32 userID, err := s.sessionRepo.GetUserIDByToken(ctx, tkn)
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID) if errors.Is(err, sql.ErrNoRows) {
if err == sql.ErrNoRows {
return 0, fmt.Errorf("invalid login token") return 0, fmt.Errorf("invalid login token")
} else if err != nil { } else if err != nil {
return 0, err return 0, err
@@ -54,82 +39,50 @@ func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32,
} }
func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) { func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) {
var character Character character, err := s.charRepo.GetNewCharacter(ctx, userID)
err := s.db.GetContext(ctx, &character, if errors.Is(err, sql.ErrNoRows) {
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1", count, _ := s.charRepo.CountForUser(ctx, userID)
userID,
)
if err == sql.ErrNoRows {
var count int
_ = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
if count >= 16 { if count >= 16 {
return character, fmt.Errorf("cannot have more than 16 characters") return character, fmt.Errorf("cannot have more than 16 characters")
} }
err = s.db.GetContext(ctx, &character, ` character, err = s.charRepo.Create(ctx, userID, uint32(time.Now().Unix()))
INSERT INTO characters (
user_id, is_female, is_new_character, name, unk_desc_string,
hr, gr, weapon_type, last_login
)
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
userID, uint32(time.Now().Unix()),
)
} }
return character, err return character, err
} }
func (s *APIServer) deleteCharacter(ctx context.Context, userID uint32, charID uint32) error { func (s *APIServer) deleteCharacter(_ context.Context, _ uint32, charID uint32) error {
var isNew bool isNew, err := s.charRepo.IsNew(charID)
err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
if err != nil { if err != nil {
return err return err
} }
if isNew { if isNew {
_, err = s.db.Exec("DELETE FROM characters WHERE id = $1", charID) return s.charRepo.HardDelete(charID)
} else {
_, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
} }
return err return s.charRepo.SoftDelete(charID)
} }
func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) { func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) {
var characters []Character return s.charRepo.GetForUser(ctx, uid)
err := s.db.SelectContext(
ctx, &characters, `
SELECT id, name, is_female, weapon_type, hr, gr, last_login
FROM characters
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
uid,
)
if err != nil {
return nil, err
}
return characters, nil
} }
func (s *APIServer) getReturnExpiry(uid uint32) time.Time { func (s *APIServer) getReturnExpiry(uid uint32) time.Time {
var returnExpiry, lastLogin time.Time lastLogin, _ := s.userRepo.GetLastLogin(uid)
_ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid) var returnExpiry time.Time
if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) { if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) {
returnExpiry = time.Now().Add(time.Hour * 24 * 30) returnExpiry = time.Now().Add(time.Hour * 24 * 30)
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
} else { } else {
err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid) var err error
returnExpiry, err = s.userRepo.GetReturnExpiry(uid)
if err != nil { if err != nil {
returnExpiry = time.Now() returnExpiry = time.Now()
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid) _ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
} }
} }
_, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid) _ = s.userRepo.UpdateLastLogin(uid, time.Now())
return returnExpiry return returnExpiry
} }
func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) { func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) {
row := s.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", cid, uid) return s.charRepo.ExportSave(ctx, uid, cid)
result := make(map[string]interface{})
err := row.MapScan(result)
if err != nil {
return nil, err
}
return result, nil
} }

View File

@@ -162,12 +162,7 @@ func (s *APIServer) Login(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(400) w.WriteHeader(400)
return return
} }
var ( userID, password, userRights, err := s.userRepo.GetCredentials(ctx, reqData.Username)
userID uint32
userRights uint32
password string
)
err := s.db.QueryRow("SELECT id, password, rights FROM users WHERE username = $1", reqData.Username).Scan(&userID, &password, &userRights)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
w.WriteHeader(400) w.WriteHeader(400)
_, _ = w.Write([]byte("username-error")) _, _ = w.Write([]byte("username-error"))

View File

@@ -8,6 +8,7 @@ import (
"net/http/httptest" "net/http/httptest"
"strings" "strings"
"testing" "testing"
"time"
cfg "erupe-ce/config" cfg "erupe-ce/config"
"erupe-ce/common/gametime" "erupe-ce/common/gametime"
@@ -33,7 +34,6 @@ func TestLauncherEndpoint(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
// Create test request // Create test request
@@ -123,7 +123,6 @@ func TestLoginEndpointInvalidJSON(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
// Invalid JSON // Invalid JSON
@@ -148,7 +147,6 @@ func TestLoginEndpointEmptyCredentials(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
tests := []struct { tests := []struct {
@@ -200,7 +198,6 @@ func TestRegisterEndpointInvalidJSON(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
invalidJSON := `{"username": "test"` invalidJSON := `{"username": "test"`
@@ -223,7 +220,6 @@ func TestRegisterEndpointEmptyCredentials(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
tests := []struct { tests := []struct {
@@ -271,7 +267,6 @@ func TestCreateCharacterEndpointInvalidJSON(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
invalidJSON := `{"token": ` invalidJSON := `{"token": `
@@ -294,7 +289,6 @@ func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
invalidJSON := `{"token": "test"` invalidJSON := `{"token": "test"`
@@ -317,7 +311,6 @@ func TestExportSaveEndpointInvalidJSON(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
invalidJSON := `{"token": ` invalidJSON := `{"token": `
@@ -342,7 +335,6 @@ func TestScreenShotEndpointDisabled(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil) req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil)
@@ -379,7 +371,6 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil,
} }
tests := []struct { tests := []struct {
@@ -408,10 +399,16 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
} }
} }
// newTestUserRepo returns a mock user repo suitable for newAuthData tests.
func newTestUserRepo() *mockAPIUserRepo {
return &mockAPIUserRepo{
lastLogin: time.Now(),
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
}
}
// TestNewAuthDataStructure tests the newAuthData helper function // TestNewAuthDataStructure tests the newAuthData helper function
func TestNewAuthDataStructure(t *testing.T) { func TestNewAuthDataStructure(t *testing.T) {
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
logger := NewTestLogger(t) logger := NewTestLogger(t)
defer func() { _ = logger.Sync() }() defer func() { _ = logger.Sync() }()
@@ -423,7 +420,7 @@ func TestNewAuthDataStructure(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil, userRepo: newTestUserRepo(),
} }
characters := []Character{ characters := []Character{
@@ -466,8 +463,6 @@ func TestNewAuthDataStructure(t *testing.T) {
// TestNewAuthDataDebugMode tests newAuthData with debug mode enabled // TestNewAuthDataDebugMode tests newAuthData with debug mode enabled
func TestNewAuthDataDebugMode(t *testing.T) { func TestNewAuthDataDebugMode(t *testing.T) {
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
logger := NewTestLogger(t) logger := NewTestLogger(t)
defer func() { _ = logger.Sync() }() defer func() { _ = logger.Sync() }()
@@ -477,7 +472,7 @@ func TestNewAuthDataDebugMode(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil, userRepo: newTestUserRepo(),
} }
characters := []Character{ characters := []Character{
@@ -500,8 +495,6 @@ func TestNewAuthDataDebugMode(t *testing.T) {
// TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData // TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData
func TestNewAuthDataMezFesConfiguration(t *testing.T) { func TestNewAuthDataMezFesConfiguration(t *testing.T) {
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
logger := NewTestLogger(t) logger := NewTestLogger(t)
defer func() { _ = logger.Sync() }() defer func() { _ = logger.Sync() }()
@@ -513,7 +506,7 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil, userRepo: newTestUserRepo(),
} }
authData := server.newAuthData(1, 0, 1, "token", []Character{}) authData := server.newAuthData(1, 0, 1, "token", []Character{})
@@ -534,8 +527,6 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
// TestNewAuthDataHideNotices tests notice hiding in newAuthData // TestNewAuthDataHideNotices tests notice hiding in newAuthData
func TestNewAuthDataHideNotices(t *testing.T) { func TestNewAuthDataHideNotices(t *testing.T) {
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
logger := NewTestLogger(t) logger := NewTestLogger(t)
defer func() { _ = logger.Sync() }() defer func() { _ = logger.Sync() }()
@@ -546,7 +537,7 @@ func TestNewAuthDataHideNotices(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil, userRepo: newTestUserRepo(),
} }
authData := server.newAuthData(1, 0, 1, "token", []Character{}) authData := server.newAuthData(1, 0, 1, "token", []Character{})
@@ -558,8 +549,6 @@ func TestNewAuthDataHideNotices(t *testing.T) {
// TestNewAuthDataTimestamps tests timestamp generation in newAuthData // TestNewAuthDataTimestamps tests timestamp generation in newAuthData
func TestNewAuthDataTimestamps(t *testing.T) { func TestNewAuthDataTimestamps(t *testing.T) {
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
logger := NewTestLogger(t) logger := NewTestLogger(t)
defer func() { _ = logger.Sync() }() defer func() { _ = logger.Sync() }()
@@ -567,7 +556,7 @@ func TestNewAuthDataTimestamps(t *testing.T) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
db: nil, userRepo: newTestUserRepo(),
} }
authData := server.newAuthData(1, 0, 1, "token", []Character{}) authData := server.newAuthData(1, 0, 1, "token", []Character{})
@@ -611,6 +600,10 @@ func BenchmarkNewAuthData(b *testing.B) {
server := &APIServer{ server := &APIServer{
logger: logger, logger: logger,
erupeConfig: c, erupeConfig: c,
userRepo: &mockAPIUserRepo{
lastLogin: time.Now(),
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
},
} }
characters := make([]Character, 16) characters := make([]Character, 16)

View File

@@ -0,0 +1,87 @@
package api
import (
"context"
"github.com/jmoiron/sqlx"
)
// APICharacterRepository implements APICharacterRepo with PostgreSQL.
type APICharacterRepository struct {
db *sqlx.DB
}
// NewAPICharacterRepository creates a new APICharacterRepository.
func NewAPICharacterRepository(db *sqlx.DB) *APICharacterRepository {
return &APICharacterRepository{db: db}
}
func (r *APICharacterRepository) GetNewCharacter(ctx context.Context, userID uint32) (Character, error) {
var character Character
err := r.db.GetContext(ctx, &character,
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1",
userID,
)
return character, err
}
func (r *APICharacterRepository) CountForUser(ctx context.Context, userID uint32) (int, error) {
var count int
err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
return count, err
}
func (r *APICharacterRepository) Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) {
var character Character
err := r.db.GetContext(ctx, &character, `
INSERT INTO characters (
user_id, is_female, is_new_character, name, unk_desc_string,
hr, gr, weapon_type, last_login
)
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
userID, lastLogin,
)
return character, err
}
func (r *APICharacterRepository) IsNew(charID uint32) (bool, error) {
var isNew bool
err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
return isNew, err
}
func (r *APICharacterRepository) HardDelete(charID uint32) error {
_, err := r.db.Exec("DELETE FROM characters WHERE id = $1", charID)
return err
}
func (r *APICharacterRepository) SoftDelete(charID uint32) error {
_, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
return err
}
func (r *APICharacterRepository) GetForUser(ctx context.Context, userID uint32) ([]Character, error) {
var characters []Character
err := r.db.SelectContext(
ctx, &characters, `
SELECT id, name, is_female, weapon_type, hr, gr, last_login
FROM characters
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
userID,
)
if err != nil {
return nil, err
}
return characters, nil
}
func (r *APICharacterRepository) ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) {
row := r.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", charID, userID)
result := make(map[string]interface{})
err := row.MapScan(result)
if err != nil {
return nil, err
}
return result, nil
}

View File

@@ -0,0 +1,53 @@
package api
import (
"context"
"time"
)
// Repository interfaces decouple API server business logic from concrete
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
// APIUserRepo defines the contract for user-related data access.
type APIUserRepo interface {
// Register creates a new user and returns their ID and rights.
Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (id uint32, rights uint32, err error)
// GetCredentials returns the user's ID, password hash, and rights.
GetCredentials(ctx context.Context, username string) (id uint32, passwordHash string, rights uint32, err error)
// GetLastLogin returns the user's last login time.
GetLastLogin(uid uint32) (time.Time, error)
// GetReturnExpiry returns the user's return expiry time.
GetReturnExpiry(uid uint32) (time.Time, error)
// UpdateReturnExpiry sets the user's return expiry time.
UpdateReturnExpiry(uid uint32, expiry time.Time) error
// UpdateLastLogin sets the user's last login time.
UpdateLastLogin(uid uint32, loginTime time.Time) error
}
// APICharacterRepo defines the contract for character-related data access.
type APICharacterRepo interface {
// GetNewCharacter returns an existing new (unfinished) character for a user.
GetNewCharacter(ctx context.Context, userID uint32) (Character, error)
// CountForUser returns the total number of characters for a user.
CountForUser(ctx context.Context, userID uint32) (int, error)
// Create inserts a new character and returns it.
Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error)
// IsNew returns whether a character is a new (unfinished) character.
IsNew(charID uint32) (bool, error)
// HardDelete permanently removes a character.
HardDelete(charID uint32) error
// SoftDelete marks a character as deleted.
SoftDelete(charID uint32) error
// GetForUser returns all finalized (non-deleted) characters for a user.
GetForUser(ctx context.Context, userID uint32) ([]Character, error)
// ExportSave returns the full character row as a map.
ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error)
}
// APISessionRepo defines the contract for session/token data access.
type APISessionRepo interface {
// CreateToken inserts a new sign session and returns its ID and token.
CreateToken(ctx context.Context, uid uint32, token string) (tokenID uint32, err error)
// GetUserIDByToken returns the user ID for a given session token.
GetUserIDByToken(ctx context.Context, token string) (uint32, error)
}

View File

@@ -0,0 +1,124 @@
package api
import (
"context"
"time"
)
// mockAPIUserRepo implements APIUserRepo for testing.
type mockAPIUserRepo struct {
registerID uint32
registerRights uint32
registerErr error
credentialsID uint32
credentialsPassword string
credentialsRights uint32
credentialsErr error
lastLogin time.Time
lastLoginErr error
returnExpiry time.Time
returnExpiryErr error
updateReturnExpiryErr error
updateLastLoginErr error
}
func (m *mockAPIUserRepo) Register(_ context.Context, _, _ string, _ time.Time) (uint32, uint32, error) {
return m.registerID, m.registerRights, m.registerErr
}
func (m *mockAPIUserRepo) GetCredentials(_ context.Context, _ string) (uint32, string, uint32, error) {
return m.credentialsID, m.credentialsPassword, m.credentialsRights, m.credentialsErr
}
func (m *mockAPIUserRepo) GetLastLogin(_ uint32) (time.Time, error) {
return m.lastLogin, m.lastLoginErr
}
func (m *mockAPIUserRepo) GetReturnExpiry(_ uint32) (time.Time, error) {
return m.returnExpiry, m.returnExpiryErr
}
func (m *mockAPIUserRepo) UpdateReturnExpiry(_ uint32, _ time.Time) error {
return m.updateReturnExpiryErr
}
func (m *mockAPIUserRepo) UpdateLastLogin(_ uint32, _ time.Time) error {
return m.updateLastLoginErr
}
// mockAPICharacterRepo implements APICharacterRepo for testing.
type mockAPICharacterRepo struct {
newCharacter Character
newCharacterErr error
countForUser int
countForUserErr error
createChar Character
createCharErr error
isNewResult bool
isNewErr error
hardDeleteErr error
softDeleteErr error
characters []Character
charactersErr error
exportResult map[string]interface{}
exportErr error
}
func (m *mockAPICharacterRepo) GetNewCharacter(_ context.Context, _ uint32) (Character, error) {
return m.newCharacter, m.newCharacterErr
}
func (m *mockAPICharacterRepo) CountForUser(_ context.Context, _ uint32) (int, error) {
return m.countForUser, m.countForUserErr
}
func (m *mockAPICharacterRepo) Create(_ context.Context, _ uint32, _ uint32) (Character, error) {
return m.createChar, m.createCharErr
}
func (m *mockAPICharacterRepo) IsNew(_ uint32) (bool, error) {
return m.isNewResult, m.isNewErr
}
func (m *mockAPICharacterRepo) HardDelete(_ uint32) error {
return m.hardDeleteErr
}
func (m *mockAPICharacterRepo) SoftDelete(_ uint32) error {
return m.softDeleteErr
}
func (m *mockAPICharacterRepo) GetForUser(_ context.Context, _ uint32) ([]Character, error) {
return m.characters, m.charactersErr
}
func (m *mockAPICharacterRepo) ExportSave(_ context.Context, _, _ uint32) (map[string]interface{}, error) {
return m.exportResult, m.exportErr
}
// mockAPISessionRepo implements APISessionRepo for testing.
type mockAPISessionRepo struct {
createTokenID uint32
createTokenErr error
userID uint32
userIDErr error
}
func (m *mockAPISessionRepo) CreateToken(_ context.Context, _ uint32, _ string) (uint32, error) {
return m.createTokenID, m.createTokenErr
}
func (m *mockAPISessionRepo) GetUserIDByToken(_ context.Context, _ string) (uint32, error) {
return m.userID, m.userIDErr
}

View File

@@ -0,0 +1,29 @@
package api
import (
"context"
"github.com/jmoiron/sqlx"
)
// APISessionRepository implements APISessionRepo with PostgreSQL.
type APISessionRepository struct {
db *sqlx.DB
}
// NewAPISessionRepository creates a new APISessionRepository.
func NewAPISessionRepository(db *sqlx.DB) *APISessionRepository {
return &APISessionRepository{db: db}
}
func (r *APISessionRepository) CreateToken(ctx context.Context, uid uint32, token string) (uint32, error) {
var tid uint32
err := r.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&tid)
return tid, err
}
func (r *APISessionRepository) GetUserIDByToken(ctx context.Context, token string) (uint32, error) {
var userID uint32
err := r.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID)
return userID, err
}

66
server/api/repo_user.go Normal file
View File

@@ -0,0 +1,66 @@
package api
import (
"context"
"time"
"github.com/jmoiron/sqlx"
)
// APIUserRepository implements APIUserRepo with PostgreSQL.
type APIUserRepository struct {
db *sqlx.DB
}
// NewAPIUserRepository creates a new APIUserRepository.
func NewAPIUserRepository(db *sqlx.DB) *APIUserRepository {
return &APIUserRepository{db: db}
}
func (r *APIUserRepository) Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (uint32, uint32, error) {
var (
id uint32
rights uint32
)
err := r.db.QueryRowContext(
ctx, `
INSERT INTO users (username, password, return_expires)
VALUES ($1, $2, $3)
RETURNING id, rights
`,
username, passwordHash, returnExpires,
).Scan(&id, &rights)
return id, rights, err
}
func (r *APIUserRepository) GetCredentials(ctx context.Context, username string) (uint32, string, uint32, error) {
var (
id uint32
passwordHash string
rights uint32
)
err := r.db.QueryRowContext(ctx, "SELECT id, password, rights FROM users WHERE username = $1", username).Scan(&id, &passwordHash, &rights)
return id, passwordHash, rights, err
}
func (r *APIUserRepository) GetLastLogin(uid uint32) (time.Time, error) {
var lastLogin time.Time
err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
return lastLogin, err
}
func (r *APIUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) {
var returnExpiry time.Time
err := r.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
return returnExpiry, err
}
func (r *APIUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error {
_, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid)
return err
}
func (r *APIUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error {
_, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid)
return err
}

View File

@@ -32,9 +32,9 @@ func initCommands(cmds []cfg.Command, logger *zap.Logger) {
for _, cmd := range cmds { for _, cmd := range cmds {
commands[cmd.Name] = cmd commands[cmd.Name] = cmd
if cmd.Enabled { if cmd.Enabled {
logger.Info(fmt.Sprintf("Command %s: Enabled, prefix: %s", cmd.Name, cmd.Prefix)) logger.Info("Command registered", zap.String("name", cmd.Name), zap.String("prefix", cmd.Prefix), zap.Bool("enabled", true))
} else { } else {
logger.Info(fmt.Sprintf("Command %s: Disabled", cmd.Name)) logger.Info("Command registered", zap.String("name", cmd.Name), zap.Bool("enabled", false))
} }
} }
}) })

View File

@@ -1,7 +1,6 @@
package channelserver package channelserver
import ( import (
"fmt"
"strings" "strings"
"time" "time"
@@ -90,7 +89,7 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
if s.stage != nil { // avoids lock up when using bed for dream quests if s.stage != nil { // avoids lock up when using bed for dream quests
// Notify the client to duplicate the existing objects. // Notify the client to duplicate the existing objects.
s.logger.Info(fmt.Sprintf("Sending existing stage objects to %s", s.Name)) s.logger.Info("Sending existing stage objects", zap.String("session", s.Name))
// Lock stage to safely iterate over objects map // Lock stage to safely iterate over objects map
// We need to copy the objects list first to avoid holding the lock during packet building // We need to copy the objects list first to avoid holding the lock during packet building

View File

@@ -19,7 +19,8 @@ type Server struct {
sync.Mutex sync.Mutex
logger *zap.Logger logger *zap.Logger
erupeConfig *cfg.Config erupeConfig *cfg.Config
db *sqlx.DB serverRepo EntranceServerRepo
sessionRepo EntranceSessionRepo
listener net.Listener listener net.Listener
isShuttingDown bool isShuttingDown bool
} }
@@ -36,7 +37,10 @@ func NewServer(config *Config) *Server {
s := &Server{ s := &Server{
logger: config.Logger, logger: config.Logger,
erupeConfig: config.ErupeConfig, erupeConfig: config.ErupeConfig,
db: config.DB, }
if config.DB != nil {
s.serverRepo = NewEntranceServerRepository(config.DB)
s.sessionRepo = NewEntranceSessionRepository(config.DB)
} }
return s return s
} }

View File

@@ -71,7 +71,9 @@ func encodeServerInfo(config *cfg.Config, s *Server, local bool) []byte {
bf.WriteUint16(uint16(channelIdx | 16)) bf.WriteUint16(uint16(channelIdx | 16))
bf.WriteUint16(ci.MaxPlayers) bf.WriteUint16(ci.MaxPlayers)
var currentPlayers uint16 var currentPlayers uint16
_ = s.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", sid).Scan(&currentPlayers) if s.serverRepo != nil {
currentPlayers, _ = s.serverRepo.GetCurrentPlayers(sid)
}
bf.WriteUint16(currentPlayers) bf.WriteUint16(currentPlayers)
bf.WriteUint16(0) bf.WriteUint16(0)
bf.WriteUint16(0) bf.WriteUint16(0)
@@ -164,12 +166,10 @@ func makeUsrResp(pkt []byte, s *Server) []byte {
for i := 0; i < int(userEntries); i++ { for i := 0; i < int(userEntries); i++ {
cid := bf.ReadUint32() cid := bf.ReadUint32()
var sid uint16 var sid uint16
err := s.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", cid).Scan(&sid) if s.sessionRepo != nil {
if err != nil { sid, _ = s.sessionRepo.GetServerIDForCharacter(cid)
resp.WriteUint16(0)
} else {
resp.WriteUint16(sid)
} }
resp.WriteUint16(sid)
resp.WriteUint16(0) resp.WriteUint16(0)
} }

View File

@@ -113,6 +113,99 @@ func TestClanMemberLimitsBoundsChecking(t *testing.T) {
} }
// TestEncodeServerInfo_WithMockRepo tests encodeServerInfo with a mock server repo
func TestEncodeServerInfo_WithMockRepo(t *testing.T) {
config := &cfg.Config{
RealClientMode: cfg.Z1,
Host: "127.0.0.1",
Entrance: cfg.Entrance{
Enabled: true,
Port: 53310,
Entries: []cfg.EntranceServerInfo{
{
Name: "TestServer",
Description: "Test",
IP: "127.0.0.1",
Type: 0,
Recommended: 0,
AllowedClientFlags: 0xFFFFFFFF,
Channels: []cfg.EntranceChannelInfo{
{
Port: 54001,
MaxPlayers: 100,
},
},
},
},
},
GameplayOptions: cfg.GameplayOptions{
ClanMemberLimits: [][]uint8{{1, 60}},
},
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: config,
serverRepo: &mockEntranceServerRepo{currentPlayers: 42},
}
result := encodeServerInfo(config, server, true)
if len(result) == 0 {
t.Error("encodeServerInfo returned empty result")
}
}
// TestMakeUsrResp_WithMockRepo tests makeUsrResp with a mock session repo
func TestMakeUsrResp_WithMockRepo(t *testing.T) {
config := &cfg.Config{
RealClientMode: cfg.Z1,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: config,
sessionRepo: &mockEntranceSessionRepo{serverID: 1234},
}
// Build a minimal USR request packet:
// 4 bytes ALL+ prefix, 1 byte 0x00, 2 bytes entry count, then 4 bytes per entry (char ID)
pkt := []byte{
'A', 'L', 'L', '+',
0x00,
0x00, 0x01, // 1 entry
0x00, 0x00, 0x00, 0x01, // char_id = 1
}
result := makeUsrResp(pkt, server)
if len(result) == 0 {
t.Error("makeUsrResp returned empty result")
}
}
// TestMakeUsrResp_NilSessionRepo tests makeUsrResp when sessionRepo is nil
func TestMakeUsrResp_NilSessionRepo(t *testing.T) {
config := &cfg.Config{
RealClientMode: cfg.Z1,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: config,
}
pkt := []byte{
'A', 'L', 'L', '+',
0x00,
0x00, 0x01,
0x00, 0x00, 0x00, 0x01,
}
result := makeUsrResp(pkt, server)
if len(result) == 0 {
t.Error("makeUsrResp returned empty result")
}
}
// TestEncodeServerInfo_MissingSecondColumnClanMemberLimits tests accessing [last][1] when [last] is too small // TestEncodeServerInfo_MissingSecondColumnClanMemberLimits tests accessing [last][1] when [last] is too small
// Previously panicked: runtime error: index out of range [1] // Previously panicked: runtime error: index out of range [1]
// After fix: Should handle missing column gracefully with default value (60) // After fix: Should handle missing column gracefully with default value (60)

View File

@@ -0,0 +1,19 @@
package entranceserver
// Repository interfaces decouple entrance server business logic from concrete
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
// EntranceServerRepo defines the contract for server-related data access
// used by the entrance server when building server list responses.
type EntranceServerRepo interface {
// GetCurrentPlayers returns the current player count for a given server ID.
GetCurrentPlayers(serverID int) (uint16, error)
}
// EntranceSessionRepo defines the contract for session-related data access
// used by the entrance server when resolving user locations.
type EntranceSessionRepo interface {
// GetServerIDForCharacter returns the server ID where the given character
// is currently signed in, or 0 if not found.
GetServerIDForCharacter(charID uint32) (uint16, error)
}

View File

@@ -0,0 +1,21 @@
package entranceserver
// mockEntranceServerRepo implements EntranceServerRepo for testing.
type mockEntranceServerRepo struct {
currentPlayers uint16
currentPlayersErr error
}
func (m *mockEntranceServerRepo) GetCurrentPlayers(_ int) (uint16, error) {
return m.currentPlayers, m.currentPlayersErr
}
// mockEntranceSessionRepo implements EntranceSessionRepo for testing.
type mockEntranceSessionRepo struct {
serverID uint16
serverIDErr error
}
func (m *mockEntranceSessionRepo) GetServerIDForCharacter(_ uint32) (uint16, error) {
return m.serverID, m.serverIDErr
}

View File

@@ -0,0 +1,22 @@
package entranceserver
import "github.com/jmoiron/sqlx"
// EntranceServerRepository implements EntranceServerRepo with PostgreSQL.
type EntranceServerRepository struct {
db *sqlx.DB
}
// NewEntranceServerRepository creates a new EntranceServerRepository.
func NewEntranceServerRepository(db *sqlx.DB) *EntranceServerRepository {
return &EntranceServerRepository{db: db}
}
func (r *EntranceServerRepository) GetCurrentPlayers(serverID int) (uint16, error) {
var currentPlayers uint16
err := r.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", serverID).Scan(&currentPlayers)
if err != nil {
return 0, err
}
return currentPlayers, nil
}

View File

@@ -0,0 +1,22 @@
package entranceserver
import "github.com/jmoiron/sqlx"
// EntranceSessionRepository implements EntranceSessionRepo with PostgreSQL.
type EntranceSessionRepository struct {
db *sqlx.DB
}
// NewEntranceSessionRepository creates a new EntranceSessionRepository.
func NewEntranceSessionRepository(db *sqlx.DB) *EntranceSessionRepository {
return &EntranceSessionRepository{db: db}
}
func (r *EntranceSessionRepository) GetServerIDForCharacter(charID uint32) (uint16, error) {
var sid uint16
err := r.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", charID).Scan(&sid)
if err != nil {
return 0, err
}
return sid, nil
}