refactor(channelserver): eliminate *sql.Rows from EventRepo.GetEventQuests

Return []EventQuest instead of a raw database cursor, removing the last
*sql.Rows leak from the repository layer. The handler now iterates a
slice, and makeEventQuest reads fields from the struct directly instead
of scanning rows twice. This makes the method fully mockable and
eliminates the risk of unclosed cursors.
This commit is contained in:
Houmgaor
2026-02-21 14:37:29 +01:00
parent f2f5696a22
commit bd8e30d570
4 changed files with 66 additions and 78 deletions

View File

@@ -1,7 +1,6 @@
package channelserver package channelserver
import ( import (
"database/sql"
"encoding/binary" "encoding/binary"
"erupe-ce/common/byteframe" "erupe-ce/common/byteframe"
"erupe-ce/common/decryption" "erupe-ce/common/decryption"
@@ -264,25 +263,17 @@ func loadQuestFile(s *Session, questId int) []byte {
return result return result
} }
func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) { func makeEventQuest(s *Session, eq EventQuest) ([]byte, error) {
var id, mark uint32 data := loadQuestFile(s, eq.QuestID)
var questId, activeDuration, inactiveDuration, flags int
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDuration, &inactiveDuration); err != nil {
return nil, fmt.Errorf("failed to scan event quest row: %w", err)
}
data := loadQuestFile(s, questId)
if data == nil { if data == nil {
return nil, fmt.Errorf("failed to load quest file (%d)", questId) return nil, fmt.Errorf("failed to load quest file (%d)", eq.QuestID)
} }
bf := byteframe.NewByteFrame() bf := byteframe.NewByteFrame()
bf.WriteUint32(id) bf.WriteUint32(eq.ID)
bf.WriteUint32(0) // Unk bf.WriteUint32(0) // Unk
bf.WriteUint8(0) // Unk bf.WriteUint8(0) // Unk
switch questType { switch eq.QuestType {
case QuestTypeRegularRaviente: case QuestTypeRegularRaviente:
bf.WriteUint8(s.server.erupeConfig.GameplayOptions.RegularRavienteMaxPlayers) bf.WriteUint8(s.server.erupeConfig.GameplayOptions.RegularRavienteMaxPlayers)
case QuestTypeViolentRaviente: case QuestTypeViolentRaviente:
@@ -294,17 +285,17 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) {
case QuestTypeSmallBerserkRavi: case QuestTypeSmallBerserkRavi:
bf.WriteUint8(s.server.erupeConfig.GameplayOptions.SmallBerserkRavienteMaxPlayers) bf.WriteUint8(s.server.erupeConfig.GameplayOptions.SmallBerserkRavienteMaxPlayers)
default: default:
bf.WriteUint8(maxPlayers) bf.WriteUint8(eq.MaxPlayers)
} }
bf.WriteUint8(questType) bf.WriteUint8(eq.QuestType)
if questType == QuestTypeSpecialTool { if eq.QuestType == QuestTypeSpecialTool {
bf.WriteBool(false) bf.WriteBool(false)
} else { } else {
bf.WriteBool(true) bf.WriteBool(true)
} }
bf.WriteUint16(0) // Unk bf.WriteUint16(0) // Unk
if s.server.erupeConfig.RealClientMode >= cfg.G2 { if s.server.erupeConfig.RealClientMode >= cfg.G2 {
bf.WriteUint32(mark) bf.WriteUint32(eq.Mark)
} }
bf.WriteUint16(0) // Unk bf.WriteUint16(0) // Unk
bf.WriteUint16(uint16(len(data))) bf.WriteUint16(uint16(len(data)))
@@ -320,10 +311,10 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) {
bf.WriteUint8(flagByte & 0b11100000) bf.WriteUint8(flagByte & 0b11100000)
} else { } else {
// Allow for seasons to be specified in database, otherwise use the one in the file. // Allow for seasons to be specified in database, otherwise use the one in the file.
if flags < 0 { if eq.Flags < 0 {
bf.WriteUint8(flagByte) bf.WriteUint8(flagByte)
} else { } else {
bf.WriteUint8(uint8(flags)) bf.WriteUint8(uint8(eq.Flags))
} }
} }
@@ -348,59 +339,48 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) {
bf := byteframe.NewByteFrame() bf := byteframe.NewByteFrame()
bf.WriteUint16(0) bf.WriteUint16(0)
rows, err := s.server.eventRepo.GetEventQuests() quests, err := s.server.eventRepo.GetEventQuests()
if err == nil { if err == nil {
currentTime := time.Now() currentTime := time.Now()
tx, err := s.server.eventRepo.BeginTx() tx, err := s.server.eventRepo.BeginTx()
if err != nil { if err != nil {
s.logger.Error("Failed to begin transaction for event quests", zap.Error(err)) s.logger.Error("Failed to begin transaction for event quests", zap.Error(err))
_ = rows.Close()
doAckBufSucceed(s, pkt.AckHandle, bf.Data()) doAckBufSucceed(s, pkt.AckHandle, bf.Data())
return return
} }
for rows.Next() { for i, eq := range quests {
var id, mark uint32
var questId, flags, activeDays, inactiveDays int
var maxPlayers, questType uint8
var startTime time.Time
err = rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDays, &inactiveDays)
if err != nil {
s.logger.Error("Failed to scan event quest row", zap.Error(err))
continue
}
// Use the Event Cycling system // Use the Event Cycling system
if activeDays > 0 { if eq.ActiveDays > 0 {
cycleLength := (time.Duration(activeDays) + time.Duration(inactiveDays)) * 24 * time.Hour cycleLength := (time.Duration(eq.ActiveDays) + time.Duration(eq.InactiveDays)) * 24 * time.Hour
// Count the number of full cycles elapsed since the last rotation. // Count the number of full cycles elapsed since the last rotation.
extraCycles := int(currentTime.Sub(startTime) / cycleLength) extraCycles := int(currentTime.Sub(eq.StartTime) / cycleLength)
if extraCycles > 0 { if extraCycles > 0 {
// Calculate the rotation time based on start time, active duration, and inactive duration. // Calculate the rotation time based on start time, active duration, and inactive duration.
rotationTime := startTime.Add(time.Duration(activeDays+inactiveDays) * 24 * time.Hour * time.Duration(extraCycles)) rotationTime := eq.StartTime.Add(time.Duration(eq.ActiveDays+eq.InactiveDays) * 24 * time.Hour * time.Duration(extraCycles))
if currentTime.After(rotationTime) { if currentTime.After(rotationTime) {
// Normalize rotationTime to 12PM JST to align with the in-game events update notification. // Normalize rotationTime to 12PM JST to align with the in-game events update notification.
newRotationTime := time.Date(rotationTime.Year(), rotationTime.Month(), rotationTime.Day(), 12, 0, 0, 0, TimeAdjusted().Location()) newRotationTime := time.Date(rotationTime.Year(), rotationTime.Month(), rotationTime.Day(), 12, 0, 0, 0, TimeAdjusted().Location())
err = s.server.eventRepo.UpdateEventQuestStartTime(tx, id, newRotationTime) err = s.server.eventRepo.UpdateEventQuestStartTime(tx, eq.ID, newRotationTime)
if err != nil { if err != nil {
_ = tx.Rollback() _ = tx.Rollback()
break break
} }
startTime = newRotationTime // Set the new start time so the quest can be used/removed immediately. quests[i].StartTime = newRotationTime // Set the new start time so the quest can be used/removed immediately.
eq = quests[i]
} }
} }
// Check if the quest is currently active // Check if the quest is currently active
if currentTime.Before(startTime) || currentTime.After(startTime.Add(time.Duration(activeDays)*24*time.Hour)) { if currentTime.Before(eq.StartTime) || currentTime.After(eq.StartTime.Add(time.Duration(eq.ActiveDays)*24*time.Hour)) {
continue continue
} }
} }
data, err := makeEventQuest(s, rows) data, err := makeEventQuest(s, eq)
if err != nil { if err != nil {
s.logger.Error("Failed to make event quest", zap.Error(err)) s.logger.Error("Failed to make event quest", zap.Error(err))
continue continue
@@ -419,7 +399,6 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) {
} }
} }
_ = rows.Close()
_ = tx.Commit() _ = tx.Commit()
} }

