diff --git a/server/channelserver/handlers_character.go b/server/channelserver/handlers_character.go index a0cf83348..3199ff37b 100644 --- a/server/channelserver/handlers_character.go +++ b/server/channelserver/handlers_character.go @@ -1,6 +1,7 @@ package channelserver import ( + "database/sql" "errors" cfg "erupe-ce/config" @@ -11,26 +12,23 @@ import ( // GetCharacterSaveData loads a character's save data from the database. func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error) { - result, err := s.server.db.Query("SELECT id, savedata, is_new_character, name FROM characters WHERE id = $1", charID) + id, savedata, isNew, name, err := s.server.charRepo.LoadSaveData(charID) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + s.logger.Error("No savedata found", zap.Uint32("charID", charID)) + return nil, errors.New("no savedata found") + } s.logger.Error("Failed to get savedata", zap.Error(err), zap.Uint32("charID", charID)) return nil, err } - defer func() { _ = result.Close() }() - if !result.Next() { - err = errors.New("no savedata found") - s.logger.Error("No savedata found", zap.Uint32("charID", charID)) - return nil, err - } saveData := &CharacterSaveData{ - Mode: s.server.erupeConfig.RealClientMode, - Pointers: getPointers(s.server.erupeConfig.RealClientMode), - } - err = result.Scan(&saveData.CharID, &saveData.compSave, &saveData.IsNewCharacter, &saveData.Name) - if err != nil { - s.logger.Error("Failed to scan savedata", zap.Error(err), zap.Uint32("charID", charID)) - return nil, err + CharID: id, + compSave: savedata, + IsNewCharacter: isNew, + Name: name, + Mode: s.server.erupeConfig.RealClientMode, + Pointers: getPointers(s.server.erupeConfig.RealClientMode), } if saveData.compSave == nil { diff --git a/server/channelserver/handlers_commands.go b/server/channelserver/handlers_commands.go index 7d7adcf6a..57407f7e1 100644 --- a/server/channelserver/handlers_commands.go +++ b/server/channelserver/handlers_commands.go @@ -104,14 +104,12 @@ func parseChatCommand(s *Session, command string) { uid, uname, err := s.server.userRepo.GetByIDAndUsername(cid) if err == nil { if expiry.IsZero() { - if _, err := s.server.db.Exec(`INSERT INTO bans VALUES ($1) - ON CONFLICT (user_id) DO UPDATE SET expires=NULL`, uid); err != nil { + if err := s.server.userRepo.BanUser(uid, nil); err != nil { s.logger.Error("Failed to ban user", zap.Error(err)) } sendServerChatMessage(s, fmt.Sprintf(s.server.i18n.commands.ban.success, uname)) } else { - if _, err := s.server.db.Exec(`INSERT INTO bans VALUES ($1, $2) - ON CONFLICT (user_id) DO UPDATE SET expires=$2`, uid, expiry); err != nil { + if err := s.server.userRepo.BanUser(uid, &expiry); err != nil { s.logger.Error("Failed to ban user with expiry", zap.Error(err)) } sendServerChatMessage(s, fmt.Sprintf(s.server.i18n.commands.ban.success, uname)+fmt.Sprintf(s.server.i18n.commands.ban.length, expiry.Format(time.DateTime))) diff --git a/server/channelserver/handlers_quest.go b/server/channelserver/handlers_quest.go index f9ae36256..ca546bab2 100644 --- a/server/channelserver/handlers_quest.go +++ b/server/channelserver/handlers_quest.go @@ -348,10 +348,10 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) { bf := byteframe.NewByteFrame() bf.WriteUint16(0) - rows, err := s.server.db.Query("SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1), start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id") + rows, err := s.server.eventRepo.GetEventQuests() if err == nil { currentTime := time.Now() - tx, err := s.server.db.Begin() + tx, err := s.server.eventRepo.BeginTx() if err != nil { s.logger.Error("Failed to begin transaction for event quests", zap.Error(err)) _ = rows.Close() @@ -385,7 +385,7 @@ func handleMsgMhfEnumerateQuest(s *Session, p mhfpacket.MHFPacket) { // Normalize rotationTime to 12PM JST to align with the in-game events update notification. newRotationTime := time.Date(rotationTime.Year(), rotationTime.Month(), rotationTime.Day(), 12, 0, 0, 0, TimeAdjusted().Location()) - _, err = tx.Exec("UPDATE event_quests SET start_time = $1 WHERE id = $2", newRotationTime, id) + err = s.server.eventRepo.UpdateEventQuestStartTime(tx, id, newRotationTime) if err != nil { _ = tx.Rollback() break diff --git a/server/channelserver/repo_character.go b/server/channelserver/repo_character.go index c118911bd..5d954a7ed 100644 --- a/server/channelserver/repo_character.go +++ b/server/channelserver/repo_character.go @@ -224,3 +224,15 @@ func (r *CharacterRepository) SaveHouseData(charID uint32, houseTier []byte, hou houseTier, houseData, bookshelf, gallery, tore, garden, charID) return err } + +// LoadSaveData reads the core save columns for a character. +// Returns charID, savedata, isNewCharacter, name, and any error. +func (r *CharacterRepository) LoadSaveData(charID uint32) (uint32, []byte, bool, string, error) { + var id uint32 + var savedata []byte + var isNew bool + var name string + err := r.db.QueryRow("SELECT id, savedata, is_new_character, name FROM characters WHERE id = $1", charID). + Scan(&id, &savedata, &isNew, &name) + return id, savedata, isNew, name, err +} diff --git a/server/channelserver/repo_event.go b/server/channelserver/repo_event.go index 9bfb00efa..77ee96791 100644 --- a/server/channelserver/repo_event.go +++ b/server/channelserver/repo_event.go @@ -1,6 +1,7 @@ package channelserver import ( + "database/sql" "time" "github.com/jmoiron/sqlx" @@ -45,3 +46,19 @@ func (r *EventRepository) UpdateLoginBoost(charID uint32, weekReq uint8, expirat _, err := r.db.Exec(`UPDATE login_boost SET expiration=$1, reset=$2 WHERE char_id=$3 AND week_req=$4`, expiration, reset, charID, weekReq) return err } + +// GetEventQuests returns all event quest rows ordered by quest_id. +func (r *EventRepository) GetEventQuests() (*sql.Rows, error) { + return r.db.Query("SELECT id, COALESCE(max_players, 4) AS max_players, quest_type, quest_id, COALESCE(mark, 0) AS mark, COALESCE(flags, -1), start_time, COALESCE(active_days, 0) AS active_days, COALESCE(inactive_days, 0) AS inactive_days FROM event_quests ORDER BY quest_id") +} + +// UpdateEventQuestStartTime updates the start_time for an event quest within a transaction. +func (r *EventRepository) UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error { + _, err := tx.Exec("UPDATE event_quests SET start_time = $1 WHERE id = $2", startTime, id) + return err +} + +// BeginTx starts a new database transaction. +func (r *EventRepository) BeginTx() (*sql.Tx, error) { + return r.db.Begin() +} diff --git a/server/channelserver/repo_interfaces.go b/server/channelserver/repo_interfaces.go index fc28bb705..a5bf64494 100644 --- a/server/channelserver/repo_interfaces.go +++ b/server/channelserver/repo_interfaces.go @@ -41,6 +41,7 @@ type CharacterRepo interface { FindByRastaID(rastaID int) (charID uint32, name string, err error) SaveCharacterData(charID uint32, compSave []byte, hr, gr uint16, isFemale bool, weaponType uint8, weaponID uint16) error SaveHouseData(charID uint32, houseTier []byte, houseData, bookshelf, gallery, tore, garden []byte) error + LoadSaveData(charID uint32) (uint32, []byte, bool, string, error) } // GuildRepo defines the contract for guild data access. @@ -141,6 +142,7 @@ type UserRepo interface { LinkDiscord(discordID string, token string) (string, error) SetPasswordByDiscordID(discordID string, hash []byte) error GetByIDAndUsername(charID uint32) (userID uint32, username string, err error) + BanUser(userID uint32, expires *time.Time) error } // GachaRepo defines the contract for gacha system data access. @@ -271,6 +273,9 @@ type EventRepo interface { GetLoginBoosts(charID uint32) (*sqlx.Rows, error) InsertLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error UpdateLoginBoost(charID uint32, weekReq uint8, expiration, reset time.Time) error + GetEventQuests() (*sql.Rows, error) + UpdateEventQuestStartTime(tx *sql.Tx, id uint32, startTime time.Time) error + BeginTx() (*sql.Tx, error) } // AchievementRepo defines the contract for achievement data access. diff --git a/server/channelserver/repo_mocks_test.go b/server/channelserver/repo_mocks_test.go index a84386347..f30f96cd2 100644 --- a/server/channelserver/repo_mocks_test.go +++ b/server/channelserver/repo_mocks_test.go @@ -194,6 +194,7 @@ func (m *mockCharacterRepo) UpdateGCPAndPact(_ uint32, _ uint32, _ uint32) error func (m *mockCharacterRepo) FindByRastaID(_ int) (uint32, string, error) { return 0, "", nil } func (m *mockCharacterRepo) SaveCharacterData(_ uint32, _ []byte, _, _ uint16, _ bool, _ uint8, _ uint16) error { return nil } func (m *mockCharacterRepo) SaveHouseData(_ uint32, _ []byte, _, _, _, _, _ []byte) error { return nil } +func (m *mockCharacterRepo) LoadSaveData(_ uint32) (uint32, []byte, bool, string, error) { return 0, nil, false, "", nil } // --- mockGoocooRepo --- diff --git a/server/channelserver/repo_user.go b/server/channelserver/repo_user.go index b3bc7b003..919f33dd1 100644 --- a/server/channelserver/repo_user.go +++ b/server/channelserver/repo_user.go @@ -2,6 +2,7 @@ package channelserver import ( "database/sql" + "time" "github.com/jmoiron/sqlx" ) @@ -218,3 +219,16 @@ func (r *UserRepository) GetByIDAndUsername(charID uint32) (userID uint32, usern ).Scan(&userID, &username) return } + +// BanUser inserts or updates a ban for the given user. +// A nil expires means a permanent ban; non-nil sets a temporary ban with expiry. +func (r *UserRepository) BanUser(userID uint32, expires *time.Time) error { + if expires == nil { + _, err := r.db.Exec(`INSERT INTO bans VALUES ($1) + ON CONFLICT (user_id) DO UPDATE SET expires=NULL`, userID) + return err + } + _, err := r.db.Exec(`INSERT INTO bans VALUES ($1, $2) + ON CONFLICT (user_id) DO UPDATE SET expires=$2`, userID, *expires) + return err +}