mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-22 07:32:32 +01:00
refactor: replace raw SQL with repository interfaces in entranceserver and API server
Extract all direct database calls from entranceserver (2 calls) and API server (17 calls) into typed repository interfaces with PostgreSQL implementations, matching the pattern established in signserver and channelserver. Entranceserver: EntranceServerRepo, EntranceSessionRepo API server: APIUserRepo, APICharacterRepo, APISessionRepo Also fix the 3 remaining fmt.Sprintf calls inside logger invocations in handlers_commands.go and handlers_stage.go, replacing them with structured zap fields. Unskip 5 TestNewAuthData* tests that previously required a real database — they now run with mock repos.
This commit is contained in:
@@ -27,7 +27,9 @@ type APIServer struct {
|
||||
sync.Mutex
|
||||
logger *zap.Logger
|
||||
erupeConfig *cfg.Config
|
||||
db *sqlx.DB
|
||||
userRepo APIUserRepo
|
||||
charRepo APICharacterRepo
|
||||
sessionRepo APISessionRepo
|
||||
httpServer *http.Server
|
||||
isShuttingDown bool
|
||||
}
|
||||
@@ -37,9 +39,13 @@ func NewAPIServer(config *Config) *APIServer {
|
||||
s := &APIServer{
|
||||
logger: config.Logger,
|
||||
erupeConfig: config.ErupeConfig,
|
||||
db: config.DB,
|
||||
httpServer: &http.Server{},
|
||||
}
|
||||
if config.DB != nil {
|
||||
s.userRepo = NewAPIUserRepository(config.DB)
|
||||
s.charRepo = NewAPICharacterRepository(config.DB)
|
||||
s.sessionRepo = NewAPISessionRepository(config.DB)
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"erupe-ce/common/token"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
@@ -11,41 +12,25 @@ import (
|
||||
)
|
||||
|
||||
func (s *APIServer) createNewUser(ctx context.Context, username string, password string) (uint32, uint32, error) {
|
||||
// Create salted hash of user password
|
||||
passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
var (
|
||||
id uint32
|
||||
rights uint32
|
||||
)
|
||||
err = s.db.QueryRowContext(
|
||||
ctx, `
|
||||
INSERT INTO users (username, password, return_expires)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, rights
|
||||
`,
|
||||
username, string(passwordHash), time.Now().Add(time.Hour*24*30),
|
||||
).Scan(&id, &rights)
|
||||
return id, rights, err
|
||||
return s.userRepo.Register(ctx, username, string(passwordHash), time.Now().Add(time.Hour*24*30))
|
||||
}
|
||||
|
||||
func (s *APIServer) createLoginToken(ctx context.Context, uid uint32) (uint32, string, error) {
|
||||
loginToken := token.Generate(16)
|
||||
var tid uint32
|
||||
err := s.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, loginToken).Scan(&tid)
|
||||
tid, err := s.sessionRepo.CreateToken(ctx, uid, loginToken)
|
||||
if err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
return tid, loginToken, nil
|
||||
}
|
||||
|
||||
func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32, error) {
|
||||
var userID uint32
|
||||
err := s.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID)
|
||||
if err == sql.ErrNoRows {
|
||||
func (s *APIServer) userIDFromToken(ctx context.Context, tkn string) (uint32, error) {
|
||||
userID, err := s.sessionRepo.GetUserIDByToken(ctx, tkn)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return 0, fmt.Errorf("invalid login token")
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
@@ -54,82 +39,50 @@ func (s *APIServer) userIDFromToken(ctx context.Context, token string) (uint32,
|
||||
}
|
||||
|
||||
func (s *APIServer) createCharacter(ctx context.Context, userID uint32) (Character, error) {
|
||||
var character Character
|
||||
err := s.db.GetContext(ctx, &character,
|
||||
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
var count int
|
||||
_ = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
|
||||
character, err := s.charRepo.GetNewCharacter(ctx, userID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
count, _ := s.charRepo.CountForUser(ctx, userID)
|
||||
if count >= 16 {
|
||||
return character, fmt.Errorf("cannot have more than 16 characters")
|
||||
}
|
||||
err = s.db.GetContext(ctx, &character, `
|
||||
INSERT INTO characters (
|
||||
user_id, is_female, is_new_character, name, unk_desc_string,
|
||||
hr, gr, weapon_type, last_login
|
||||
)
|
||||
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
|
||||
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
|
||||
userID, uint32(time.Now().Unix()),
|
||||
)
|
||||
character, err = s.charRepo.Create(ctx, userID, uint32(time.Now().Unix()))
|
||||
}
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (s *APIServer) deleteCharacter(ctx context.Context, userID uint32, charID uint32) error {
|
||||
var isNew bool
|
||||
err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
|
||||
func (s *APIServer) deleteCharacter(_ context.Context, _ uint32, charID uint32) error {
|
||||
isNew, err := s.charRepo.IsNew(charID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if isNew {
|
||||
_, err = s.db.Exec("DELETE FROM characters WHERE id = $1", charID)
|
||||
} else {
|
||||
_, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
|
||||
return s.charRepo.HardDelete(charID)
|
||||
}
|
||||
return err
|
||||
return s.charRepo.SoftDelete(charID)
|
||||
}
|
||||
|
||||
func (s *APIServer) getCharactersForUser(ctx context.Context, uid uint32) ([]Character, error) {
|
||||
var characters []Character
|
||||
err := s.db.SelectContext(
|
||||
ctx, &characters, `
|
||||
SELECT id, name, is_female, weapon_type, hr, gr, last_login
|
||||
FROM characters
|
||||
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
|
||||
uid,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return characters, nil
|
||||
return s.charRepo.GetForUser(ctx, uid)
|
||||
}
|
||||
|
||||
func (s *APIServer) getReturnExpiry(uid uint32) time.Time {
|
||||
var returnExpiry, lastLogin time.Time
|
||||
_ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
|
||||
lastLogin, _ := s.userRepo.GetLastLogin(uid)
|
||||
var returnExpiry time.Time
|
||||
if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) {
|
||||
returnExpiry = time.Now().Add(time.Hour * 24 * 30)
|
||||
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
|
||||
_ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
|
||||
} else {
|
||||
err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
|
||||
var err error
|
||||
returnExpiry, err = s.userRepo.GetReturnExpiry(uid)
|
||||
if err != nil {
|
||||
returnExpiry = time.Now()
|
||||
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
|
||||
_ = s.userRepo.UpdateReturnExpiry(uid, returnExpiry)
|
||||
}
|
||||
}
|
||||
_, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid)
|
||||
_ = s.userRepo.UpdateLastLogin(uid, time.Now())
|
||||
return returnExpiry
|
||||
}
|
||||
|
||||
func (s *APIServer) exportSave(ctx context.Context, uid uint32, cid uint32) (map[string]interface{}, error) {
|
||||
row := s.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", cid, uid)
|
||||
result := make(map[string]interface{})
|
||||
err := row.MapScan(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
return s.charRepo.ExportSave(ctx, uid, cid)
|
||||
}
|
||||
|
||||
@@ -162,12 +162,7 @@ func (s *APIServer) Login(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(400)
|
||||
return
|
||||
}
|
||||
var (
|
||||
userID uint32
|
||||
userRights uint32
|
||||
password string
|
||||
)
|
||||
err := s.db.QueryRow("SELECT id, password, rights FROM users WHERE username = $1", reqData.Username).Scan(&userID, &password, &userRights)
|
||||
userID, password, userRights, err := s.userRepo.GetCredentials(ctx, reqData.Username)
|
||||
if err == sql.ErrNoRows {
|
||||
w.WriteHeader(400)
|
||||
_, _ = w.Write([]byte("username-error"))
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
cfg "erupe-ce/config"
|
||||
"erupe-ce/common/gametime"
|
||||
@@ -33,7 +34,6 @@ func TestLauncherEndpoint(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
// Create test request
|
||||
@@ -123,7 +123,6 @@ func TestLoginEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
// Invalid JSON
|
||||
@@ -148,7 +147,6 @@ func TestLoginEndpointEmptyCredentials(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -200,7 +198,6 @@ func TestRegisterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"username": "test"`
|
||||
@@ -223,7 +220,6 @@ func TestRegisterEndpointEmptyCredentials(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -271,7 +267,6 @@ func TestCreateCharacterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": `
|
||||
@@ -294,7 +289,6 @@ func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": "test"`
|
||||
@@ -317,7 +311,6 @@ func TestExportSaveEndpointInvalidJSON(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
invalidJSON := `{"token": `
|
||||
@@ -342,7 +335,6 @@ func TestScreenShotEndpointDisabled(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
req := httptest.NewRequest("POST", "/api/ss/bbs/upload.php", nil)
|
||||
@@ -379,7 +371,6 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -408,10 +399,16 @@ func TestScreenShotGetInvalidToken(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// newTestUserRepo returns a mock user repo suitable for newAuthData tests.
|
||||
func newTestUserRepo() *mockAPIUserRepo {
|
||||
return &mockAPIUserRepo{
|
||||
lastLogin: time.Now(),
|
||||
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
|
||||
}
|
||||
}
|
||||
|
||||
// TestNewAuthDataStructure tests the newAuthData helper function
|
||||
func TestNewAuthDataStructure(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -423,7 +420,7 @@ func TestNewAuthDataStructure(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
characters := []Character{
|
||||
@@ -466,8 +463,6 @@ func TestNewAuthDataStructure(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataDebugMode tests newAuthData with debug mode enabled
|
||||
func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -477,7 +472,7 @@ func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
characters := []Character{
|
||||
@@ -500,8 +495,6 @@ func TestNewAuthDataDebugMode(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataMezFesConfiguration tests MezFes configuration in newAuthData
|
||||
func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -513,7 +506,7 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -534,8 +527,6 @@ func TestNewAuthDataMezFesConfiguration(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataHideNotices tests notice hiding in newAuthData
|
||||
func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -546,7 +537,7 @@ func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -558,8 +549,6 @@ func TestNewAuthDataHideNotices(t *testing.T) {
|
||||
|
||||
// TestNewAuthDataTimestamps tests timestamp generation in newAuthData
|
||||
func TestNewAuthDataTimestamps(t *testing.T) {
|
||||
t.Skip("newAuthData requires database for getReturnExpiry - needs integration test")
|
||||
|
||||
logger := NewTestLogger(t)
|
||||
defer func() { _ = logger.Sync() }()
|
||||
|
||||
@@ -567,7 +556,7 @@ func TestNewAuthDataTimestamps(t *testing.T) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
db: nil,
|
||||
userRepo: newTestUserRepo(),
|
||||
}
|
||||
|
||||
authData := server.newAuthData(1, 0, 1, "token", []Character{})
|
||||
@@ -611,6 +600,10 @@ func BenchmarkNewAuthData(b *testing.B) {
|
||||
server := &APIServer{
|
||||
logger: logger,
|
||||
erupeConfig: c,
|
||||
userRepo: &mockAPIUserRepo{
|
||||
lastLogin: time.Now(),
|
||||
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
|
||||
},
|
||||
}
|
||||
|
||||
characters := make([]Character, 16)
|
||||
|
||||
87
server/api/repo_character.go
Normal file
87
server/api/repo_character.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APICharacterRepository implements APICharacterRepo with PostgreSQL.
|
||||
type APICharacterRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPICharacterRepository creates a new APICharacterRepository.
|
||||
func NewAPICharacterRepository(db *sqlx.DB) *APICharacterRepository {
|
||||
return &APICharacterRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) GetNewCharacter(ctx context.Context, userID uint32) (Character, error) {
|
||||
var character Character
|
||||
err := r.db.GetContext(ctx, &character,
|
||||
"SELECT id, name, is_female, weapon_type, hr, gr, last_login FROM characters WHERE is_new_character = true AND user_id = $1 LIMIT 1",
|
||||
userID,
|
||||
)
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) CountForUser(ctx context.Context, userID uint32) (int, error) {
|
||||
var count int
|
||||
err := r.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM characters WHERE user_id = $1", userID).Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error) {
|
||||
var character Character
|
||||
err := r.db.GetContext(ctx, &character, `
|
||||
INSERT INTO characters (
|
||||
user_id, is_female, is_new_character, name, unk_desc_string,
|
||||
hr, gr, weapon_type, last_login
|
||||
)
|
||||
VALUES ($1, false, true, '', '', 0, 0, 0, $2)
|
||||
RETURNING id, name, is_female, weapon_type, hr, gr, last_login`,
|
||||
userID, lastLogin,
|
||||
)
|
||||
return character, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) IsNew(charID uint32) (bool, error) {
|
||||
var isNew bool
|
||||
err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", charID).Scan(&isNew)
|
||||
return isNew, err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) HardDelete(charID uint32) error {
|
||||
_, err := r.db.Exec("DELETE FROM characters WHERE id = $1", charID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) SoftDelete(charID uint32) error {
|
||||
_, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", charID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) GetForUser(ctx context.Context, userID uint32) ([]Character, error) {
|
||||
var characters []Character
|
||||
err := r.db.SelectContext(
|
||||
ctx, &characters, `
|
||||
SELECT id, name, is_female, weapon_type, hr, gr, last_login
|
||||
FROM characters
|
||||
WHERE user_id = $1 AND deleted = false AND is_new_character = false ORDER BY id ASC`,
|
||||
userID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return characters, nil
|
||||
}
|
||||
|
||||
func (r *APICharacterRepository) ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error) {
|
||||
row := r.db.QueryRowxContext(ctx, "SELECT * FROM characters WHERE id=$1 AND user_id=$2", charID, userID)
|
||||
result := make(map[string]interface{})
|
||||
err := row.MapScan(result)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
53
server/api/repo_interfaces.go
Normal file
53
server/api/repo_interfaces.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Repository interfaces decouple API server business logic from concrete
|
||||
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
|
||||
|
||||
// APIUserRepo defines the contract for user-related data access.
|
||||
type APIUserRepo interface {
|
||||
// Register creates a new user and returns their ID and rights.
|
||||
Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (id uint32, rights uint32, err error)
|
||||
// GetCredentials returns the user's ID, password hash, and rights.
|
||||
GetCredentials(ctx context.Context, username string) (id uint32, passwordHash string, rights uint32, err error)
|
||||
// GetLastLogin returns the user's last login time.
|
||||
GetLastLogin(uid uint32) (time.Time, error)
|
||||
// GetReturnExpiry returns the user's return expiry time.
|
||||
GetReturnExpiry(uid uint32) (time.Time, error)
|
||||
// UpdateReturnExpiry sets the user's return expiry time.
|
||||
UpdateReturnExpiry(uid uint32, expiry time.Time) error
|
||||
// UpdateLastLogin sets the user's last login time.
|
||||
UpdateLastLogin(uid uint32, loginTime time.Time) error
|
||||
}
|
||||
|
||||
// APICharacterRepo defines the contract for character-related data access.
|
||||
type APICharacterRepo interface {
|
||||
// GetNewCharacter returns an existing new (unfinished) character for a user.
|
||||
GetNewCharacter(ctx context.Context, userID uint32) (Character, error)
|
||||
// CountForUser returns the total number of characters for a user.
|
||||
CountForUser(ctx context.Context, userID uint32) (int, error)
|
||||
// Create inserts a new character and returns it.
|
||||
Create(ctx context.Context, userID uint32, lastLogin uint32) (Character, error)
|
||||
// IsNew returns whether a character is a new (unfinished) character.
|
||||
IsNew(charID uint32) (bool, error)
|
||||
// HardDelete permanently removes a character.
|
||||
HardDelete(charID uint32) error
|
||||
// SoftDelete marks a character as deleted.
|
||||
SoftDelete(charID uint32) error
|
||||
// GetForUser returns all finalized (non-deleted) characters for a user.
|
||||
GetForUser(ctx context.Context, userID uint32) ([]Character, error)
|
||||
// ExportSave returns the full character row as a map.
|
||||
ExportSave(ctx context.Context, userID, charID uint32) (map[string]interface{}, error)
|
||||
}
|
||||
|
||||
// APISessionRepo defines the contract for session/token data access.
|
||||
type APISessionRepo interface {
|
||||
// CreateToken inserts a new sign session and returns its ID and token.
|
||||
CreateToken(ctx context.Context, uid uint32, token string) (tokenID uint32, err error)
|
||||
// GetUserIDByToken returns the user ID for a given session token.
|
||||
GetUserIDByToken(ctx context.Context, token string) (uint32, error)
|
||||
}
|
||||
124
server/api/repo_mocks_test.go
Normal file
124
server/api/repo_mocks_test.go
Normal file
@@ -0,0 +1,124 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
// mockAPIUserRepo implements APIUserRepo for testing.
|
||||
type mockAPIUserRepo struct {
|
||||
registerID uint32
|
||||
registerRights uint32
|
||||
registerErr error
|
||||
|
||||
credentialsID uint32
|
||||
credentialsPassword string
|
||||
credentialsRights uint32
|
||||
credentialsErr error
|
||||
|
||||
lastLogin time.Time
|
||||
lastLoginErr error
|
||||
|
||||
returnExpiry time.Time
|
||||
returnExpiryErr error
|
||||
|
||||
updateReturnExpiryErr error
|
||||
updateLastLoginErr error
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) Register(_ context.Context, _, _ string, _ time.Time) (uint32, uint32, error) {
|
||||
return m.registerID, m.registerRights, m.registerErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetCredentials(_ context.Context, _ string) (uint32, string, uint32, error) {
|
||||
return m.credentialsID, m.credentialsPassword, m.credentialsRights, m.credentialsErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetLastLogin(_ uint32) (time.Time, error) {
|
||||
return m.lastLogin, m.lastLoginErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) GetReturnExpiry(_ uint32) (time.Time, error) {
|
||||
return m.returnExpiry, m.returnExpiryErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) UpdateReturnExpiry(_ uint32, _ time.Time) error {
|
||||
return m.updateReturnExpiryErr
|
||||
}
|
||||
|
||||
func (m *mockAPIUserRepo) UpdateLastLogin(_ uint32, _ time.Time) error {
|
||||
return m.updateLastLoginErr
|
||||
}
|
||||
|
||||
// mockAPICharacterRepo implements APICharacterRepo for testing.
|
||||
type mockAPICharacterRepo struct {
|
||||
newCharacter Character
|
||||
newCharacterErr error
|
||||
|
||||
countForUser int
|
||||
countForUserErr error
|
||||
|
||||
createChar Character
|
||||
createCharErr error
|
||||
|
||||
isNewResult bool
|
||||
isNewErr error
|
||||
|
||||
hardDeleteErr error
|
||||
softDeleteErr error
|
||||
|
||||
characters []Character
|
||||
charactersErr error
|
||||
|
||||
exportResult map[string]interface{}
|
||||
exportErr error
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) GetNewCharacter(_ context.Context, _ uint32) (Character, error) {
|
||||
return m.newCharacter, m.newCharacterErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) CountForUser(_ context.Context, _ uint32) (int, error) {
|
||||
return m.countForUser, m.countForUserErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) Create(_ context.Context, _ uint32, _ uint32) (Character, error) {
|
||||
return m.createChar, m.createCharErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) IsNew(_ uint32) (bool, error) {
|
||||
return m.isNewResult, m.isNewErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) HardDelete(_ uint32) error {
|
||||
return m.hardDeleteErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) SoftDelete(_ uint32) error {
|
||||
return m.softDeleteErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) GetForUser(_ context.Context, _ uint32) ([]Character, error) {
|
||||
return m.characters, m.charactersErr
|
||||
}
|
||||
|
||||
func (m *mockAPICharacterRepo) ExportSave(_ context.Context, _, _ uint32) (map[string]interface{}, error) {
|
||||
return m.exportResult, m.exportErr
|
||||
}
|
||||
|
||||
// mockAPISessionRepo implements APISessionRepo for testing.
|
||||
type mockAPISessionRepo struct {
|
||||
createTokenID uint32
|
||||
createTokenErr error
|
||||
|
||||
userID uint32
|
||||
userIDErr error
|
||||
}
|
||||
|
||||
func (m *mockAPISessionRepo) CreateToken(_ context.Context, _ uint32, _ string) (uint32, error) {
|
||||
return m.createTokenID, m.createTokenErr
|
||||
}
|
||||
|
||||
func (m *mockAPISessionRepo) GetUserIDByToken(_ context.Context, _ string) (uint32, error) {
|
||||
return m.userID, m.userIDErr
|
||||
}
|
||||
29
server/api/repo_session.go
Normal file
29
server/api/repo_session.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APISessionRepository implements APISessionRepo with PostgreSQL.
|
||||
type APISessionRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPISessionRepository creates a new APISessionRepository.
|
||||
func NewAPISessionRepository(db *sqlx.DB) *APISessionRepository {
|
||||
return &APISessionRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APISessionRepository) CreateToken(ctx context.Context, uid uint32, token string) (uint32, error) {
|
||||
var tid uint32
|
||||
err := r.db.QueryRowContext(ctx, "INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id", uid, token).Scan(&tid)
|
||||
return tid, err
|
||||
}
|
||||
|
||||
func (r *APISessionRepository) GetUserIDByToken(ctx context.Context, token string) (uint32, error) {
|
||||
var userID uint32
|
||||
err := r.db.QueryRowContext(ctx, "SELECT user_id FROM sign_sessions WHERE token = $1", token).Scan(&userID)
|
||||
return userID, err
|
||||
}
|
||||
66
server/api/repo_user.go
Normal file
66
server/api/repo_user.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// APIUserRepository implements APIUserRepo with PostgreSQL.
|
||||
type APIUserRepository struct {
|
||||
db *sqlx.DB
|
||||
}
|
||||
|
||||
// NewAPIUserRepository creates a new APIUserRepository.
|
||||
func NewAPIUserRepository(db *sqlx.DB) *APIUserRepository {
|
||||
return &APIUserRepository{db: db}
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) Register(ctx context.Context, username, passwordHash string, returnExpires time.Time) (uint32, uint32, error) {
|
||||
var (
|
||||
id uint32
|
||||
rights uint32
|
||||
)
|
||||
err := r.db.QueryRowContext(
|
||||
ctx, `
|
||||
INSERT INTO users (username, password, return_expires)
|
||||
VALUES ($1, $2, $3)
|
||||
RETURNING id, rights
|
||||
`,
|
||||
username, passwordHash, returnExpires,
|
||||
).Scan(&id, &rights)
|
||||
return id, rights, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetCredentials(ctx context.Context, username string) (uint32, string, uint32, error) {
|
||||
var (
|
||||
id uint32
|
||||
passwordHash string
|
||||
rights uint32
|
||||
)
|
||||
err := r.db.QueryRowContext(ctx, "SELECT id, password, rights FROM users WHERE username = $1", username).Scan(&id, &passwordHash, &rights)
|
||||
return id, passwordHash, rights, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetLastLogin(uid uint32) (time.Time, error) {
|
||||
var lastLogin time.Time
|
||||
err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
|
||||
return lastLogin, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) {
|
||||
var returnExpiry time.Time
|
||||
err := r.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
|
||||
return returnExpiry, err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error {
|
||||
_, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid)
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *APIUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error {
|
||||
_, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid)
|
||||
return err
|
||||
}
|
||||
Reference in New Issue
Block a user