mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-21 23:22:34 +01:00
refactor: replace raw SQL with repository interfaces in entranceserver and API server
Extract all direct database calls from entranceserver (2 calls) and API server (17 calls) into typed repository interfaces with PostgreSQL implementations, matching the pattern established in signserver and channelserver. Entranceserver: EntranceServerRepo, EntranceSessionRepo API server: APIUserRepo, APICharacterRepo, APISessionRepo Also fix the 3 remaining fmt.Sprintf calls inside logger invocations in handlers_commands.go and handlers_stage.go, replacing them with structured zap fields. Unskip 5 TestNewAuthData* tests that previously required a real database — they now run with mock repos.
This commit is contained in:
@@ -27,7 +27,9 @@ type APIServer struct {
|
||||
sync.Mutex
|
||||
logger *zap.Logger
|
||||
erupeConfig *cfg.Config
|
||||
db *sqlx.DB
|
||||
userRepo APIUserRepo
|
||||
charRepo APICharacterRepo
|
||||
sessionRepo APISessionRepo
|
||||
httpServer *http.Server
|
||||
isShuttingDown bool
|
||||
}
|
||||
@@ -37,9 +39,13 @@ func NewAPIServer(config *Config) *APIServer {
|
||||
s := &APIServer{
|
||||
logger: config.Logger,
|
||||
erupeConfig: config.ErupeConfig,
|
||||
db: config.DB,
|
||||
httpServer: &http.Server{},
|
||||
}
|
||||
if config.DB != nil {
|
||||
s.userRepo = NewAPIUserRepository(config.DB)
|
||||
s.charRepo = NewAPICharacterRepository(config.DB)
|
||||
s.sessionRepo = NewAPISessionRepository(config.DB)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"erupe-ce/common/token"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,41 +12,25 @@ import (
|
||||
)
|
||||
|
||||
func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) {
|
||||
// Create salted hash of user password
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var (
|
||||
id uint32
|
||||
rights uint32
|
||||
)
|
||||
err = s.db.QueryRowContext(
|
||||
ctx, `
|
||||
INSERT INTO users (username, password, return_expires)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, rights
|
||||
`,
|
||||
username, string(passwordHash), time.Now().Add(time.Hour*24*30),
|
||||
).Scan(&id, &rights)
|
||||
return id, rights, err
|
||||
return s.userRepo.Register(ctx, username, string(passwordHash), time.Now().Add(time.Hour*24*30))
|
||||
}
|
||||
|
||||
func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) {
|
||||
loginToken := token.Generate(16)
|
||||
var tid uint32
|
||||
err := s.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, loginToken).Scan(&tid)
|
||||
tid, err := s.sessionRepo.CreateToken(ctx, uid, loginToken)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return tid, loginToken, nil
|
||||
}
|
||||
|
||||
func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, error) {
|
||||
var userID uint32
|
||||
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID)
|
||||
if err == sql.ErrNoRows {
|
||||
func (s *APIServer) userIDFromToken(ctx context.Context, tkn string) (uint32, error) {
|
||||
userID, err := s.sessionRepo.GetUserIDByToken(ctx, tkn)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, fmt.Errorf("invalid login token")
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
@@ -54,82 +39,50 @@ func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32,
|
||||
}
|
||||
|
||||
func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) {
|
||||
var character Character
|
||||
err := s.db.GetContext(ctx, &character,
|
||||
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
var count int
|
||||
_ = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
|
||||
character, err := s.charRepo.GetNewCharacter(ctx, userID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
count, _ := s.charRepo.CountForUser(ctx, userID)
|
||||
if count >= 16 {
|
||||
return character, fmt.Errorf("cannot have more than 16 characters")
|
||||
}
|
||||
err = s.db.GetContext(ctx, &character, `
|
||||
INSERT INTO characters (
|
||||
user_id, is_female, is_new_character, name, unk_desc_string,
|
||||
hr, gr, weapon_type, last_login
|
||||
)
|
||||
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
|
||||
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
|
||||
userID, uint32(time.Now().Unix()),
|
||||
)
|
||||
character, err = s.charRepo.Create(ctx, userID, uint32(time.Now().Unix()))
|
||||
}
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (s *APIServer) deleteCharacter(ctx context.Context, userID uint32, charID uint32) error {
|
||||
var isNew bool
|
||||
err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
|
||||
func (s *APIServer) deleteCharacter(_ context.Context, _ uint32, charID uint32) error {
|
||||
isNew, err := s.charRepo.IsNew(charID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isNew {
|
||||
_, err = s.db.Exec("DELETE FROM characters WHERE id = $1", charID)
|
||||
} else {
|
||||
_, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
|
||||
return s.charRepo.HardDelete(charID)
|
||||
}
|
||||
return err
|
||||
return s.charRepo.SoftDelete(charID)
|
||||
}
|
||||
|
||||
func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) {
|
||||
var characters []Character
|
||||
err := s.db.SelectContext(
|
||||
ctx, &characters, `
|
||||
SELECT id, name, is_female, weapon_type, hr, gr, last_login
|
||||
FROM characters
|
||||
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
|
||||
uid,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return characters, nil
|
||||
return s.charRepo.GetForUser(ctx, uid)
|
||||
}
|
||||
|
||||
func (s *APIServer) getReturnExpiry(uid uint32) time.Time {
|
||||
var returnExpiry, lastLogin time.Time
|
||||
_ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
|
||||
lastLogin, _ := s.userRepo.GetLastLogin(uid)
|
||||
var returnExpiry time.Time
|
||||
if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) {
|
||||
returnExpiry = time.Now().Add(time.Hour * 24 * 30)
|
||||
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
|
||||
_ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
|
||||
} else {
|
||||
err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
|
||||
var err error
|
||||
returnExpiry, err = s.userRepo.GetReturnExpiry(uid)
|
||||
if err != nil {
|
||||
returnExpiry = time.Now()
|
||||
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
|
||||
_ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
|
||||
}
|
||||
}
|
||||
_, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid)
|
||||
_ = s.userRepo.UpdateLastLogin(uid, time.Now())
|
||||
return returnExpiry
|
||||
}
|
||||
|
||||
func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) {
|
||||
row := s.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", cid, uid)
|
||||
result := make(map[string]interface{})
|
||||
err := row.MapScan(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return s.charRepo.ExportSave(ctx, uid, cid)
|
||||
}
|
||||
|
||||
@@ -162,12 +162,7 @@ func (s *APIServer) Login(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
var (
|
||||
userID uint32
|
||||
userRights uint32
|
||||
password string
|
||||
)
|
||||
err := s.db.QueryRow("SELECT id, password, rights FROM users WHERE username = $1", reqData.Username).Scan(&userID, &password, &userRights)
|
||||
userID, password, userRights, err := s.userRepo.GetCredentials(ctx, reqData.Username)
|
||||
if err == sql.ErrNoRows {
|
||||
w.WriteHeader(400)
|
||||
_, _ = w.Write([]byte("username-error"))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cfg "erupe-ce/config"
|
||||
"erupe-ce/common/gametime"
|
||||
@@ -33,7 +34,6 @@ func TestLauncherEndpoint(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
// Create test request
|
||||
@@ -123,7 +123,6 @@ func TestLoginEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
// Invalid JSON
|
||||
@@ -148,7 +147,6 @@ func TestLoginEndpointEmptyCredentials(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -200,7 +198,6 @@ func TestRegisterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"username": "test"`
|
||||
@@ -223,7 +220,6 @@ func TestRegisterEndpointEmptyCredentials(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -271,7 +267,6 @@ func TestCreateCharacterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": `
|
||||
@@ -294,7 +289,6 @@ func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": "test"`
|
||||
@@ -317,7 +311,6 @@ func TestExportSaveEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": `
|
||||
@@ -342,7 +335,6 @@ func TestScreenShotEndpointDisabled(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil)
|
||||
@@ -379,7 +371,6 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -408,10 +399,16 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newTestUserRepo returns a mock user repo suitable for newAuthData tests.
|
||||
func newTestUserRepo() *mockAPIUserRepo {
|
||||
return &mockAPIUserRepo{
|
||||
lastLogin: time.Now(),
|
||||
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAuthDataStructure tests the newAuthData helper function
|
||||
func TestNewAuthDataStructure(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -423,7 +420,7 @@ func TestNewAuthDataStructure(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
characters := []Character{
|
||||
@@ -466,8 +463,6 @@ func TestNewAuthDataStructure(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataDebugMode tests newAuthData with debug mode enabled
|
||||
func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -477,7 +472,7 @@ func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
characters := []Character{
|
||||
@@ -500,8 +495,6 @@ func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData
|
||||
func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -513,7 +506,7 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -534,8 +527,6 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataHideNotices tests notice hiding in newAuthData
|
||||
func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -546,7 +537,7 @@ func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -558,8 +549,6 @@ func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataTimestamps tests timestamp generation in newAuthData
|
||||
func TestNewAuthDataTimestamps(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -567,7 +556,7 @@ func TestNewAuthDataTimestamps(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -611,6 +600,10 @@ func BenchmarkNewAuthData(b *testing.B) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
userRepo: &mockAPIUserRepo{
|
||||
lastLogin: time.Now(),
|
||||
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
|
||||
},
|
||||
}
|
||||
|
||||
characters := make([]Character, 16)
|
||||
|
||||
87
server/api/repo_character.go
Normal file
87
server/api/repo_character.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APICharacterRepository implements APICharacterRepo with PostgreSQL.
|
||||
type APICharacterRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPICharacterRepository creates a new APICharacterRepository.
|
||||
func NewAPICharacterRepository(db *sqlx.DB) *APICharacterRepository {
|
||||
return &APICharacterRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) GetNewCharacter(ctx context.Context, userID uint32) (Character, error) {
|
||||
var character Character
|
||||
err := r.db.GetContext(ctx, &character,
|
||||
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) CountForUser(ctx context.Context, userID uint32) (int, error) {
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) {
|
||||
var character Character
|
||||
err := r.db.GetContext(ctx, &character, `
|
||||
INSERT INTO characters (
|
||||
user_id, is_female, is_new_character, name, unk_desc_string,
|
||||
hr, gr, weapon_type, last_login
|
||||
)
|
||||
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
|
||||
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
|
||||
userID, lastLogin,
|
||||
)
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) IsNew(charID uint32) (bool, error) {
|
||||
var isNew bool
|
||||
err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
|
||||
return isNew, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) HardDelete(charID uint32) error {
|
||||
_, err := r.db.Exec("DELETE FROM characters WHERE id = $1", charID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) SoftDelete(charID uint32) error {
|
||||
_, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) GetForUser(ctx context.Context, userID uint32) ([]Character, error) {
|
||||
var characters []Character
|
||||
err := r.db.SelectContext(
|
||||
ctx, &characters, `
|
||||
SELECT id, name, is_female, weapon_type, hr, gr, last_login
|
||||
FROM characters
|
||||
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return characters, nil
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) {
|
||||
row := r.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", charID, userID)
|
||||
result := make(map[string]interface{})
|
||||
err := row.MapScan(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
53
server/api/repo_interfaces.go
Normal file
53
server/api/repo_interfaces.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Repository interfaces decouple API server business logic from concrete
|
||||
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
|
||||
|
||||
// APIUserRepo defines the contract for user-related data access.
|
||||
type APIUserRepo interface {
|
||||
// Register creates a new user and returns their ID and rights.
|
||||
Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (id uint32, rights uint32, err error)
|
||||
// GetCredentials returns the user's ID, password hash, and rights.
|
||||
GetCredentials(ctx context.Context, username string) (id uint32, passwordHash string, rights uint32, err error)
|
||||
// GetLastLogin returns the user's last login time.
|
||||
GetLastLogin(uid uint32) (time.Time, error)
|
||||
// GetReturnExpiry returns the user's return expiry time.
|
||||
GetReturnExpiry(uid uint32) (time.Time, error)
|
||||
// UpdateReturnExpiry sets the user's return expiry time.
|
||||
UpdateReturnExpiry(uid uint32, expiry time.Time) error
|
||||
// UpdateLastLogin sets the user's last login time.
|
||||
UpdateLastLogin(uid uint32, loginTime time.Time) error
|
||||
}
|
||||
|
||||
// APICharacterRepo defines the contract for character-related data access.
|
||||
type APICharacterRepo interface {
|
||||
// GetNewCharacter returns an existing new (unfinished) character for a user.
|
||||
GetNewCharacter(ctx context.Context, userID uint32) (Character, error)
|
||||
// CountForUser returns the total number of characters for a user.
|
||||
CountForUser(ctx context.Context, userID uint32) (int, error)
|
||||
// Create inserts a new character and returns it.
|
||||
Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error)
|
||||
// IsNew returns whether a character is a new (unfinished) character.
|
||||
IsNew(charID uint32) (bool, error)
|
||||
// HardDelete permanently removes a character.
|
||||
HardDelete(charID uint32) error
|
||||
// SoftDelete marks a character as deleted.
|
||||
SoftDelete(charID uint32) error
|
||||
// GetForUser returns all finalized (non-deleted) characters for a user.
|
||||
GetForUser(ctx context.Context, userID uint32) ([]Character, error)
|
||||
// ExportSave returns the full character row as a map.
|
||||
ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// APISessionRepo defines the contract for session/token data access.
|
||||
type APISessionRepo interface {
|
||||
// CreateToken inserts a new sign session and returns its ID and token.
|
||||
CreateToken(ctx context.Context, uid uint32, token string) (tokenID uint32, err error)
|
||||
// GetUserIDByToken returns the user ID for a given session token.
|
||||
GetUserIDByToken(ctx context.Context, token string) (uint32, error)
|
||||
}
|
||||
124
server/api/repo_mocks_test.go
Normal file
124
server/api/repo_mocks_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockAPIUserRepo implements APIUserRepo for testing.
|
||||
type mockAPIUserRepo struct {
|
||||
registerID uint32
|
||||
registerRights uint32
|
||||
registerErr error
|
||||
|
||||
credentialsID uint32
|
||||
credentialsPassword string
|
||||
credentialsRights uint32
|
||||
credentialsErr error
|
||||
|
||||
lastLogin time.Time
|
||||
lastLoginErr error
|
||||
|
||||
returnExpiry time.Time
|
||||
returnExpiryErr error
|
||||
|
||||
updateReturnExpiryErr error
|
||||
updateLastLoginErr error
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) Register(_ context.Context, _, _ string, _ time.Time) (uint32, uint32, error) {
|
||||
return m.registerID, m.registerRights, m.registerErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetCredentials(_ context.Context, _ string) (uint32, string, uint32, error) {
|
||||
return m.credentialsID, m.credentialsPassword, m.credentialsRights, m.credentialsErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetLastLogin(_ uint32) (time.Time, error) {
|
||||
return m.lastLogin, m.lastLoginErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetReturnExpiry(_ uint32) (time.Time, error) {
|
||||
return m.returnExpiry, m.returnExpiryErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) UpdateReturnExpiry(_ uint32, _ time.Time) error {
|
||||
return m.updateReturnExpiryErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) UpdateLastLogin(_ uint32, _ time.Time) error {
|
||||
return m.updateLastLoginErr
|
||||
}
|
||||
|
||||
// mockAPICharacterRepo implements APICharacterRepo for testing.
|
||||
type mockAPICharacterRepo struct {
|
||||
newCharacter Character
|
||||
newCharacterErr error
|
||||
|
||||
countForUser int
|
||||
countForUserErr error
|
||||
|
||||
createChar Character
|
||||
createCharErr error
|
||||
|
||||
isNewResult bool
|
||||
isNewErr error
|
||||
|
||||
hardDeleteErr error
|
||||
softDeleteErr error
|
||||
|
||||
characters []Character
|
||||
charactersErr error
|
||||
|
||||
exportResult map[string]interface{}
|
||||
exportErr error
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) GetNewCharacter(_ context.Context, _ uint32) (Character, error) {
|
||||
return m.newCharacter, m.newCharacterErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) CountForUser(_ context.Context, _ uint32) (int, error) {
|
||||
return m.countForUser, m.countForUserErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) Create(_ context.Context, _ uint32, _ uint32) (Character, error) {
|
||||
return m.createChar, m.createCharErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) IsNew(_ uint32) (bool, error) {
|
||||
return m.isNewResult, m.isNewErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) HardDelete(_ uint32) error {
|
||||
return m.hardDeleteErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) SoftDelete(_ uint32) error {
|
||||
return m.softDeleteErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) GetForUser(_ context.Context, _ uint32) ([]Character, error) {
|
||||
return m.characters, m.charactersErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) ExportSave(_ context.Context, _, _ uint32) (map[string]interface{}, error) {
|
||||
return m.exportResult, m.exportErr
|
||||
}
|
||||
|
||||
// mockAPISessionRepo implements APISessionRepo for testing.
|
||||
type mockAPISessionRepo struct {
|
||||
createTokenID uint32
|
||||
createTokenErr error
|
||||
|
||||
userID uint32
|
||||
userIDErr error
|
||||
}
|
||||
|
||||
func (m *mockAPISessionRepo) CreateToken(_ context.Context, _ uint32, _ string) (uint32, error) {
|
||||
return m.createTokenID, m.createTokenErr
|
||||
}
|
||||
|
||||
func (m *mockAPISessionRepo) GetUserIDByToken(_ context.Context, _ string) (uint32, error) {
|
||||
return m.userID, m.userIDErr
|
||||
}
|
||||
29
server/api/repo_session.go
Normal file
29
server/api/repo_session.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APISessionRepository implements APISessionRepo with PostgreSQL.
|
||||
type APISessionRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPISessionRepository creates a new APISessionRepository.
|
||||
func NewAPISessionRepository(db *sqlx.DB) *APISessionRepository {
|
||||
return &APISessionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APISessionRepository) CreateToken(ctx context.Context, uid uint32, token string) (uint32, error) {
|
||||
var tid uint32
|
||||
err := r.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&tid)
|
||||
return tid, err
|
||||
}
|
||||
|
||||
func (r *APISessionRepository) GetUserIDByToken(ctx context.Context, token string) (uint32, error) {
|
||||
var userID uint32
|
||||
err := r.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID)
|
||||
return userID, err
|
||||
}
|
||||
66
server/api/repo_user.go
Normal file
66
server/api/repo_user.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APIUserRepository implements APIUserRepo with PostgreSQL.
|
||||
type APIUserRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPIUserRepository creates a new APIUserRepository.
|
||||
func NewAPIUserRepository(db *sqlx.DB) *APIUserRepository {
|
||||
return &APIUserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (uint32, uint32, error) {
|
||||
var (
|
||||
id uint32
|
||||
rights uint32
|
||||
)
|
||||
err := r.db.QueryRowContext(
|
||||
ctx, `
|
||||
INSERT INTO users (username, password, return_expires)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, rights
|
||||
`,
|
||||
username, passwordHash, returnExpires,
|
||||
).Scan(&id, &rights)
|
||||
return id, rights, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetCredentials(ctx context.Context, username string) (uint32, string, uint32, error) {
|
||||
var (
|
||||
id uint32
|
||||
passwordHash string
|
||||
rights uint32
|
||||
)
|
||||
err := r.db.QueryRowContext(ctx, "SELECT id, password, rights FROM users WHERE username = $1", username).Scan(&id, &passwordHash, &rights)
|
||||
return id, passwordHash, rights, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetLastLogin(uid uint32) (time.Time, error) {
|
||||
var lastLogin time.Time
|
||||
err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
|
||||
return lastLogin, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) {
|
||||
var returnExpiry time.Time
|
||||
err := r.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
|
||||
return returnExpiry, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error {
|
||||
_, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error {
|
||||
_, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid)
|
||||
return err
|
||||
}
|
||||
@@ -32,9 +32,9 @@ func initCommands(cmds []cfg.Command, logger *zap.Logger) {
|
||||
for _, cmd := range cmds {
|
||||
commands[cmd.Name] = cmd
|
||||
if cmd.Enabled {
|
||||
logger.Info(fmt.Sprintf("Command %s: Enabled, prefix: %s", cmd.Name, cmd.Prefix))
|
||||
logger.Info("Command registered", zap.String("name", cmd.Name), zap.String("prefix", cmd.Prefix), zap.Bool("enabled", true))
|
||||
} else {
|
||||
logger.Info(fmt.Sprintf("Command %s: Disabled", cmd.Name))
|
||||
logger.Info("Command registered", zap.String("name", cmd.Name), zap.Bool("enabled", false))
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -90,7 +89,7 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
|
||||
|
||||
if s.stage != nil { // avoids lock up when using bed for dream quests
|
||||
// Notify the client to duplicate the existing objects.
|
||||
s.logger.Info(fmt.Sprintf("Sending existing stage objects to %s", s.Name))
|
||||
s.logger.Info("Sending existing stage objects", zap.String("session", s.Name))
|
||||
|
||||
// Lock stage to safely iterate over objects map
|
||||
// We need to copy the objects list first to avoid holding the lock during packet building
|
||||
|
||||
@@ -19,7 +19,8 @@ type Server struct {
|
||||
sync.Mutex
|
||||
logger *zap.Logger
|
||||
erupeConfig *cfg.Config
|
||||
db *sqlx.DB
|
||||
serverRepo EntranceServerRepo
|
||||
sessionRepo EntranceSessionRepo
|
||||
listener net.Listener
|
||||
isShuttingDown bool
|
||||
}
|
||||
@@ -36,7 +37,10 @@ func NewServer(config *Config) *Server {
|
||||
s := &Server{
|
||||
logger: config.Logger,
|
||||
erupeConfig: config.ErupeConfig,
|
||||
db: config.DB,
|
||||
}
|
||||
if config.DB != nil {
|
||||
s.serverRepo = NewEntranceServerRepository(config.DB)
|
||||
s.sessionRepo = NewEntranceSessionRepository(config.DB)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
@@ -71,7 +71,9 @@ func encodeServerInfo(config *cfg.Config, s *Server, local bool) []byte {
|
||||
bf.WriteUint16(uint16(channelIdx | 16))
|
||||
bf.WriteUint16(ci.MaxPlayers)
|
||||
var currentPlayers uint16
|
||||
_ = s.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", sid).Scan(¤tPlayers)
|
||||
if s.serverRepo != nil {
|
||||
currentPlayers, _ = s.serverRepo.GetCurrentPlayers(sid)
|
||||
}
|
||||
bf.WriteUint16(currentPlayers)
|
||||
bf.WriteUint16(0)
|
||||
bf.WriteUint16(0)
|
||||
@@ -164,12 +166,10 @@ func makeUsrResp(pkt []byte, s *Server) []byte {
|
||||
for i := 0; i < int(userEntries); i++ {
|
||||
cid := bf.ReadUint32()
|
||||
var sid uint16
|
||||
err := s.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", cid).Scan(&sid)
|
||||
if err != nil {
|
||||
resp.WriteUint16(0)
|
||||
} else {
|
||||
resp.WriteUint16(sid)
|
||||
if s.sessionRepo != nil {
|
||||
sid, _ = s.sessionRepo.GetServerIDForCharacter(cid)
|
||||
}
|
||||
resp.WriteUint16(sid)
|
||||
resp.WriteUint16(0)
|
||||
}
|
||||
|
||||
|
||||
@@ -113,6 +113,99 @@ func TestClanMemberLimitsBoundsChecking(t *testing.T) {
|
||||
}
|
||||
|
||||
|
||||
// TestEncodeServerInfo_WithMockRepo tests encodeServerInfo with a mock server repo
|
||||
func TestEncodeServerInfo_WithMockRepo(t *testing.T) {
|
||||
config := &cfg.Config{
|
||||
RealClientMode: cfg.Z1,
|
||||
Host: "127.0.0.1",
|
||||
Entrance: cfg.Entrance{
|
||||
Enabled: true,
|
||||
Port: 53310,
|
||||
Entries: []cfg.EntranceServerInfo{
|
||||
{
|
||||
Name: "TestServer",
|
||||
Description: "Test",
|
||||
IP: "127.0.0.1",
|
||||
Type: 0,
|
||||
Recommended: 0,
|
||||
AllowedClientFlags: 0xFFFFFFFF,
|
||||
Channels: []cfg.EntranceChannelInfo{
|
||||
{
|
||||
Port: 54001,
|
||||
MaxPlayers: 100,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
GameplayOptions: cfg.GameplayOptions{
|
||||
ClanMemberLimits: [][]uint8{{1, 60}},
|
||||
},
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
logger: zap.NewNop(),
|
||||
erupeConfig: config,
|
||||
serverRepo: &mockEntranceServerRepo{currentPlayers: 42},
|
||||
}
|
||||
|
||||
result := encodeServerInfo(config, server, true)
|
||||
if len(result) == 0 {
|
||||
t.Error("encodeServerInfo returned empty result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeUsrResp_WithMockRepo tests makeUsrResp with a mock session repo
|
||||
func TestMakeUsrResp_WithMockRepo(t *testing.T) {
|
||||
config := &cfg.Config{
|
||||
RealClientMode: cfg.Z1,
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
logger: zap.NewNop(),
|
||||
erupeConfig: config,
|
||||
sessionRepo: &mockEntranceSessionRepo{serverID: 1234},
|
||||
}
|
||||
|
||||
// Build a minimal USR request packet:
|
||||
// 4 bytes ALL+ prefix, 1 byte 0x00, 2 bytes entry count, then 4 bytes per entry (char ID)
|
||||
pkt := []byte{
|
||||
'A', 'L', 'L', '+',
|
||||
0x00,
|
||||
0x00, 0x01, // 1 entry
|
||||
0x00, 0x00, 0x00, 0x01, // char_id = 1
|
||||
}
|
||||
|
||||
result := makeUsrResp(pkt, server)
|
||||
if len(result) == 0 {
|
||||
t.Error("makeUsrResp returned empty result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMakeUsrResp_NilSessionRepo tests makeUsrResp when sessionRepo is nil
|
||||
func TestMakeUsrResp_NilSessionRepo(t *testing.T) {
|
||||
config := &cfg.Config{
|
||||
RealClientMode: cfg.Z1,
|
||||
}
|
||||
|
||||
server := &Server{
|
||||
logger: zap.NewNop(),
|
||||
erupeConfig: config,
|
||||
}
|
||||
|
||||
pkt := []byte{
|
||||
'A', 'L', 'L', '+',
|
||||
0x00,
|
||||
0x00, 0x01,
|
||||
0x00, 0x00, 0x00, 0x01,
|
||||
}
|
||||
|
||||
result := makeUsrResp(pkt, server)
|
||||
if len(result) == 0 {
|
||||
t.Error("makeUsrResp returned empty result")
|
||||
}
|
||||
}
|
||||
|
||||
// TestEncodeServerInfo_MissingSecondColumnClanMemberLimits tests accessing [last][1] when [last] is too small
|
||||
// Previously panicked: runtime error: index out of range [1]
|
||||
// After fix: Should handle missing column gracefully with default value (60)
|
||||
|
||||
19
server/entranceserver/repo_interfaces.go
Normal file
19
server/entranceserver/repo_interfaces.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package entranceserver
|
||||
|
||||
// Repository interfaces decouple entrance server business logic from concrete
|
||||
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
|
||||
|
||||
// EntranceServerRepo defines the contract for server-related data access
|
||||
// used by the entrance server when building server list responses.
|
||||
type EntranceServerRepo interface {
|
||||
// GetCurrentPlayers returns the current player count for a given server ID.
|
||||
GetCurrentPlayers(serverID int) (uint16, error)
|
||||
}
|
||||
|
||||
// EntranceSessionRepo defines the contract for session-related data access
|
||||
// used by the entrance server when resolving user locations.
|
||||
type EntranceSessionRepo interface {
|
||||
// GetServerIDForCharacter returns the server ID where the given character
|
||||
// is currently signed in, or 0 if not found.
|
||||
GetServerIDForCharacter(charID uint32) (uint16, error)
|
||||
}
|
||||
21
server/entranceserver/repo_mocks_test.go
Normal file
21
server/entranceserver/repo_mocks_test.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package entranceserver
|
||||
|
||||
// mockEntranceServerRepo implements EntranceServerRepo for testing.
|
||||
type mockEntranceServerRepo struct {
|
||||
currentPlayers uint16
|
||||
currentPlayersErr error
|
||||
}
|
||||
|
||||
func (m *mockEntranceServerRepo) GetCurrentPlayers(_ int) (uint16, error) {
|
||||
return m.currentPlayers, m.currentPlayersErr
|
||||
}
|
||||
|
||||
// mockEntranceSessionRepo implements EntranceSessionRepo for testing.
|
||||
type mockEntranceSessionRepo struct {
|
||||
serverID uint16
|
||||
serverIDErr error
|
||||
}
|
||||
|
||||
func (m *mockEntranceSessionRepo) GetServerIDForCharacter(_ uint32) (uint16, error) {
|
||||
return m.serverID, m.serverIDErr
|
||||
}
|
||||
22
server/entranceserver/repo_server.go
Normal file
22
server/entranceserver/repo_server.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package entranceserver
|
||||
|
||||
import "github.com/jmoiron/sqlx"
|
||||
|
||||
// EntranceServerRepository implements EntranceServerRepo with PostgreSQL.
|
||||
type EntranceServerRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewEntranceServerRepository creates a new EntranceServerRepository.
|
||||
func NewEntranceServerRepository(db *sqlx.DB) *EntranceServerRepository {
|
||||
return &EntranceServerRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *EntranceServerRepository) GetCurrentPlayers(serverID int) (uint16, error) {
|
||||
var currentPlayers uint16
|
||||
err := r.db.QueryRow("SELECT current_players FROM servers WHERE server_id=$1", serverID).Scan(¤tPlayers)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return currentPlayers, nil
|
||||
}
|
||||
22
server/entranceserver/repo_session.go
Normal file
22
server/entranceserver/repo_session.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package entranceserver
|
||||
|
||||
import "github.com/jmoiron/sqlx"
|
||||
|
||||
// EntranceSessionRepository implements EntranceSessionRepo with PostgreSQL.
|
||||
type EntranceSessionRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewEntranceSessionRepository creates a new EntranceSessionRepository.
|
||||
func NewEntranceSessionRepository(db *sqlx.DB) *EntranceSessionRepository {
|
||||
return &EntranceSessionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *EntranceSessionRepository) GetServerIDForCharacter(charID uint32) (uint16, error) {
|
||||
var sid uint16
|
||||
err := r.db.QueryRow("SELECT(SELECT server_id FROM sign_sessions WHERE char_id=$1) AS _", charID).Scan(&sid)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return sid, nil
|
||||
}
|
||||
Reference in New Issue
Block a user