refactor(signserver): replace raw SQL with repository interfaces

Extract all direct database access into three repository interfaces
(SignUserRepo, SignCharacterRepo, SignSessionRepo) matching the
pattern established in channelserver. This surfaces 9 previously
silenced errors that are now logged with structured context, and
makes the sign server testable with mock repos instead of go-sqlmock.

Security fix: GetFriends now uses parameterized ANY($1) queries
instead of string-concatenated WHERE clauses (SQL injection vector).
This commit is contained in:
Houmgaor
2026-02-22 16:30:24 +01:00
parent 53b5bb3b96
commit b3f75232a3
11 changed files with 1193 additions and 435 deletions

View File

@@ -5,7 +5,6 @@ import (
"errors"
"erupe-ce/common/mhfcourse"
"erupe-ce/common/token"
"strings"
"time"
"go.uber.org/zap"
@@ -13,34 +12,20 @@ import (
)
func (s *Server) newUserChara(uid uint32) error {
var numNewChars int
err := s.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", uid).Scan(&numNewChars)
numNewChars, err := s.charRepo.CountNewCharacters(uid)
if err != nil {
return err
}
// prevent users with an uninitialised character from creating more
if numNewChars >= 1 {
return err
return nil
}
_, err = s.db.Exec(`
INSERT INTO characters (
user_id, is_female, is_new_character, name, unk_desc_string,
hr, gr, weapon_type, last_login)
VALUES($1, False, True, '', '', 0, 0, 0, $2)`,
uid,
uint32(time.Now().Unix()),
)
if err != nil {
return err
}
return nil
return s.charRepo.CreateCharacter(uid, uint32(time.Now().Unix()))
}
func (s *Server) registerDBAccount(username string, password string) (uint32, error) {
var uid uint32
s.logger.Info("Creating user", zap.String("User", username))
// Create salted hash of user password
@@ -49,7 +34,7 @@ func (s *Server) registerDBAccount(username string, password string) (uint32, er
return 0, err
}
err = s.db.QueryRow("INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3) RETURNING id", username, string(passwordHash), time.Now().Add(time.Hour*24*30)).Scan(&uid)
uid, err := s.userRepo.Register(username, string(passwordHash), time.Now().Add(time.Hour*24*30))
if err != nil {
return 0, err
}
@@ -57,81 +42,65 @@ func (s *Server) registerDBAccount(username string, password string) (uint32, er
return uid, nil
}
type character struct {
ID uint32 `db:"id"`
IsFemale bool `db:"is_female"`
IsNewCharacter bool `db:"is_new_character"`
Name string `db:"name"`
UnkDescString string `db:"unk_desc_string"`
HR uint16 `db:"hr"`
GR uint16 `db:"gr"`
WeaponType uint16 `db:"weapon_type"`
LastLogin uint32 `db:"last_login"`
}
func (s *Server) getCharactersForUser(uid uint32) ([]character, error) {
characters := make([]character, 0)
err := s.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id", uid)
if err != nil {
return nil, err
}
return characters, nil
return s.charRepo.GetForUser(uid)
}
func (s *Server) getReturnExpiry(uid uint32) time.Time {
var returnExpiry, lastLogin time.Time
_ = s.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
var returnExpiry time.Time
lastLogin, err := s.userRepo.GetLastLogin(uid)
if err != nil {
s.logger.Warn("Failed to get last login", zap.Uint32("uid", uid), zap.Error(err))
lastLogin = time.Now()
}
if time.Now().Add((time.Hour * 24) * -90).After(lastLogin) {
returnExpiry = time.Now().Add(time.Hour * 24 * 30)
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
if err := s.userRepo.UpdateReturnExpiry(uid, returnExpiry); err != nil {
s.logger.Warn("Failed to update return expiry", zap.Uint32("uid", uid), zap.Error(err))
}
} else {
err := s.db.Get(&returnExpiry, "SELECT return_expires FROM users WHERE id=$1", uid)
returnExpiry, err = s.userRepo.GetReturnExpiry(uid)
if err != nil {
returnExpiry = time.Now()
_, _ = s.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", returnExpiry, uid)
if err := s.userRepo.UpdateReturnExpiry(uid, returnExpiry); err != nil {
s.logger.Warn("Failed to update return expiry (fallback)", zap.Uint32("uid", uid), zap.Error(err))
}
}
}
_, _ = s.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", time.Now(), uid)
if err := s.userRepo.UpdateLastLogin(uid, time.Now()); err != nil {
s.logger.Warn("Failed to update last login", zap.Uint32("uid", uid), zap.Error(err))
}
return returnExpiry
}
func (s *Server) getLastCID(uid uint32) uint32 {
var lastPlayed uint32
_ = s.db.QueryRow("SELECT last_character FROM users WHERE id=$1", uid).Scan(&lastPlayed)
lastPlayed, err := s.userRepo.GetLastCharacter(uid)
if err != nil {
s.logger.Warn("Failed to get last character", zap.Uint32("uid", uid), zap.Error(err))
return 0
}
return lastPlayed
}
func (s *Server) getUserRights(uid uint32) uint32 {
var rights uint32
if uid != 0 {
_ = s.db.QueryRow("SELECT rights FROM users WHERE id=$1", uid).Scan(&rights)
_, rights = mhfcourse.GetCourseStruct(rights, s.erupeConfig.DefaultCourses)
if uid == 0 {
return 0
}
rights, err := s.userRepo.GetRights(uid)
if err != nil {
s.logger.Warn("Failed to get user rights", zap.Uint32("uid", uid), zap.Error(err))
return 0
}
_, rights = mhfcourse.GetCourseStruct(rights, s.erupeConfig.DefaultCourses)
return rights
}
type members struct {
CID uint32 // Local character ID
ID uint32 `db:"id"`
Name string `db:"name"`
}
func (s *Server) getFriendsForCharacters(chars []character) []members {
friends := make([]members, 0)
for _, char := range chars {
friendsCSV := ""
_ = s.db.QueryRow("SELECT friends FROM characters WHERE id=$1", char.ID).Scan(&friendsCSV)
friendsSlice := strings.Split(friendsCSV, ",")
friendQuery := "SELECT id, name FROM characters WHERE id="
for i := 0; i < len(friendsSlice); i++ {
friendQuery += friendsSlice[i]
if i+1 != len(friendsSlice) {
friendQuery += " OR id="
}
}
charFriends := make([]members, 0)
err := s.db.Select(&charFriends, friendQuery)
charFriends, err := s.charRepo.GetFriends(char.ID)
if err != nil {
s.logger.Warn("Failed to get friends", zap.Uint32("charID", char.ID), zap.Error(err))
continue
}
for i := range charFriends {
@@ -145,79 +114,56 @@ func (s *Server) getFriendsForCharacters(chars []character) []members {
func (s *Server) getGuildmatesForCharacters(chars []character) []members {
guildmates := make([]members, 0)
for _, char := range chars {
var inGuild int
_ = s.db.QueryRow("SELECT count(*) FROM guild_characters WHERE character_id=$1", char.ID).Scan(&inGuild)
if inGuild > 0 {
var guildID int
err := s.db.QueryRow("SELECT guild_id FROM guild_characters WHERE character_id=$1", char.ID).Scan(&guildID)
if err != nil {
continue
}
charGuildmates := make([]members, 0)
err = s.db.Select(&charGuildmates, "SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=$1 AND character_id!=$2", guildID, char.ID)
if err != nil {
continue
}
for i := range charGuildmates {
charGuildmates[i].CID = char.ID
}
guildmates = append(guildmates, charGuildmates...)
charGuildmates, err := s.charRepo.GetGuildmates(char.ID)
if err != nil {
s.logger.Warn("Failed to get guildmates", zap.Uint32("charID", char.ID), zap.Error(err))
continue
}
for i := range charGuildmates {
charGuildmates[i].CID = char.ID
}
guildmates = append(guildmates, charGuildmates...)
}
return guildmates
}
func (s *Server) deleteCharacter(cid int, token string, tokenID uint32) error {
if !s.validateToken(token, tokenID) {
func (s *Server) deleteCharacter(cid int, tok string, tokenID uint32) error {
if !s.validateToken(tok, tokenID) {
return errors.New("invalid token")
}
var isNew bool
err := s.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", cid).Scan(&isNew)
isNew, err := s.charRepo.IsNewCharacter(cid)
if err != nil {
return err
}
if isNew {
_, err = s.db.Exec("DELETE FROM characters WHERE id = $1", cid)
} else {
_, err = s.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", cid)
return s.charRepo.HardDelete(cid)
}
if err != nil {
return err
}
return nil
return s.charRepo.SoftDelete(cid)
}
func (s *Server) registerUidToken(uid uint32) (uint32, string, error) {
_token := token.Generate(16)
var tid uint32
err := s.db.QueryRow(`INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id`, uid, _token).Scan(&tid)
tid, err := s.sessionRepo.RegisterUID(uid, _token)
return tid, _token, err
}
func (s *Server) registerPsnToken(psn string) (uint32, string, error) {
_token := token.Generate(16)
var tid uint32
err := s.db.QueryRow(`INSERT INTO sign_sessions (psn_id, token) VALUES ($1, $2) RETURNING id`, psn, _token).Scan(&tid)
tid, err := s.sessionRepo.RegisterPSN(psn, _token)
return tid, _token, err
}
func (s *Server) validateToken(token string, tokenID uint32) bool {
query := `SELECT count(*) FROM sign_sessions WHERE token = $1`
if tokenID > 0 {
query += ` AND id = $2`
}
var exists int
err := s.db.QueryRow(query, token, tokenID).Scan(&exists)
if err != nil || exists == 0 {
func (s *Server) validateToken(tok string, tokenID uint32) bool {
valid, err := s.sessionRepo.Validate(tok, tokenID)
if err != nil {
s.logger.Warn("Failed to validate token", zap.Error(err))
return false
}
return true
return valid
}
func (s *Server) validateLogin(user string, pass string) (uint32, RespID) {
var uid uint32
var passDB string
err := s.db.QueryRow(`SELECT id, password FROM users WHERE username = $1`, user).Scan(&uid, &passDB)
uid, passDB, err := s.userRepo.GetCredentials(user)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
s.logger.Info("User not found", zap.String("User", user))
@@ -225,26 +171,25 @@ func (s *Server) validateLogin(user string, pass string) (uint32, RespID) {
uid, err = s.registerDBAccount(user, pass)
if err == nil {
return uid, SIGN_SUCCESS
} else {
return 0, SIGN_EABORT
}
return 0, SIGN_EABORT
}
return 0, SIGN_EAUTH
}
return 0, SIGN_EABORT
} else {
if bcrypt.CompareHashAndPassword([]byte(passDB), []byte(pass)) == nil {
var bans int
err = s.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires IS NULL`, uid).Scan(&bans)
if err == nil && bans > 0 {
return uid, SIGN_EELIMINATE
}
err = s.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires > now()`, uid).Scan(&bans)
if err == nil && bans > 0 {
return uid, SIGN_ESUSPEND
}
return uid, SIGN_SUCCESS
}
}
if bcrypt.CompareHashAndPassword([]byte(passDB), []byte(pass)) != nil {
return 0, SIGN_EPASS
}
bans, err := s.userRepo.CountPermanentBans(uid)
if err == nil && bans > 0 {
return uid, SIGN_EELIMINATE
}
bans, err = s.userRepo.CountActiveBans(uid)
if err == nil && bans > 0 {
return uid, SIGN_ESUSPEND
}
return uid, SIGN_SUCCESS
}

View File

@@ -6,8 +6,7 @@ import (
"time"
cfg "erupe-ce/config"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jmoiron/sqlx"
"go.uber.org/zap"
)
@@ -292,315 +291,276 @@ func TestMultipleMembers(t *testing.T) {
}
}
// Helper to create a test server with mocked database
func newTestServerWithMock(t *testing.T) (*Server, sqlmock.Sqlmock) {
db, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to create sqlmock: %v", err)
func TestGetCharactersForUser(t *testing.T) {
charRepo := &mockSignCharacterRepo{
characters: []character{
{ID: 1, IsFemale: false, Name: "Hunter1", HR: 100, GR: 50, WeaponType: 3, LastLogin: 1700000000},
{ID: 2, IsFemale: true, Name: "Hunter2", HR: 200, GR: 100, WeaponType: 7, LastLogin: 1700000001},
},
}
sqlxDB := sqlx.NewDb(db, "sqlmock")
server := &Server{
logger: zap.NewNop(),
db: sqlxDB,
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
return server, mock
}
func TestGetCharactersForUser(t *testing.T) {
server, mock := newTestServerWithMock(t)
rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hr", "gr", "weapon_type", "last_login"}).
AddRow(1, false, false, "Hunter1", "desc1", 100, 50, 3, 1700000000).
AddRow(2, true, false, "Hunter2", "desc2", 200, 100, 7, 1700000001)
mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id").
WithArgs(uint32(1)).
WillReturnRows(rows)
chars, err := server.getCharactersForUser(1)
if err != nil {
t.Errorf("getCharactersForUser() error: %v", err)
}
if len(chars) != 2 {
t.Errorf("getCharactersForUser() returned %d characters, want 2", len(chars))
}
if chars[0].Name != "Hunter1" {
t.Errorf("First character name = %s, want Hunter1", chars[0].Name)
}
if chars[1].IsFemale != true {
t.Error("Second character should be female")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetCharactersForUserNoCharacters(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
characters: []character{},
}
rows := sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hr", "gr", "weapon_type", "last_login"})
mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id").
WithArgs(uint32(1)).
WillReturnRows(rows)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars, err := server.getCharactersForUser(1)
if err != nil {
t.Errorf("getCharactersForUser() error: %v", err)
}
if len(chars) != 0 {
t.Errorf("getCharactersForUser() returned %d characters, want 0", len(chars))
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetCharactersForUserDBError(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
getForUserErr: sql.ErrConnDone,
}
mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id").
WithArgs(uint32(1)).
WillReturnError(sql.ErrConnDone)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
_, err := server.getCharactersForUser(1)
if err == nil {
t.Error("getCharactersForUser() should return error")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetLastCID(t *testing.T) {
server, mock := newTestServerWithMock(t)
userRepo := &mockSignUserRepo{
lastCharacter: 12345,
}
mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(12345))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
lastCID := server.getLastCID(1)
if lastCID != 12345 {
t.Errorf("getLastCID() = %d, want 12345", lastCID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetLastCIDNoResult(t *testing.T) {
server, mock := newTestServerWithMock(t)
userRepo := &mockSignUserRepo{
lastCharacterErr: sql.ErrNoRows,
}
mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnError(sql.ErrNoRows)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
lastCID := server.getLastCID(1)
if lastCID != 0 {
t.Errorf("getLastCID() with no result = %d, want 0", lastCID)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetUserRights(t *testing.T) {
server, mock := newTestServerWithMock(t)
userRepo := &mockSignUserRepo{
rights: 30,
}
mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(30))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
rights := server.getUserRights(1)
if rights == 0 {
t.Error("getUserRights() should return non-zero value")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetReturnExpiry(t *testing.T) {
server, mock := newTestServerWithMock(t)
recentLogin := time.Now().Add(-time.Hour * 24)
mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin))
userRepo := &mockSignUserRepo{
lastLogin: recentLogin,
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
}
mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30)))
mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2").
WithArgs(sqlmock.AnyArg(), uint32(1)).
WillReturnResult(sqlmock.NewResult(0, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
expiry := server.getReturnExpiry(1)
if expiry.Before(time.Now()) {
t.Error("getReturnExpiry() should return future date")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !userRepo.updateLastLoginCalled {
t.Error("getReturnExpiry() should update last login")
}
}
func TestGetReturnExpiryInactiveUser(t *testing.T) {
server, mock := newTestServerWithMock(t)
oldLogin := time.Now().Add(-time.Hour * 24 * 100)
mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(oldLogin))
userRepo := &mockSignUserRepo{
lastLogin: oldLogin,
}
mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2").
WithArgs(sqlmock.AnyArg(), uint32(1)).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2").
WithArgs(sqlmock.AnyArg(), uint32(1)).
WillReturnResult(sqlmock.NewResult(0, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
expiry := server.getReturnExpiry(1)
if expiry.Before(time.Now()) {
t.Error("getReturnExpiry() should return future date for inactive user")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !userRepo.updateReturnExpiryCalled {
t.Error("getReturnExpiry() should update return expiry for inactive user")
}
if !userRepo.updateLastLoginCalled {
t.Error("getReturnExpiry() should update last login")
}
}
func TestGetReturnExpiryDBError(t *testing.T) {
server, mock := newTestServerWithMock(t)
recentLogin := time.Now().Add(-time.Hour * 24)
mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin))
userRepo := &mockSignUserRepo{
lastLogin: recentLogin,
returnExpiryErr: sql.ErrNoRows,
}
mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnError(sql.ErrNoRows)
mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2").
WithArgs(sqlmock.AnyArg(), uint32(1)).
WillReturnResult(sqlmock.NewResult(0, 1))
mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2").
WithArgs(sqlmock.AnyArg(), uint32(1)).
WillReturnResult(sqlmock.NewResult(0, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
expiry := server.getReturnExpiry(1)
if expiry.IsZero() {
t.Error("getReturnExpiry() should return non-zero time even on error")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !userRepo.updateReturnExpiryCalled {
t.Error("getReturnExpiry() should update return expiry on fallback")
}
}
func TestNewUserChara(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
newCharCount: 0,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO characters").
WithArgs(uint32(1), sqlmock.AnyArg()).
WillReturnResult(sqlmock.NewResult(1, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
err := server.newUserChara(1)
if err != nil {
t.Errorf("newUserChara() error: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !charRepo.createCalled {
t.Error("newUserChara() should call CreateCharacter")
}
}
func TestNewUserCharaAlreadyHasNewChar(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
newCharCount: 1,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
err := server.newUserChara(1)
if err != nil {
t.Errorf("newUserChara() should return nil when user already has new char: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if charRepo.createCalled {
t.Error("newUserChara() should not call CreateCharacter when user already has new char")
}
}
func TestNewUserCharaCountError(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
newCharCountErr: sql.ErrConnDone,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true").
WithArgs(uint32(1)).
WillReturnError(sql.ErrConnDone)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
err := server.newUserChara(1)
if err == nil {
t.Error("newUserChara() should return error when count query fails")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestNewUserCharaInsertError(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
newCharCount: 0,
createErr: sql.ErrConnDone,
}
mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
mock.ExpectExec("INSERT INTO characters").
WithArgs(uint32(1), sqlmock.AnyArg()).
WillReturnError(sql.ErrConnDone)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
err := server.newUserChara(1)
if err == nil {
t.Error("newUserChara() should return error when insert fails")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestRegisterDBAccount(t *testing.T) {
server, mock := newTestServerWithMock(t)
userRepo := &mockSignUserRepo{
registerUID: 1,
}
mock.ExpectQuery("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\) RETURNING id").
WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
uid, err := server.registerDBAccount("newuser", "password123")
if err != nil {
@@ -609,128 +569,125 @@ func TestRegisterDBAccount(t *testing.T) {
if uid != 1 {
t.Errorf("registerDBAccount() uid = %d, want 1", uid)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !userRepo.registered {
t.Error("registerDBAccount() should call Register")
}
}
func TestRegisterDBAccountDuplicateUser(t *testing.T) {
server, mock := newTestServerWithMock(t)
userRepo := &mockSignUserRepo{
registerErr: sql.ErrNoRows,
}
mock.ExpectQuery("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\) RETURNING id").
WithArgs("existinguser", sqlmock.AnyArg(), sqlmock.AnyArg()).
WillReturnError(sql.ErrNoRows)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
_, err := server.registerDBAccount("existinguser", "password123")
if err == nil {
t.Error("registerDBAccount() should return error for duplicate user")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDeleteCharacter(t *testing.T) {
server, mock := newTestServerWithMock(t)
sessionRepo := &mockSignSessionRepo{
validateResult: true,
}
charRepo := &mockSignCharacterRepo{
isNew: false,
}
// validateToken: SELECT count(*) FROM sign_sessions WHERE token = $1
// When tokenID=0, query has no AND clause but both args are still passed to QueryRow
mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1").
WithArgs("validtoken", uint32(0)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1").
WithArgs(123).
WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false))
mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1").
WithArgs(123).
WillReturnResult(sqlmock.NewResult(0, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
charRepo: charRepo,
}
err := server.deleteCharacter(123, "validtoken", 0)
if err != nil {
t.Errorf("deleteCharacter() error: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !charRepo.softDeleteCalled {
t.Error("deleteCharacter() should soft delete existing character")
}
}
func TestDeleteNewCharacter(t *testing.T) {
server, mock := newTestServerWithMock(t)
sessionRepo := &mockSignSessionRepo{
validateResult: true,
}
charRepo := &mockSignCharacterRepo{
isNew: true,
}
mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1").
WithArgs("validtoken", uint32(0)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1").
WithArgs(123).
WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(true))
mock.ExpectExec("DELETE FROM characters WHERE id = \\$1").
WithArgs(123).
WillReturnResult(sqlmock.NewResult(0, 1))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
charRepo: charRepo,
}
err := server.deleteCharacter(123, "validtoken", 0)
if err != nil {
t.Errorf("deleteCharacter() error: %v", err)
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if !charRepo.hardDeleteCalled {
t.Error("deleteCharacter() should hard delete new character")
}
}
func TestDeleteCharacterInvalidToken(t *testing.T) {
server, mock := newTestServerWithMock(t)
sessionRepo := &mockSignSessionRepo{
validateResult: false,
}
mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1").
WithArgs("invalidtoken", uint32(0)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
}
err := server.deleteCharacter(123, "invalidtoken", 0)
if err == nil {
t.Error("deleteCharacter() should return error for invalid token")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestDeleteCharacterDeleteError(t *testing.T) {
server, mock := newTestServerWithMock(t)
sessionRepo := &mockSignSessionRepo{
validateResult: true,
}
charRepo := &mockSignCharacterRepo{
isNew: false,
softDeleteErr: sql.ErrConnDone,
}
mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1").
WithArgs("validtoken", uint32(0)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1").
WithArgs(123).
WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false))
mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1").
WithArgs(123).
WillReturnError(sql.ErrConnDone)
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
charRepo: charRepo,
}
err := server.deleteCharacter(123, "validtoken", 0)
if err == nil {
t.Error("deleteCharacter() should return error when update fails")
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
}
}
func TestGetFriendsForCharactersEmpty(t *testing.T) {
server, _ := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{}
friends := server.getFriendsForCharacters(chars)
if len(friends) != 0 {
t.Errorf("getFriendsForCharacters() for empty chars = %d, want 0", len(friends))
@@ -738,10 +695,15 @@ func TestGetFriendsForCharactersEmpty(t *testing.T) {
}
func TestGetGuildmatesForCharactersEmpty(t *testing.T) {
server, _ := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{}
guildmates := server.getGuildmatesForCharacters(chars)
if len(guildmates) != 0 {
t.Errorf("getGuildmatesForCharacters() for empty chars = %d, want 0", len(guildmates))
@@ -749,79 +711,269 @@ func TestGetGuildmatesForCharactersEmpty(t *testing.T) {
}
func TestGetFriendsForCharacters(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
friends: []members{
{ID: 2, Name: "Friend1"},
{ID: 3, Name: "Friend2"},
},
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{
{ID: 1, Name: "Hunter1"},
}
mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"friends"}).AddRow("2,3"))
mock.ExpectQuery("SELECT id, name FROM characters WHERE id=2 OR id=3").
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
AddRow(2, "Friend1").
AddRow(3, "Friend2"))
friends := server.getFriendsForCharacters(chars)
if len(friends) != 2 {
t.Errorf("getFriendsForCharacters() = %d, want 2", len(friends))
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if friends[0].CID != 1 {
t.Errorf("friends[0].CID = %d, want 1", friends[0].CID)
}
}
func TestGetGuildmatesForCharacters(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
guildmates: []members{
{ID: 2, Name: "Guildmate1"},
{ID: 3, Name: "Guildmate2"},
},
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{
{ID: 1, Name: "Hunter1"},
}
mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1))
mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"guild_id"}).AddRow(100))
mock.ExpectQuery("SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=\\$1 AND character_id!=\\$2").
WithArgs(100, uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).
AddRow(2, "Guildmate1").
AddRow(3, "Guildmate2"))
guildmates := server.getGuildmatesForCharacters(chars)
if len(guildmates) != 2 {
t.Errorf("getGuildmatesForCharacters() = %d, want 2", len(guildmates))
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
if guildmates[0].CID != 1 {
t.Errorf("guildmates[0].CID = %d, want 1", guildmates[0].CID)
}
}
func TestGetGuildmatesNotInGuild(t *testing.T) {
server, mock := newTestServerWithMock(t)
charRepo := &mockSignCharacterRepo{
guildmates: nil,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{
{ID: 1, Name: "Hunter1"},
}
mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1").
WithArgs(uint32(1)).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0))
guildmates := server.getGuildmatesForCharacters(chars)
if len(guildmates) != 0 {
t.Errorf("getGuildmatesForCharacters() for non-guild member = %d, want 0", len(guildmates))
}
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("unfulfilled expectations: %v", err)
func TestValidateLoginSuccess(t *testing.T) {
// bcrypt hash for "password123"
hash := "$2a$10$N9qo8uLOickgx2ZMRZoMyeIjZAgcfl7p92ldGxad68LJZdL17lhWy"
userRepo := &mockSignUserRepo{
credUID: 1,
credPassword: hash,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
// Note: bcrypt verification will fail with this test hash since it's not a real hash of "password123"
// The important thing is testing the flow, not actual bcrypt verification
_, resp := server.validateLogin("testuser", "password123")
// This will return SIGN_EPASS since the hash doesn't match, which is expected behavior
if resp == SIGN_EABORT {
t.Error("validateLogin() should not abort for valid credentials lookup")
}
}
func TestValidateLoginUserNotFound(t *testing.T) {
userRepo := &mockSignUserRepo{
credErr: sql.ErrNoRows,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
_, resp := server.validateLogin("unknown", "password")
if resp != SIGN_EAUTH {
t.Errorf("validateLogin() for unknown user = %d, want SIGN_EAUTH(%d)", resp, SIGN_EAUTH)
}
}
func TestValidateLoginAutoCreate(t *testing.T) {
userRepo := &mockSignUserRepo{
credErr: sql.ErrNoRows,
registerUID: 42,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{
AutoCreateAccount: true,
},
userRepo: userRepo,
}
uid, resp := server.validateLogin("newuser", "password")
if resp != SIGN_SUCCESS {
t.Errorf("validateLogin() with auto-create = %d, want SIGN_SUCCESS(%d)", resp, SIGN_SUCCESS)
}
if uid != 42 {
t.Errorf("validateLogin() uid = %d, want 42", uid)
}
}
func TestValidateLoginDBError(t *testing.T) {
userRepo := &mockSignUserRepo{
credErr: sql.ErrConnDone,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
_, resp := server.validateLogin("testuser", "password")
if resp != SIGN_EABORT {
t.Errorf("validateLogin() on DB error = %d, want SIGN_EABORT(%d)", resp, SIGN_EABORT)
}
}
func TestValidateTokenValid(t *testing.T) {
sessionRepo := &mockSignSessionRepo{
validateResult: true,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
}
if !server.validateToken("validtoken", 0) {
t.Error("validateToken() should return true for valid token")
}
}
func TestValidateTokenInvalid(t *testing.T) {
sessionRepo := &mockSignSessionRepo{
validateResult: false,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
}
if server.validateToken("invalidtoken", 0) {
t.Error("validateToken() should return false for invalid token")
}
}
func TestValidateTokenDBError(t *testing.T) {
sessionRepo := &mockSignSessionRepo{
validateErr: sql.ErrConnDone,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
sessionRepo: sessionRepo,
}
if server.validateToken("token", 0) {
t.Error("validateToken() should return false on DB error")
}
}
func TestGetUserRightsZeroUID(t *testing.T) {
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
}
rights := server.getUserRights(0)
if rights != 0 {
t.Errorf("getUserRights(0) = %d, want 0", rights)
}
}
func TestGetUserRightsDBError(t *testing.T) {
userRepo := &mockSignUserRepo{
rightsErr: sql.ErrConnDone,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
userRepo: userRepo,
}
rights := server.getUserRights(1)
if rights != 0 {
t.Errorf("getUserRights() on error = %d, want 0", rights)
}
}
func TestGetFriendsForCharactersError(t *testing.T) {
charRepo := &mockSignCharacterRepo{
getFriendsErr: errMockDB,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{{ID: 1, Name: "Hunter1"}}
friends := server.getFriendsForCharacters(chars)
if len(friends) != 0 {
t.Errorf("getFriendsForCharacters() on error = %d, want 0", len(friends))
}
}
func TestGetGuildmatesForCharactersError(t *testing.T) {
charRepo := &mockSignCharacterRepo{
getGuildmatesErr: errMockDB,
}
server := &Server{
logger: zap.NewNop(),
erupeConfig: &cfg.Config{},
charRepo: charRepo,
}
chars := []character{{ID: 1, Name: "Hunter1"}}
guildmates := server.getGuildmatesForCharacters(chars)
if len(guildmates) != 0 {
t.Errorf("getGuildmatesForCharacters() on error = %d, want 0", len(guildmates))
}
}

View File

@@ -333,8 +333,10 @@ func (s *Session) makeSignResponse(uid uint32) []byte {
bf.WriteBytes(filters.Data())
if s.client == VITA || s.client == PS3 || s.client == PS4 {
var psnUser string
_ = s.server.db.QueryRow("SELECT psn_id FROM users WHERE id = $1", uid).Scan(&psnUser)
psnUser, err := s.server.userRepo.GetPSNIDForUser(uid)
if err != nil {
s.logger.Warn("Failed to get PSN ID for user", zap.Uint32("uid", uid), zap.Error(err))
}
bf.WriteBytes(stringsupport.PaddedString(psnUser, 20, true))
}

View File

@@ -4,12 +4,33 @@ import (
"fmt"
"strings"
"testing"
"time"
"go.uber.org/zap"
cfg "erupe-ce/config"
)
// newMakeSignResponseServer creates a Server with mock repos for makeSignResponse tests.
func newMakeSignResponseServer(config *cfg.Config) *Server {
return &Server{
erupeConfig: config,
logger: zap.NewNop(),
charRepo: &mockSignCharacterRepo{
characters: []character{},
friends: nil,
guildmates: nil,
},
userRepo: &mockSignUserRepo{
returnExpiry: time.Now().Add(time.Hour * 24 * 30),
lastLogin: time.Now(),
},
sessionRepo: &mockSignSessionRepo{
registerUIDTokenID: 1,
},
}
}
// TestMakeSignResponse_EmptyCapLinkValues verifies the crash is FIXED when CapLink.Values is empty
// Previously panicked: runtime error: index out of range [0] with length 0
// From erupe.log.1:659796 and 659853
@@ -37,10 +58,7 @@ func TestMakeSignResponse_EmptyCapLinkValues(t *testing.T) {
session := &Session{
logger: zap.NewNop(),
server: &Server{
erupeConfig: config,
logger: zap.NewNop(),
},
server: newMakeSignResponseServer(config),
client: PC100,
}
@@ -61,7 +79,7 @@ func TestMakeSignResponse_EmptyCapLinkValues(t *testing.T) {
// This should NOT panic on array bounds anymore
result := session.makeSignResponse(0)
if len(result) > 0 {
t.Log("makeSignResponse handled empty CapLink.Values without array bounds panic")
t.Log("makeSignResponse handled empty CapLink.Values without array bounds panic")
}
}
@@ -89,10 +107,7 @@ func TestMakeSignResponse_InsufficientCapLinkValues(t *testing.T) {
session := &Session{
logger: zap.NewNop(),
server: &Server{
erupeConfig: config,
logger: zap.NewNop(),
},
server: newMakeSignResponseServer(config),
client: PC100,
}
@@ -110,7 +125,7 @@ func TestMakeSignResponse_InsufficientCapLinkValues(t *testing.T) {
// This should NOT panic on array bounds anymore
result := session.makeSignResponse(0)
if len(result) > 0 {
t.Log("makeSignResponse handled insufficient CapLink.Values without array bounds panic")
t.Log("makeSignResponse handled insufficient CapLink.Values without array bounds panic")
}
}
@@ -138,10 +153,7 @@ func TestMakeSignResponse_MissingCapLinkValues234(t *testing.T) {
session := &Session{
logger: zap.NewNop(),
server: &Server{
erupeConfig: config,
logger: zap.NewNop(),
},
server: newMakeSignResponseServer(config),
client: PC100,
}
@@ -159,7 +171,7 @@ func TestMakeSignResponse_MissingCapLinkValues234(t *testing.T) {
// This should NOT panic on array bounds anymore
result := session.makeSignResponse(0)
if len(result) > 0 {
t.Log("makeSignResponse handled missing CapLink.Values[2/3/4] without array bounds panic")
t.Log("makeSignResponse handled missing CapLink.Values[2/3/4] without array bounds panic")
}
}
@@ -207,7 +219,47 @@ func TestCapLinkValuesBoundsChecking(t *testing.T) {
}
}
t.Logf("%s: All 5 indices accessible without panic", tc.name)
t.Logf("%s: All 5 indices accessible without panic", tc.name)
})
}
}
// TestMakeSignResponse_FullFlow tests the complete makeSignResponse with mock repos.
func TestMakeSignResponse_FullFlow(t *testing.T) {
config := &cfg.Config{
DebugOptions: cfg.DebugOptions{
CapLink: cfg.CapLinkOptions{
Values: []uint16{0, 0, 0, 0, 0},
},
},
GameplayOptions: cfg.GameplayOptions{
MezFesSoloTickets: 100,
MezFesGroupTickets: 100,
},
}
server := newMakeSignResponseServer(config)
// Give the server some characters
server.charRepo = &mockSignCharacterRepo{
characters: []character{
{ID: 1, Name: "TestHunter", HR: 100, GR: 50, WeaponType: 3, LastLogin: 1700000000},
},
}
conn := newMockConn()
session := &Session{
logger: zap.NewNop(),
server: server,
rawConn: conn,
client: PC100,
}
result := session.makeSignResponse(1)
if len(result) == 0 {
t.Error("makeSignResponse() returned empty result")
}
// First byte should be SIGN_SUCCESS
if result[0] != uint8(SIGN_SUCCESS) {
t.Errorf("makeSignResponse() first byte = %d, want %d (SIGN_SUCCESS)", result[0], SIGN_SUCCESS)
}
}

View File

@@ -0,0 +1,119 @@
package signserver
import (
"strings"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)
// SignCharacterRepository implements SignCharacterRepo with PostgreSQL.
type SignCharacterRepository struct {
db *sqlx.DB
}
// NewSignCharacterRepository creates a new SignCharacterRepository.
func NewSignCharacterRepository(db *sqlx.DB) *SignCharacterRepository {
return &SignCharacterRepository{db: db}
}
func (r *SignCharacterRepository) CountNewCharacters(uid uint32) (int, error) {
var count int
err := r.db.QueryRow("SELECT COUNT(*) FROM characters WHERE user_id = $1 AND is_new_character = true", uid).Scan(&count)
return count, err
}
func (r *SignCharacterRepository) CreateCharacter(uid uint32, lastLogin uint32) error {
_, err := r.db.Exec(`
INSERT INTO characters (
user_id, is_female, is_new_character, name, unk_desc_string,
hr, gr, weapon_type, last_login)
VALUES($1, False, True, '', '', 0, 0, 0, $2)`,
uid, lastLogin,
)
return err
}
func (r *SignCharacterRepository) GetForUser(uid uint32) ([]character, error) {
characters := make([]character, 0)
err := r.db.Select(&characters, "SELECT id, is_female, is_new_character, name, unk_desc_string, hr, gr, weapon_type, last_login FROM characters WHERE user_id = $1 AND deleted = false ORDER BY id", uid)
if err != nil {
return nil, err
}
return characters, nil
}
func (r *SignCharacterRepository) IsNewCharacter(cid int) (bool, error) {
var isNew bool
err := r.db.QueryRow("SELECT is_new_character FROM characters WHERE id = $1", cid).Scan(&isNew)
return isNew, err
}
func (r *SignCharacterRepository) HardDelete(cid int) error {
_, err := r.db.Exec("DELETE FROM characters WHERE id = $1", cid)
return err
}
func (r *SignCharacterRepository) SoftDelete(cid int) error {
_, err := r.db.Exec("UPDATE characters SET deleted = true WHERE id = $1", cid)
return err
}
// GetFriends returns friends for a character using parameterized queries
// (fixes the SQL injection vector from the original string-concatenated approach).
func (r *SignCharacterRepository) GetFriends(charID uint32) ([]members, error) {
var friendsCSV string
err := r.db.QueryRow("SELECT friends FROM characters WHERE id=$1", charID).Scan(&friendsCSV)
if err != nil {
return nil, err
}
if friendsCSV == "" {
return nil, nil
}
friendsSlice := strings.Split(friendsCSV, ",")
// Filter out empty strings
ids := make([]string, 0, len(friendsSlice))
for _, s := range friendsSlice {
s = strings.TrimSpace(s)
if s != "" {
ids = append(ids, s)
}
}
if len(ids) == 0 {
return nil, nil
}
// Use parameterized ANY($1) instead of string-concatenated WHERE id=X OR id=Y
friends := make([]members, 0)
err = r.db.Select(&friends, "SELECT id, name FROM characters WHERE id = ANY($1)", pq.Array(ids))
if err != nil {
return nil, err
}
return friends, nil
}
// GetGuildmates returns guildmates for a character.
func (r *SignCharacterRepository) GetGuildmates(charID uint32) ([]members, error) {
var inGuild int
err := r.db.QueryRow("SELECT count(*) FROM guild_characters WHERE character_id=$1", charID).Scan(&inGuild)
if err != nil {
return nil, err
}
if inGuild == 0 {
return nil, nil
}
var guildID int
err = r.db.QueryRow("SELECT guild_id FROM guild_characters WHERE character_id=$1", charID).Scan(&guildID)
if err != nil {
return nil, err
}
guildmates := make([]members, 0)
err = r.db.Select(&guildmates, "SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=$1 AND character_id!=$2", guildID, charID)
if err != nil {
return nil, err
}
return guildmates, nil
}

View File

@@ -0,0 +1,66 @@
package signserver
import "time"
// Repository interfaces decouple sign server business logic from concrete
// PostgreSQL implementations, enabling mock/stub injection for unit tests.
// character represents a player character record from the characters table.
type character struct {
ID uint32 `db:"id"`
IsFemale bool `db:"is_female"`
IsNewCharacter bool `db:"is_new_character"`
Name string `db:"name"`
UnkDescString string `db:"unk_desc_string"`
HR uint16 `db:"hr"`
GR uint16 `db:"gr"`
WeaponType uint16 `db:"weapon_type"`
LastLogin uint32 `db:"last_login"`
}
// members represents a friend or guildmate entry used in the sign response.
type members struct {
CID uint32 // Local character ID
ID uint32 `db:"id"`
Name string `db:"name"`
}
// SignUserRepo defines the contract for user-related data access (users, bans tables).
type SignUserRepo interface {
GetCredentials(username string) (uid uint32, passwordHash string, err error)
Register(username, passwordHash string, returnExpires time.Time) (uint32, error)
GetRights(uid uint32) (uint32, error)
GetLastCharacter(uid uint32) (uint32, error)
GetLastLogin(uid uint32) (time.Time, error)
GetReturnExpiry(uid uint32) (time.Time, error)
UpdateReturnExpiry(uid uint32, expiry time.Time) error
UpdateLastLogin(uid uint32, loginTime time.Time) error
CountPermanentBans(uid uint32) (int, error)
CountActiveBans(uid uint32) (int, error)
GetByWiiUKey(wiiuKey string) (uint32, error)
GetByPSNID(psnID string) (uint32, error)
CountByPSNID(psnID string) (int, error)
GetPSNIDForUsername(username string) (string, error)
SetPSNID(username, psnID string) error
GetPSNIDForUser(uid uint32) (string, error)
}
// SignCharacterRepo defines the contract for character data access.
type SignCharacterRepo interface {
CountNewCharacters(uid uint32) (int, error)
CreateCharacter(uid uint32, lastLogin uint32) error
GetForUser(uid uint32) ([]character, error)
IsNewCharacter(cid int) (bool, error)
HardDelete(cid int) error
SoftDelete(cid int) error
GetFriends(charID uint32) ([]members, error)
GetGuildmates(charID uint32) ([]members, error)
}
// SignSessionRepo defines the contract for sign session/token data access.
type SignSessionRepo interface {
RegisterUID(uid uint32, token string) (tokenID uint32, err error)
RegisterPSN(psnID, token string) (tokenID uint32, err error)
Validate(token string, tokenID uint32) (bool, error)
GetPSNIDByToken(token string) (string, error)
}

View File

@@ -0,0 +1,263 @@
package signserver
import (
"errors"
"time"
)
// errMockDB is a sentinel for mock repo error injection.
var errMockDB = errors.New("mock database error")
// --- mockSignUserRepo ---
type mockSignUserRepo struct {
// GetCredentials
credUID uint32
credPassword string
credErr error
// Register
registerUID uint32
registerErr error
registered bool
// GetRights
rights uint32
rightsErr error
// GetLastCharacter
lastCharacter uint32
lastCharacterErr error
// GetLastLogin
lastLogin time.Time
lastLoginErr error
// GetReturnExpiry
returnExpiry time.Time
returnExpiryErr error
// UpdateReturnExpiry
updateReturnExpiryErr error
updateReturnExpiryCalled bool
// UpdateLastLogin
updateLastLoginErr error
updateLastLoginCalled bool
// CountPermanentBans
permanentBans int
permanentBansErr error
// CountActiveBans
activeBans int
activeBansErr error
// GetByWiiUKey
wiiuUID uint32
wiiuErr error
// GetByPSNID
psnUID uint32
psnErr error
// CountByPSNID
psnCount int
psnCountErr error
// GetPSNIDForUsername
psnIDForUsername string
psnIDForUsernameErr error
// SetPSNID
setPSNIDErr error
setPSNIDCalled bool
// GetPSNIDForUser
psnIDForUser string
psnIDForUserErr error
}
func (m *mockSignUserRepo) GetCredentials(username string) (uint32, string, error) {
return m.credUID, m.credPassword, m.credErr
}
func (m *mockSignUserRepo) Register(username, passwordHash string, returnExpires time.Time) (uint32, error) {
m.registered = true
return m.registerUID, m.registerErr
}
func (m *mockSignUserRepo) GetRights(uid uint32) (uint32, error) {
return m.rights, m.rightsErr
}
func (m *mockSignUserRepo) GetLastCharacter(uid uint32) (uint32, error) {
return m.lastCharacter, m.lastCharacterErr
}
func (m *mockSignUserRepo) GetLastLogin(uid uint32) (time.Time, error) {
return m.lastLogin, m.lastLoginErr
}
func (m *mockSignUserRepo) GetReturnExpiry(uid uint32) (time.Time, error) {
return m.returnExpiry, m.returnExpiryErr
}
func (m *mockSignUserRepo) UpdateReturnExpiry(uid uint32, expiry time.Time) error {
m.updateReturnExpiryCalled = true
return m.updateReturnExpiryErr
}
func (m *mockSignUserRepo) UpdateLastLogin(uid uint32, loginTime time.Time) error {
m.updateLastLoginCalled = true
return m.updateLastLoginErr
}
func (m *mockSignUserRepo) CountPermanentBans(uid uint32) (int, error) {
return m.permanentBans, m.permanentBansErr
}
func (m *mockSignUserRepo) CountActiveBans(uid uint32) (int, error) {
return m.activeBans, m.activeBansErr
}
func (m *mockSignUserRepo) GetByWiiUKey(wiiuKey string) (uint32, error) {
return m.wiiuUID, m.wiiuErr
}
func (m *mockSignUserRepo) GetByPSNID(psnID string) (uint32, error) {
return m.psnUID, m.psnErr
}
func (m *mockSignUserRepo) CountByPSNID(psnID string) (int, error) {
return m.psnCount, m.psnCountErr
}
func (m *mockSignUserRepo) GetPSNIDForUsername(username string) (string, error) {
return m.psnIDForUsername, m.psnIDForUsernameErr
}
func (m *mockSignUserRepo) SetPSNID(username, psnID string) error {
m.setPSNIDCalled = true
return m.setPSNIDErr
}
func (m *mockSignUserRepo) GetPSNIDForUser(uid uint32) (string, error) {
return m.psnIDForUser, m.psnIDForUserErr
}
// --- mockSignCharacterRepo ---
type mockSignCharacterRepo struct {
// CountNewCharacters
newCharCount int
newCharCountErr error
// CreateCharacter
createErr error
createCalled bool
// GetForUser
characters []character
getForUserErr error
// IsNewCharacter
isNew bool
isNewErr error
// HardDelete
hardDeleteErr error
hardDeleteCalled bool
// SoftDelete
softDeleteErr error
softDeleteCalled bool
// GetFriends
friends []members
getFriendsErr error
// GetGuildmates
guildmates []members
getGuildmatesErr error
}
func (m *mockSignCharacterRepo) CountNewCharacters(uid uint32) (int, error) {
return m.newCharCount, m.newCharCountErr
}
func (m *mockSignCharacterRepo) CreateCharacter(uid uint32, lastLogin uint32) error {
m.createCalled = true
return m.createErr
}
func (m *mockSignCharacterRepo) GetForUser(uid uint32) ([]character, error) {
return m.characters, m.getForUserErr
}
func (m *mockSignCharacterRepo) IsNewCharacter(cid int) (bool, error) {
return m.isNew, m.isNewErr
}
func (m *mockSignCharacterRepo) HardDelete(cid int) error {
m.hardDeleteCalled = true
return m.hardDeleteErr
}
func (m *mockSignCharacterRepo) SoftDelete(cid int) error {
m.softDeleteCalled = true
return m.softDeleteErr
}
func (m *mockSignCharacterRepo) GetFriends(charID uint32) ([]members, error) {
return m.friends, m.getFriendsErr
}
func (m *mockSignCharacterRepo) GetGuildmates(charID uint32) ([]members, error) {
return m.guildmates, m.getGuildmatesErr
}
// --- mockSignSessionRepo ---
type mockSignSessionRepo struct {
// RegisterUID
registerUIDTokenID uint32
registerUIDErr error
// RegisterPSN
registerPSNTokenID uint32
registerPSNErr error
// Validate
validateResult bool
validateErr error
// GetPSNIDByToken
psnIDByToken string
psnIDByTokenErr error
}
func (m *mockSignSessionRepo) RegisterUID(uid uint32, token string) (uint32, error) {
return m.registerUIDTokenID, m.registerUIDErr
}
func (m *mockSignSessionRepo) RegisterPSN(psnID, token string) (uint32, error) {
return m.registerPSNTokenID, m.registerPSNErr
}
func (m *mockSignSessionRepo) Validate(token string, tokenID uint32) (bool, error) {
return m.validateResult, m.validateErr
}
func (m *mockSignSessionRepo) GetPSNIDByToken(token string) (string, error) {
return m.psnIDByToken, m.psnIDByTokenErr
}
// newTestServer creates a Server with mock repos for testing.
func newTestServer(userRepo SignUserRepo, charRepo SignCharacterRepo, sessionRepo SignSessionRepo) *Server {
return &Server{
userRepo: userRepo,
charRepo: charRepo,
sessionRepo: sessionRepo,
}
}

View File

@@ -0,0 +1,44 @@
package signserver
import "github.com/jmoiron/sqlx"
// SignSessionRepository implements SignSessionRepo with PostgreSQL.
type SignSessionRepository struct {
db *sqlx.DB
}
// NewSignSessionRepository creates a new SignSessionRepository.
func NewSignSessionRepository(db *sqlx.DB) *SignSessionRepository {
return &SignSessionRepository{db: db}
}
func (r *SignSessionRepository) RegisterUID(uid uint32, token string) (uint32, error) {
var tid uint32
err := r.db.QueryRow(`INSERT INTO sign_sessions (user_id, token) VALUES ($1, $2) RETURNING id`, uid, token).Scan(&tid)
return tid, err
}
func (r *SignSessionRepository) RegisterPSN(psnID, token string) (uint32, error) {
var tid uint32
err := r.db.QueryRow(`INSERT INTO sign_sessions (psn_id, token) VALUES ($1, $2) RETURNING id`, psnID, token).Scan(&tid)
return tid, err
}
func (r *SignSessionRepository) Validate(token string, tokenID uint32) (bool, error) {
query := `SELECT count(*) FROM sign_sessions WHERE token = $1`
if tokenID > 0 {
query += ` AND id = $2`
}
var exists int
err := r.db.QueryRow(query, token, tokenID).Scan(&exists)
if err != nil {
return false, err
}
return exists > 0, nil
}
func (r *SignSessionRepository) GetPSNIDByToken(token string) (string, error) {
var psnID string
err := r.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psnID)
return psnID, err
}

View File

@@ -0,0 +1,114 @@
package signserver
import (
"time"
"github.com/jmoiron/sqlx"
)
// SignUserRepository implements SignUserRepo with PostgreSQL.
type SignUserRepository struct {
db *sqlx.DB
}
// NewSignUserRepository creates a new SignUserRepository.
func NewSignUserRepository(db *sqlx.DB) *SignUserRepository {
return &SignUserRepository{db: db}
}
func (r *SignUserRepository) GetCredentials(username string) (uint32, string, error) {
var uid uint32
var passwordHash string
err := r.db.QueryRow(`SELECT id, password FROM users WHERE username = $1`, username).Scan(&uid, &passwordHash)
return uid, passwordHash, err
}
func (r *SignUserRepository) Register(username, passwordHash string, returnExpires time.Time) (uint32, error) {
var uid uint32
err := r.db.QueryRow(
"INSERT INTO users (username, password, return_expires) VALUES ($1, $2, $3) RETURNING id",
username, passwordHash, returnExpires,
).Scan(&uid)
return uid, err
}
func (r *SignUserRepository) GetRights(uid uint32) (uint32, error) {
var rights uint32
err := r.db.QueryRow("SELECT rights FROM users WHERE id=$1", uid).Scan(&rights)
return rights, err
}
func (r *SignUserRepository) GetLastCharacter(uid uint32) (uint32, error) {
var lastPlayed uint32
err := r.db.QueryRow("SELECT last_character FROM users WHERE id=$1", uid).Scan(&lastPlayed)
return lastPlayed, err
}
func (r *SignUserRepository) GetLastLogin(uid uint32) (time.Time, error) {
var lastLogin time.Time
err := r.db.Get(&lastLogin, "SELECT COALESCE(last_login, now()) FROM users WHERE id=$1", uid)
return lastLogin, err
}
func (r *SignUserRepository) GetReturnExpiry(uid uint32) (time.Time, error) {
var expiry time.Time
err := r.db.Get(&expiry, "SELECT return_expires FROM users WHERE id=$1", uid)
return expiry, err
}
func (r *SignUserRepository) UpdateReturnExpiry(uid uint32, expiry time.Time) error {
_, err := r.db.Exec("UPDATE users SET return_expires=$1 WHERE id=$2", expiry, uid)
return err
}
func (r *SignUserRepository) UpdateLastLogin(uid uint32, loginTime time.Time) error {
_, err := r.db.Exec("UPDATE users SET last_login=$1 WHERE id=$2", loginTime, uid)
return err
}
func (r *SignUserRepository) CountPermanentBans(uid uint32) (int, error) {
var count int
err := r.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires IS NULL`, uid).Scan(&count)
return count, err
}
func (r *SignUserRepository) CountActiveBans(uid uint32) (int, error) {
var count int
err := r.db.QueryRow(`SELECT count(*) FROM bans WHERE user_id=$1 AND expires > now()`, uid).Scan(&count)
return count, err
}
func (r *SignUserRepository) GetByWiiUKey(wiiuKey string) (uint32, error) {
var uid uint32
err := r.db.QueryRow(`SELECT id FROM users WHERE wiiu_key = $1`, wiiuKey).Scan(&uid)
return uid, err
}
func (r *SignUserRepository) GetByPSNID(psnID string) (uint32, error) {
var uid uint32
err := r.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, psnID).Scan(&uid)
return uid, err
}
func (r *SignUserRepository) CountByPSNID(psnID string) (int, error) {
var count int
err := r.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psnID).Scan(&count)
return count, err
}
func (r *SignUserRepository) GetPSNIDForUsername(username string) (string, error) {
var psnID string
err := r.db.QueryRow(`SELECT COALESCE(psn_id, '') FROM users WHERE username = $1`, username).Scan(&psnID)
return psnID, err
}
func (r *SignUserRepository) SetPSNID(username, psnID string) error {
_, err := r.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psnID, username)
return err
}
func (r *SignUserRepository) GetPSNIDForUser(uid uint32) (string, error) {
var psnID string
err := r.db.QueryRow("SELECT psn_id FROM users WHERE id = $1", uid).Scan(&psnID)
return psnID, err
}

View File

@@ -115,8 +115,7 @@ func (s *Session) authenticate(username string, password string) {
func (s *Session) handleWIIUSGN(bf *byteframe.ByteFrame) {
_ = bf.ReadBytes(1)
wiiuKey := string(bf.ReadBytes(64))
var uid uint32
err := s.server.db.QueryRow(`SELECT id FROM users WHERE wiiu_key = $1`, wiiuKey).Scan(&uid)
uid, err := s.server.userRepo.GetByWiiUKey(wiiuKey)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Info("Unlinked Wii U attempted to authenticate", zap.String("Key", wiiuKey))
@@ -142,8 +141,7 @@ func (s *Session) handlePSSGN(bf *byteframe.ByteFrame) {
_ = bf.ReadBytes(82)
}
s.psn = string(bf.ReadNullTerminatedBytes())
var uid uint32
err := s.server.db.QueryRow(`SELECT id FROM users WHERE psn_id = $1`, s.psn).Scan(&uid)
uid, err := s.server.userRepo.GetByPSNID(s.psn)
if err != nil {
if err == sql.ErrNoRows {
_ = s.cryptConn.SendPacket(s.makeSignResponse(0))
@@ -159,19 +157,17 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) {
_ = bf.ReadNullTerminatedBytes() // Client ID
credStr, _ := stringsupport.SJISToUTF8(bf.ReadNullTerminatedBytes())
credentials := strings.Split(credStr, "\n")
token := string(bf.ReadNullTerminatedBytes())
tok := string(bf.ReadNullTerminatedBytes())
uid, resp := s.server.validateLogin(credentials[0], credentials[1])
if resp == SIGN_SUCCESS && uid > 0 {
var psn string
err := s.server.db.QueryRow(`SELECT psn_id FROM sign_sessions WHERE token = $1`, token).Scan(&psn)
psn, err := s.server.sessionRepo.GetPSNIDByToken(tok)
if err != nil {
s.sendCode(SIGN_ECOGLINK)
return
}
// Since we check for the psn_id, this will never run
var exists int
err = s.server.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psn).Scan(&exists)
exists, err := s.server.userRepo.CountByPSNID(psn)
if err != nil {
s.sendCode(SIGN_ECOGLINK)
return
@@ -180,8 +176,7 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) {
return
}
var currentPSN string
err = s.server.db.QueryRow(`SELECT COALESCE(psn_id, '') FROM users WHERE username = $1`, credentials[0]).Scan(&currentPSN)
currentPSN, err := s.server.userRepo.GetPSNIDForUsername(credentials[0])
if err != nil {
s.sendCode(SIGN_ECOGLINK)
return
@@ -190,7 +185,7 @@ func (s *Session) handlePSNLink(bf *byteframe.ByteFrame) {
return
}
_, err = s.server.db.Exec(`UPDATE users SET psn_id = $1 WHERE username = $2`, psn, credentials[0])
err = s.server.userRepo.SetPSNID(credentials[0], psn)
if err == nil {
s.sendCode(SIGN_SUCCESS)
return

View File

@@ -24,7 +24,9 @@ type Server struct {
sync.Mutex
logger *zap.Logger
erupeConfig *cfg.Config
db *sqlx.DB
userRepo SignUserRepo
charRepo SignCharacterRepo
sessionRepo SignSessionRepo
listener net.Listener
isShuttingDown bool
}
@@ -34,7 +36,11 @@ func NewServer(config *Config) *Server {
s := &Server{
logger: config.Logger,
erupeConfig: config.ErupeConfig,
db: config.DB,
}
if config.DB != nil {
s.userRepo = NewSignUserRepository(config.DB)
s.charRepo = NewSignCharacterRepository(config.DB)
s.sessionRepo = NewSignSessionRepository(config.DB)
}
return s
}