View File

@@ -7,6 +7,19 @@ import (
"github.com/jmoiron/sqlx" "github.com/jmoiron/sqlx"
) )
// EventQuest represents a row from the event_quests table.
type EventQuest struct {
ID uint32 `db:"id"`
MaxPlayers uint8 `db:"max_players"`
QuestType uint8 `db:"quest_type"`
QuestID int `db:"quest_id"`
Mark uint32 `db:"mark"`
Flags int `db:"flags"`
StartTime time.Time `db:"start_time"`
ActiveDays int `db:"active_days"`
InactiveDays int `db:"inactive_days"`
}
// EventRepository centralizes all database access for event-related tables. // EventRepository centralizes all database access for event-related tables.
type EventRepository struct { type EventRepository struct {
db *sqlx.DB db *sqlx.DB
@@ -50,8 +63,10 @@ func (r *EventRepository) UpdateLoginBoost(charID uint32, weekReq uint8, expirat
} }
// GetEventQuests returns all event quest rows ordered by quest_id. // GetEventQuests returns all event quest rows ordered by quest_id.
func (r *EventRepository) GetEventQuests() (*sql.Rows, error) { func (r *EventRepository) GetEventQuests() ([]EventQuest, error) {
return r.db.Query("SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1), start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id") var result []EventQuest
err := r.db.Select(&result, "SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1) AS flags, start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id")
return result, err
} }
// UpdateEventQuestStartTime updates the start_time for an event quest within a transaction. // UpdateEventQuestStartTime updates the start_time for an event quest within a transaction.

View File

@@ -32,14 +32,13 @@ func insertEventQuest(t *testing.T, db *sqlx.DB, questType, questID int, startTi
func TestGetEventQuestsEmpty(t *testing.T) { func TestGetEventQuestsEmpty(t *testing.T) {
repo, _ := setupEventRepo(t) repo, _ := setupEventRepo(t)
rows, err := repo.GetEventQuests() quests, err := repo.GetEventQuests()
if err != nil { if err != nil {
t.Fatalf("GetEventQuests failed: %v", err) t.Fatalf("GetEventQuests failed: %v", err)
} }
defer rows.Close()
if rows.Next() { if len(quests) != 0 {
t.Error("Expected no rows for empty event_quests table") t.Errorf("Expected no quests for empty event_quests table, got: %d", len(quests))
} }
} }
@@ -50,25 +49,28 @@ func TestGetEventQuestsReturnsRows(t *testing.T) {
insertEventQuest(t, db, 1, 100, now, 0, 0) insertEventQuest(t, db, 1, 100, now, 0, 0)
insertEventQuest(t, db, 2, 200, now, 7, 3) insertEventQuest(t, db, 2, 200, now, 7, 3)
rows, err := repo.GetEventQuests() quests, err := repo.GetEventQuests()
if err != nil { if err != nil {
t.Fatalf("GetEventQuests failed: %v", err) t.Fatalf("GetEventQuests failed: %v", err)
} }
defer rows.Close()
count := 0 if len(quests) != 2 {
for rows.Next() { t.Errorf("Expected 2 quests, got: %d", len(quests))
var id, mark uint32
var questID, flags, activeDays, inactiveDays int
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil {
t.Fatalf("Scan failed: %v", err)
}
count++
} }
if count != 2 { if quests[0].QuestID != 100 {
t.Errorf("Expected 2 rows, got: %d", count) t.Errorf("Expected first quest ID 100, got: %d", quests[0].QuestID)
}
if quests[1].QuestID != 200 {
t.Errorf("Expected second quest ID 200, got: %d", quests[1].QuestID)
}
if quests[0].QuestType != 1 {
t.Errorf("Expected first quest type 1, got: %d", quests[0].QuestType)
}
if quests[1].ActiveDays != 7 {
t.Errorf("Expected second quest active_days 7, got: %d", quests[1].ActiveDays)
}
if quests[1].InactiveDays != 3 {
t.Errorf("Expected second quest inactive_days 3, got: %d", quests[1].InactiveDays)
} }
} }
@@ -80,25 +82,17 @@ func TestGetEventQuestsOrderByQuestID(t *testing.T) {
insertEventQuest(t, db, 1, 100, now, 0, 0) insertEventQuest(t, db, 1, 100, now, 0, 0)
insertEventQuest(t, db, 1, 200, now, 0, 0) insertEventQuest(t, db, 1, 200, now, 0, 0)
rows, err := repo.GetEventQuests() quests, err := repo.GetEventQuests()
if err != nil { if err != nil {
t.Fatalf("GetEventQuests failed: %v", err) t.Fatalf("GetEventQuests failed: %v", err)
} }
defer rows.Close()
var questIDs []int if len(quests) != 3 || quests[0].QuestID != 100 || quests[1].QuestID != 200 || quests[2].QuestID != 300 {
for rows.Next() { ids := make([]int, len(quests))
var id, mark uint32 for i, q := range quests {
var questID, flags, activeDays, inactiveDays int ids[i] = q.QuestID
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil {
t.Fatalf("Scan failed: %v", err)
} }
questIDs = append(questIDs, questID) t.Errorf("Expected quest IDs [100, 200, 300], got: %v", ids)
}
if len(questIDs) != 3 || questIDs[0] != 100 || questIDs[1] != 200 || questIDs[2] != 300 {
t.Errorf("Expected quest IDs [100, 200, 300], got: %v", questIDs)
} }
} }

View File

@@ -271,7 +271,7 @@ type EventRepo interface {
GetLoginBoosts(charID uint32) ([]loginBoost, error) GetLoginBoosts(charID uint32) ([]loginBoost, error)
InsertLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error InsertLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error
UpdateLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error UpdateLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error
GetEventQuests() (*sql.Rows, error) GetEventQuests() ([]EventQuest, error)
UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error
BeginTx() (*sql.Tx, error) BeginTx() (*sql.Tx, error)
} }