From bd8e30d570c05997ec57cbb65d15c14e3e7e0437 Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Sat, 21 Feb 2026 14:37:29 +0100 Subject: [PATCH] refactor(channelserver): eliminate *sql.Rows from EventRepo.GetEventQuests Return []EventQuest instead of a raw database cursor, removing the last *sql.Rows leak from the repository layer. The handler now iterates a slice, and makeEventQuest reads fields from the struct directly instead of scanning rows twice. This makes the method fully mockable and eliminates the risk of unclosed cursors. --- server/channelserver/handlers_quest.go | 65 +++++++++---------------- server/channelserver/repo_event.go | 19 +++++++- server/channelserver/repo_event_test.go | 58 ++++++++++------------ server/channelserver/repo_interfaces.go | 2 +- 4 files changed, 66 insertions(+), 78 deletions(-) diff --git a/server/channelserver/handlers_quest.go b/server/channelserver/handlers_quest.go index ca546bab2..10e6611de 100644 --- a/server/channelserver/handlers_quest.go +++ b/server/channelserver/handlers_quest.go @@ -1,7 +1,6 @@ package channelserver import ( - "database/sql" "encoding/binary" "erupe-ce/common/byteframe" "erupe-ce/common/decryption" @@ -264,25 +263,17 @@ func loadQuestFile(s *Session, questId int) []byte { return result } -func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) { - var id, mark uint32 - var questId, activeDuration, inactiveDuration, flags int - var maxPlayers, questType uint8 - var startTime time.Time - if err := rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDuration, &inactiveDuration); err != nil { - return nil, fmt.Errorf("failed to scan event quest row: %w", err) - } - - data := loadQuestFile(s, questId) +func makeEventQuest(s *Session, eq EventQuest) ([]byte, error) { + data := loadQuestFile(s, eq.QuestID) if data == nil { - return nil, fmt.Errorf("failed to load quest file (%d)", questId) + return nil, fmt.Errorf("failed to load quest file (%d)", eq.QuestID) } bf := byteframe.NewByteFrame() - bf.WriteUint32(id) + bf.WriteUint32(eq.ID) bf.WriteUint32(0) // Unk bf.WriteUint8(0) // Unk - switch questType { + switch eq.QuestType { case QuestTypeRegularRaviente: bf.WriteUint8(s.server.erupeConfig.GameplayOptions.RegularRavienteMaxPlayers) case QuestTypeViolentRaviente: @@ -294,17 +285,17 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) { case QuestTypeSmallBerserkRavi: bf.WriteUint8(s.server.erupeConfig.GameplayOptions.SmallBerserkRavienteMaxPlayers) default: - bf.WriteUint8(maxPlayers) + bf.WriteUint8(eq.MaxPlayers) } - bf.WriteUint8(questType) - if questType == QuestTypeSpecialTool { + bf.WriteUint8(eq.QuestType) + if eq.QuestType == QuestTypeSpecialTool { bf.WriteBool(false) } else { bf.WriteBool(true) } bf.WriteUint16(0) // Unk if s.server.erupeConfig.RealClientMode >= cfg.G2 { - bf.WriteUint32(mark) + bf.WriteUint32(eq.Mark) } bf.WriteUint16(0) // Unk bf.WriteUint16(uint16(len(data))) @@ -320,10 +311,10 @@ func makeEventQuest(s *Session, rows *sql.Rows) ([]byte, error) { bf.WriteUint8(flagByte & 0b11100000) } else { // Allow for seasons to be specified in database, otherwise use the one in the file. - if flags < 0 { + if eq.Flags < 0 { bf.WriteUint8(flagByte) } else { - bf.WriteUint8(uint8(flags)) + bf.WriteUint8(uint8(eq.Flags)) } } @@ -348,59 +339,48 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) { bf := byteframe.NewByteFrame() bf.WriteUint16(0) - rows, err := s.server.eventRepo.GetEventQuests() + quests, err := s.server.eventRepo.GetEventQuests() if err == nil { currentTime := time.Now() tx, err := s.server.eventRepo.BeginTx() if err != nil { s.logger.Error("Failed to begin transaction for event quests", zap.Error(err)) - _ = rows.Close() doAckBufSucceed(s, pkt.AckHandle, bf.Data()) return } - for rows.Next() { - var id, mark uint32 - var questId, flags, activeDays, inactiveDays int - var maxPlayers, questType uint8 - var startTime time.Time - - err = rows.Scan(&id, &maxPlayers, &questType, &questId, &mark, &flags, &startTime, &activeDays, &inactiveDays) - if err != nil { - s.logger.Error("Failed to scan event quest row", zap.Error(err)) - continue - } - + for i, eq := range quests { // Use the Event Cycling system - if activeDays > 0 { - cycleLength := (time.Duration(activeDays) + time.Duration(inactiveDays)) * 24 * time.Hour + if eq.ActiveDays > 0 { + cycleLength := (time.Duration(eq.ActiveDays) + time.Duration(eq.InactiveDays)) * 24 * time.Hour // Count the number of full cycles elapsed since the last rotation. - extraCycles := int(currentTime.Sub(startTime) / cycleLength) + extraCycles := int(currentTime.Sub(eq.StartTime) / cycleLength) if extraCycles > 0 { // Calculate the rotation time based on start time, active duration, and inactive duration. - rotationTime := startTime.Add(time.Duration(activeDays+inactiveDays) * 24 * time.Hour * time.Duration(extraCycles)) + rotationTime := eq.StartTime.Add(time.Duration(eq.ActiveDays+eq.InactiveDays) * 24 * time.Hour * time.Duration(extraCycles)) if currentTime.After(rotationTime) { // Normalize rotationTime to 12PM JST to align with the in-game events update notification. newRotationTime := time.Date(rotationTime.Year(), rotationTime.Month(), rotationTime.Day(), 12, 0, 0, 0, TimeAdjusted().Location()) - err = s.server.eventRepo.UpdateEventQuestStartTime(tx, id, newRotationTime) + err = s.server.eventRepo.UpdateEventQuestStartTime(tx, eq.ID, newRotationTime) if err != nil { _ = tx.Rollback() break } - startTime = newRotationTime // Set the new start time so the quest can be used/removed immediately. + quests[i].StartTime = newRotationTime // Set the new start time so the quest can be used/removed immediately. + eq = quests[i] } } // Check if the quest is currently active - if currentTime.Before(startTime) || currentTime.After(startTime.Add(time.Duration(activeDays)*24*time.Hour)) { + if currentTime.Before(eq.StartTime) || currentTime.After(eq.StartTime.Add(time.Duration(eq.ActiveDays)*24*time.Hour)) { continue } } - data, err := makeEventQuest(s, rows) + data, err := makeEventQuest(s, eq) if err != nil { s.logger.Error("Failed to make event quest", zap.Error(err)) continue @@ -419,7 +399,6 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) { } } - _ = rows.Close() _ = tx.Commit() } diff --git a/server/channelserver/repo_event.go b/server/channelserver/repo_event.go index dd1154fd1..a8929dfd4 100644 --- a/server/channelserver/repo_event.go +++ b/server/channelserver/repo_event.go @@ -7,6 +7,19 @@ import ( "github.com/jmoiron/sqlx" ) +// EventQuest represents a row from the event_quests table. +type EventQuest struct { + ID uint32 `db:"id"` + MaxPlayers uint8 `db:"max_players"` + QuestType uint8 `db:"quest_type"` + QuestID int `db:"quest_id"` + Mark uint32 `db:"mark"` + Flags int `db:"flags"` + StartTime time.Time `db:"start_time"` + ActiveDays int `db:"active_days"` + InactiveDays int `db:"inactive_days"` +} + // EventRepository centralizes all database access for event-related tables. type EventRepository struct { db *sqlx.DB @@ -50,8 +63,10 @@ func (r *EventRepository) UpdateLoginBoost(charID uint32, weekReq uint8, expirat } // GetEventQuests returns all event quest rows ordered by quest_id. -func (r *EventRepository) GetEventQuests() (*sql.Rows, error) { - return r.db.Query("SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1), start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id") +func (r *EventRepository) GetEventQuests() ([]EventQuest, error) { + var result []EventQuest + err := r.db.Select(&result, "SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1) AS flags, start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id") + return result, err } // UpdateEventQuestStartTime updates the start_time for an event quest within a transaction. diff --git a/server/channelserver/repo_event_test.go b/server/channelserver/repo_event_test.go index 0c8b60499..f9c5a14fe 100644 --- a/server/channelserver/repo_event_test.go +++ b/server/channelserver/repo_event_test.go @@ -32,14 +32,13 @@ func insertEventQuest(t *testing.T, db *sqlx.DB, questType, questID int, startTi func TestGetEventQuestsEmpty(t *testing.T) { repo, _ := setupEventRepo(t) - rows, err := repo.GetEventQuests() + quests, err := repo.GetEventQuests() if err != nil { t.Fatalf("GetEventQuests failed: %v", err) } - defer rows.Close() - if rows.Next() { - t.Error("Expected no rows for empty event_quests table") + if len(quests) != 0 { + t.Errorf("Expected no quests for empty event_quests table, got: %d", len(quests)) } } @@ -50,25 +49,28 @@ func TestGetEventQuestsReturnsRows(t *testing.T) { insertEventQuest(t, db, 1, 100, now, 0, 0) insertEventQuest(t, db, 2, 200, now, 7, 3) - rows, err := repo.GetEventQuests() + quests, err := repo.GetEventQuests() if err != nil { t.Fatalf("GetEventQuests failed: %v", err) } - defer rows.Close() - count := 0 - for rows.Next() { - var id, mark uint32 - var questID, flags, activeDays, inactiveDays int - var maxPlayers, questType uint8 - var startTime time.Time - if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil { - t.Fatalf("Scan failed: %v", err) - } - count++ + if len(quests) != 2 { + t.Errorf("Expected 2 quests, got: %d", len(quests)) } - if count != 2 { - t.Errorf("Expected 2 rows, got: %d", count) + if quests[0].QuestID != 100 { + t.Errorf("Expected first quest ID 100, got: %d", quests[0].QuestID) + } + if quests[1].QuestID != 200 { + t.Errorf("Expected second quest ID 200, got: %d", quests[1].QuestID) + } + if quests[0].QuestType != 1 { + t.Errorf("Expected first quest type 1, got: %d", quests[0].QuestType) + } + if quests[1].ActiveDays != 7 { + t.Errorf("Expected second quest active_days 7, got: %d", quests[1].ActiveDays) + } + if quests[1].InactiveDays != 3 { + t.Errorf("Expected second quest inactive_days 3, got: %d", quests[1].InactiveDays) } } @@ -80,25 +82,17 @@ func TestGetEventQuestsOrderByQuestID(t *testing.T) { insertEventQuest(t, db, 1, 100, now, 0, 0) insertEventQuest(t, db, 1, 200, now, 0, 0) - rows, err := repo.GetEventQuests() + quests, err := repo.GetEventQuests() if err != nil { t.Fatalf("GetEventQuests failed: %v", err) } - defer rows.Close() - var questIDs []int - for rows.Next() { - var id, mark uint32 - var questID, flags, activeDays, inactiveDays int - var maxPlayers, questType uint8 - var startTime time.Time - if err := rows.Scan(&id, &maxPlayers, &questType, &questID, &mark, &flags, &startTime, &activeDays, &inactiveDays); err != nil { - t.Fatalf("Scan failed: %v", err) + if len(quests) != 3 || quests[0].QuestID != 100 || quests[1].QuestID != 200 || quests[2].QuestID != 300 { + ids := make([]int, len(quests)) + for i, q := range quests { + ids[i] = q.QuestID } - questIDs = append(questIDs, questID) - } - if len(questIDs) != 3 || questIDs[0] != 100 || questIDs[1] != 200 || questIDs[2] != 300 { - t.Errorf("Expected quest IDs [100, 200, 300], got: %v", questIDs) + t.Errorf("Expected quest IDs [100, 200, 300], got: %v", ids) } } diff --git a/server/channelserver/repo_interfaces.go b/server/channelserver/repo_interfaces.go index c7802ba38..c0cc52a3c 100644 --- a/server/channelserver/repo_interfaces.go +++ b/server/channelserver/repo_interfaces.go @@ -271,7 +271,7 @@ type EventRepo interface { GetLoginBoosts(charID uint32) ([]loginBoost, error) InsertLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error UpdateLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error - GetEventQuests() (*sql.Rows, error) + GetEventQuests() ([]EventQuest, error) UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error BeginTx() (*sql.Tx, error) }