refactor(channelserver): eliminate *sql.Rows from EventRepo.GetEventQuests

Return []EventQuest instead of a raw database cursor, removing the last
*sql.Rows leak from the repository layer. The handler now iterates a
slice, and makeEventQuest reads fields from the struct directly instead
of scanning rows twice. This makes the method fully mockable and
eliminates the risk of unclosed cursors.
This commit is contained in:
Houmgaor
2026-02-21 14:37:29 +01:00
parent f2f5696a22
commit bd8e30d570
4 changed files with 66 additions and 78 deletions

View File

@@ -1,7 +1,6 @@
package channelserver
import (
"database/sql"
"encoding/binary"
"erupe-ce/common/byteframe"
"erupe-ce/common/decryption"
@@ -264,25 +263,17 @@ func loadQuestFile(s *Session, questId int) []byte {
return result
}
func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) {
var id, mark uint32
var questId, activeDuration, inactiveDuration, flags int
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDuration, &inactiveDuration); err != nil {
return nil, fmt.Errorf("failed to scan event quest row: %w", err)
}
data := loadQuestFile(s, questId)
func makeEventQuest(s *Session, eq EventQuest) ([]byte, error) {
data := loadQuestFile(s, eq.QuestID)
if data == nil {
return nil, fmt.Errorf("failed to load quest file (%d)", questId)
return nil, fmt.Errorf("failed to load quest file (%d)", eq.QuestID)
}
bf := byteframe.NewByteFrame()
bf.WriteUint32(id)
bf.WriteUint32(eq.ID)
bf.WriteUint32(0) // Unk
bf.WriteUint8(0) // Unk
switch questType {
switch eq.QuestType {
case QuestTypeRegularRaviente:
bf.WriteUint8(s.server.erupeConfig.GameplayOptions.RegularRavienteMaxPlayers)
case QuestTypeViolentRaviente:
@@ -294,17 +285,17 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) {
case QuestTypeSmallBerserkRavi:
bf.WriteUint8(s.server.erupeConfig.GameplayOptions.SmallBerserkRavienteMaxPlayers)
default:
bf.WriteUint8(maxPlayers)
bf.WriteUint8(eq.MaxPlayers)
}
bf.WriteUint8(questType)
if questType == QuestTypeSpecialTool {
bf.WriteUint8(eq.QuestType)
if eq.QuestType == QuestTypeSpecialTool {
bf.WriteBool(false)
} else {
bf.WriteBool(true)
}
bf.WriteUint16(0) // Unk
if s.server.erupeConfig.RealClientMode >= cfg.G2 {
bf.WriteUint32(mark)
bf.WriteUint32(eq.Mark)
}
bf.WriteUint16(0) // Unk
bf.WriteUint16(uint16(len(data)))
@@ -320,10 +311,10 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) {
bf.WriteUint8(flagByte & 0b11100000)
} else {
// Allow for seasons to be specified in database, otherwise use the one in the file.
if flags < 0 {
if eq.Flags < 0 {
bf.WriteUint8(flagByte)
} else {
bf.WriteUint8(uint8(flags))
bf.WriteUint8(uint8(eq.Flags))
}
}
@@ -348,59 +339,48 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) {
bf := byteframe.NewByteFrame()
bf.WriteUint16(0)
rows, err := s.server.eventRepo.GetEventQuests()
quests, err := s.server.eventRepo.GetEventQuests()
if err == nil {
currentTime := time.Now()
tx, err := s.server.eventRepo.BeginTx()
if err != nil {
s.logger.Error("Failed to begin transaction for event quests", zap.Error(err))
_ = rows.Close()
doAckBufSucceed(s, pkt.AckHandle, bf.Data())
return
}
for rows.Next() {
var id, mark uint32
var questId, flags, activeDays, inactiveDays int
var maxPlayers, questType uint8
var startTime time.Time
err = rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDays, &inactiveDays)
if err != nil {
s.logger.Error("Failed to scan event quest row", zap.Error(err))
continue
}
for i, eq := range quests {
// Use the Event Cycling system
if activeDays > 0 {
cycleLength := (time.Duration(activeDays) + time.Duration(inactiveDays)) * 24 * time.Hour
if eq.ActiveDays > 0 {
cycleLength := (time.Duration(eq.ActiveDays) + time.Duration(eq.InactiveDays)) * 24 * time.Hour
// Count the number of full cycles elapsed since the last rotation.
extraCycles := int(currentTime.Sub(startTime) / cycleLength)
extraCycles := int(currentTime.Sub(eq.StartTime) / cycleLength)
if extraCycles > 0 {
// Calculate the rotation time based on start time, active duration, and inactive duration.
rotationTime := startTime.Add(time.Duration(activeDays+inactiveDays) * 24 * time.Hour * time.Duration(extraCycles))
rotationTime := eq.StartTime.Add(time.Duration(eq.ActiveDays+eq.InactiveDays) * 24 * time.Hour * time.Duration(extraCycles))
if currentTime.After(rotationTime) {
// Normalize rotationTime to 12PM JST to align with the in-game events update notification.
newRotationTime := time.Date(rotationTime.Year(), rotationTime.Month(), rotationTime.Day(), 12, 0, 0, 0, TimeAdjusted().Location())
err = s.server.eventRepo.UpdateEventQuestStartTime(tx, id, newRotationTime)
err = s.server.eventRepo.UpdateEventQuestStartTime(tx, eq.ID, newRotationTime)
if err != nil {
_ = tx.Rollback()
break
}
startTime = newRotationTime // Set the new start time so the quest can be used/removed immediately.
quests[i].StartTime = newRotationTime // Set the new start time so the quest can be used/removed immediately.
eq = quests[i]
}
}
// Check if the quest is currently active
if currentTime.Before(startTime) || currentTime.After(startTime.Add(time.Duration(activeDays)*24*time.Hour)) {
if currentTime.Before(eq.StartTime) || currentTime.After(eq.StartTime.Add(time.Duration(eq.ActiveDays)*24*time.Hour)) {
continue
}
}
data, err := makeEventQuest(s, rows)
data, err := makeEventQuest(s, eq)
if err != nil {
s.logger.Error("Failed to make event quest", zap.Error(err))
continue
@@ -419,7 +399,6 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) {
}
}
_ = rows.Close()
_ = tx.Commit()
}

View File

@@ -7,6 +7,19 @@ import (
"github.com/jmoiron/sqlx"
)
// EventQuest represents a row from the event_quests table.
type EventQuest struct {
ID uint32 `db:"id"`
MaxPlayers uint8 `db:"max_players"`
QuestType uint8 `db:"quest_type"`
QuestID int `db:"quest_id"`
Mark uint32 `db:"mark"`
Flags int `db:"flags"`
StartTime time.Time `db:"start_time"`
ActiveDays int `db:"active_days"`
InactiveDays int `db:"inactive_days"`
}
// EventRepository centralizes all database access for event-related tables.
type EventRepository struct {
db *sqlx.DB
@@ -50,8 +63,10 @@ func (r *EventRepository) UpdateLoginBoost(charID uint32, weekReq uint8, expirat
}
// GetEventQuests returns all event quest rows ordered by quest_id.
func (r *EventRepository) GetEventQuests() (*sql.Rows, error) {
return r.db.Query("SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1), start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id")
func (r *EventRepository) GetEventQuests() ([]EventQuest, error) {
var result []EventQuest
err := r.db.Select(&result, "SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1) AS flags, start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id")
return result, err
}
// UpdateEventQuestStartTime updates the start_time for an event quest within a transaction.

View File

@@ -32,14 +32,13 @@ func insertEventQuest(t *testing.T, db *sqlx.DB, questType, questID int, startTi
func TestGetEventQuestsEmpty(t *testing.T) {
repo, _ := setupEventRepo(t)
rows, err := repo.GetEventQuests()
quests, err := repo.GetEventQuests()
if err != nil {
t.Fatalf("GetEventQuests failed: %v", err)
}
defer rows.Close()
if rows.Next() {
t.Error("Expected no rows for empty event_quests table")
if len(quests) != 0 {
t.Errorf("Expected no quests for empty event_quests table, got: %d", len(quests))
}
}
@@ -50,25 +49,28 @@ func TestGetEventQuestsReturnsRows(t *testing.T) {
insertEventQuest(t, db, 1, 100, now, 0, 0)
insertEventQuest(t, db, 2, 200, now, 7, 3)
rows, err := repo.GetEventQuests()
quests, err := repo.GetEventQuests()
if err != nil {
t.Fatalf("GetEventQuests failed: %v", err)
}
defer rows.Close()
count := 0
for rows.Next() {
var id, mark uint32
var questID, flags, activeDays, inactiveDays int
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil {
t.Fatalf("Scan failed: %v", err)
}
count++
if len(quests) != 2 {
t.Errorf("Expected 2 quests, got: %d", len(quests))
}
if count != 2 {
t.Errorf("Expected 2 rows, got: %d", count)
if quests[0].QuestID != 100 {
t.Errorf("Expected first quest ID 100, got: %d", quests[0].QuestID)
}
if quests[1].QuestID != 200 {
t.Errorf("Expected second quest ID 200, got: %d", quests[1].QuestID)
}
if quests[0].QuestType != 1 {
t.Errorf("Expected first quest type 1, got: %d", quests[0].QuestType)
}
if quests[1].ActiveDays != 7 {
t.Errorf("Expected second quest active_days 7, got: %d", quests[1].ActiveDays)
}
if quests[1].InactiveDays != 3 {
t.Errorf("Expected second quest inactive_days 3, got: %d", quests[1].InactiveDays)
}
}
@@ -80,25 +82,17 @@ func TestGetEventQuestsOrderByQuestID(t *testing.T) {
insertEventQuest(t, db, 1, 100, now, 0, 0)
insertEventQuest(t, db, 1, 200, now, 0, 0)
rows, err := repo.GetEventQuests()
quests, err := repo.GetEventQuests()
if err != nil {
t.Fatalf("GetEventQuests failed: %v", err)
}
defer rows.Close()
var questIDs []int
for rows.Next() {
var id, mark uint32
var questID, flags, activeDays, inactiveDays int
var maxPlayers, questType uint8
var startTime time.Time
if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil {
t.Fatalf("Scan failed: %v", err)
if len(quests) != 3 || quests[0].QuestID != 100 || quests[1].QuestID != 200 || quests[2].QuestID != 300 {
ids := make([]int, len(quests))
for i, q := range quests {
ids[i] = q.QuestID
}
questIDs = append(questIDs, questID)
}
if len(questIDs) != 3 || questIDs[0] != 100 || questIDs[1] != 200 || questIDs[2] != 300 {
t.Errorf("Expected quest IDs [100, 200, 300], got: %v", questIDs)
t.Errorf("Expected quest IDs [100, 200, 300], got: %v", ids)
}
}

View File

@@ -271,7 +271,7 @@ type EventRepo interface {
GetLoginBoosts(charID uint32) ([]loginBoost, error)
InsertLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error
UpdateLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error
GetEventQuests() (*sql.Rows, error)
GetEventQuests() ([]EventQuest, error)
UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error
BeginTx() (*sql.Tx, error)
}