diff --git a/server/channelserver/handlers_distitem.go b/server/channelserver/handlers_distitem.go index 901e24593..3db86b39e 100644 --- a/server/channelserver/handlers_distitem.go +++ b/server/channelserver/handlers_distitem.go @@ -31,32 +31,8 @@ type Distribution struct { func handleMsgMhfEnumerateDistItem(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfEnumerateDistItem) - var itemDists []Distribution bf := byteframe.NewByteFrame() - rows, err := s.server.db.Queryx(` - SELECT d.id, event_name, description, COALESCE(rights, 0) AS rights, COALESCE(selection, false) AS selection, times_acceptable, - COALESCE(min_hr, -1) AS min_hr, COALESCE(max_hr, -1) AS max_hr, - COALESCE(min_sr, -1) AS min_sr, COALESCE(max_sr, -1) AS max_sr, - COALESCE(min_gr, -1) AS min_gr, COALESCE(max_gr, -1) AS max_gr, - ( - SELECT count(*) FROM distributions_accepted da - WHERE d.id = da.distribution_id AND da.character_id = $1 - ) AS times_accepted, - COALESCE(deadline, TO_TIMESTAMP(0)) AS deadline - FROM distribution d - WHERE character_id = $1 AND type = $2 OR character_id IS NULL AND type = $2 ORDER BY id DESC - `, s.charID, pkt.DistType) - - if err == nil { - var itemDist Distribution - for rows.Next() { - err = rows.StructScan(&itemDist) - if err != nil { - continue - } - itemDists = append(itemDists, itemDist) - } - } + itemDists, _ := s.server.distRepo.List(s.charID, pkt.DistType) bf.WriteUint16(uint16(len(itemDists))) for _, dist := range itemDists { @@ -128,27 +104,11 @@ type DistributionItem struct { Quantity uint32 `db:"quantity"` } -func getDistributionItems(s *Session, i uint32) []DistributionItem { - var distItems []DistributionItem - rows, err := s.server.db.Queryx(`SELECT id, item_type, COALESCE(item_id, 0) AS item_id, COALESCE(quantity, 0) AS quantity FROM distribution_items WHERE distribution_id=$1`, i) - if err == nil { - var distItem DistributionItem - for rows.Next() { - err = rows.StructScan(&distItem) - if err != nil { - continue - } - distItems = append(distItems, distItem) - } - } - return distItems -} - func handleMsgMhfApplyDistItem(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfApplyDistItem) bf := byteframe.NewByteFrame() bf.WriteUint32(pkt.DistributionID) - distItems := getDistributionItems(s, pkt.DistributionID) + distItems, _ := s.server.distRepo.GetItems(pkt.DistributionID) bf.WriteUint16(uint16(len(distItems))) for _, item := range distItems { bf.WriteUint8(item.ItemType) @@ -164,9 +124,9 @@ func handleMsgMhfApplyDistItem(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfAcquireDistItem(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfAcquireDistItem) if pkt.DistributionID > 0 { - _, err := s.server.db.Exec(`INSERT INTO public.distributions_accepted VALUES ($1, $2)`, pkt.DistributionID, s.charID) + err := s.server.distRepo.RecordAccepted(pkt.DistributionID, s.charID) if err == nil { - distItems := getDistributionItems(s, pkt.DistributionID) + distItems, _ := s.server.distRepo.GetItems(pkt.DistributionID) for _, item := range distItems { switch item.ItemType { case 17: @@ -198,8 +158,7 @@ func handleMsgMhfAcquireDistItem(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfGetDistDescription(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfGetDistDescription) - var desc string - err := s.server.db.QueryRow("SELECT description FROM distribution WHERE id = $1", pkt.DistributionID).Scan(&desc) + desc, err := s.server.distRepo.GetDescription(pkt.DistributionID) if err != nil { s.logger.Error("Error parsing item distribution description", zap.Error(err)) doAckBufSucceed(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_items.go b/server/channelserver/handlers_items.go index 7a5afafbb..46838492e 100644 --- a/server/channelserver/handlers_items.go +++ b/server/channelserver/handlers_items.go @@ -6,8 +6,6 @@ import ( "erupe-ce/common/mhfmon" _config "erupe-ce/config" "erupe-ce/network/mhfpacket" - "fmt" - "time" "go.uber.org/zap" ) @@ -230,27 +228,26 @@ func handleMsgMhfCheckWeeklyStamp(s *Session, p mhfpacket.MHFPacket) { return } var total, redeemed, updated uint16 - var lastCheck time.Time - err := s.server.db.QueryRow(fmt.Sprintf("SELECT %s_checked FROM stamps WHERE character_id=$1", pkt.StampType), s.charID).Scan(&lastCheck) + lastCheck, err := s.server.stampRepo.GetChecked(s.charID, pkt.StampType) if err != nil { lastCheck = TimeAdjusted() - if _, err := s.server.db.Exec("INSERT INTO stamps (character_id, hl_checked, ex_checked) VALUES ($1, $2, $2)", s.charID, TimeAdjusted()); err != nil { + if err := s.server.stampRepo.Init(s.charID, TimeAdjusted()); err != nil { s.logger.Error("Failed to insert stamps record", zap.Error(err)) } } else { - if _, err := s.server.db.Exec(fmt.Sprintf(`UPDATE stamps SET %s_checked=$1 WHERE character_id=$2`, pkt.StampType), TimeAdjusted(), s.charID); err != nil { + if err := s.server.stampRepo.SetChecked(s.charID, pkt.StampType, TimeAdjusted()); err != nil { s.logger.Error("Failed to update stamp check time", zap.Error(err)) } } if lastCheck.Before(TimeWeekStart()) { - if _, err := s.server.db.Exec(fmt.Sprintf("UPDATE stamps SET %s_total=%s_total+1 WHERE character_id=$1", pkt.StampType, pkt.StampType), s.charID); err != nil { + if err := s.server.stampRepo.IncrementTotal(s.charID, pkt.StampType); err != nil { s.logger.Error("Failed to increment stamp total", zap.Error(err)) } updated = 1 } - _ = s.server.db.QueryRow(fmt.Sprintf("SELECT %s_total, %s_redeemed FROM stamps WHERE character_id=$1", pkt.StampType, pkt.StampType), s.charID).Scan(&total, &redeemed) + total, redeemed, _ = s.server.stampRepo.GetTotals(s.charID, pkt.StampType) bf := byteframe.NewByteFrame() bf.WriteUint16(total) bf.WriteUint16(redeemed) @@ -268,16 +265,17 @@ func handleMsgMhfExchangeWeeklyStamp(s *Session, p mhfpacket.MHFPacket) { return } var total, redeemed uint16 + var err error var tktStack mhfitem.MHFItemStack if pkt.ExchangeType == 10 { // Yearly Sub Ex - if err := s.server.db.QueryRow("UPDATE stamps SET hl_total=hl_total-48, hl_redeemed=hl_redeemed-48 WHERE character_id=$1 RETURNING hl_total, hl_redeemed", s.charID).Scan(&total, &redeemed); err != nil { + if total, redeemed, err = s.server.stampRepo.ExchangeYearly(s.charID); err != nil { s.logger.Error("Failed to update yearly stamp exchange", zap.Error(err)) doAckBufFail(s, pkt.AckHandle, nil) return } tktStack = mhfitem.MHFItemStack{Item: mhfitem.MHFItem{ItemID: 2210}, Quantity: 1} } else { - if err := s.server.db.QueryRow(fmt.Sprintf("UPDATE stamps SET %s_redeemed=%s_redeemed+8 WHERE character_id=$1 RETURNING %s_total, %s_redeemed", pkt.StampType, pkt.StampType, pkt.StampType, pkt.StampType), s.charID).Scan(&total, &redeemed); err != nil { + if total, redeemed, err = s.server.stampRepo.Exchange(s.charID, pkt.StampType); err != nil { s.logger.Error("Failed to update stamp redemption", zap.Error(err)) doAckBufFail(s, pkt.AckHandle, nil) return diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 80bf3de4f..e2f093271 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -57,9 +57,7 @@ func handleMsgSysLogin(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysLogin) if !s.server.erupeConfig.DebugOptions.DisableTokenCheck { - var token string - err := s.server.db.QueryRow("SELECT token FROM sign_sessions ss INNER JOIN public.users u on ss.user_id = u.id WHERE token=$1 AND ss.id=$2 AND u.id=(SELECT c.user_id FROM characters c WHERE c.id=$3)", pkt.LoginTokenString, pkt.LoginTokenNumber, pkt.CharID0).Scan(&token) - if err != nil { + if err := s.server.sessionRepo.ValidateLoginToken(pkt.LoginTokenString, pkt.LoginTokenNumber, pkt.CharID0); err != nil { _ = s.rawConn.Close() s.logger.Warn(fmt.Sprintf("Invalid login token, offending CID: (%d)", pkt.CharID0)) return @@ -82,14 +80,14 @@ func handleMsgSysLogin(s *Session, p mhfpacket.MHFPacket) { bf := byteframe.NewByteFrame() bf.WriteUint32(uint32(TimeAdjusted().Unix())) // Unix timestamp - _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) + err = s.server.sessionRepo.UpdatePlayerCount(s.server.ID, len(s.server.sessions)) if err != nil { s.logger.Error("Failed to update current players", zap.Error(err)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) return } - _, err = s.server.db.Exec("UPDATE sign_sessions SET server_id=$1, char_id=$2 WHERE token=$3", s.server.ID, s.charID, s.token) + err = s.server.sessionRepo.BindSession(s.token, s.server.ID, s.charID) if err != nil { s.logger.Error("Failed to update sign session", zap.Error(err)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) @@ -281,7 +279,7 @@ func logoutPlayer(s *Session) { if err := s.server.charRepo.UpdateTimePlayed(s.charID, timePlayed); err != nil { s.logger.Error("Failed to update time played", zap.Error(err)) } - if _, err := s.server.db.Exec(`UPDATE guild_characters SET treasure_hunt=NULL WHERE character_id=$1`, s.charID); err != nil { + if err := s.server.guildRepo.ClearTreasureHunt(s.charID); err != nil { s.logger.Error("Failed to clear treasure hunt", zap.Error(err)) } } @@ -324,13 +322,11 @@ func logoutPlayer(s *Session) { // Update sign sessions and server player count if s.server.db != nil { - _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) - if err != nil { + if err := s.server.sessionRepo.ClearSession(s.token); err != nil { s.logger.Error("Failed to clear sign session", zap.Error(err)) } - _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) - if err != nil { + if err := s.server.sessionRepo.UpdatePlayerCount(s.server.ID, len(s.server.sessions)); err != nil { s.logger.Error("Failed to update player count", zap.Error(err)) } } @@ -433,7 +429,7 @@ func handleMsgSysRecordLog(s *Session, p mhfpacket.MHFPacket) { for i := 0; i < killLogMonsterCount; i++ { val = bf.ReadUint8() if val > 0 && mhfmon.Monsters[i].Large { - if _, err := s.server.db.Exec(`INSERT INTO kill_logs (character_id, monster, quantity, timestamp) VALUES ($1, $2, $3, $4)`, s.charID, i, val, TimeAdjusted()); err != nil { + if err := s.server.guildRepo.InsertKillLog(s.charID, i, val, TimeAdjusted()); err != nil { s.logger.Error("Failed to insert kill log", zap.Error(err)) } } diff --git a/server/channelserver/repo_distribution.go b/server/channelserver/repo_distribution.go new file mode 100644 index 000000000..c54e047cb --- /dev/null +++ b/server/channelserver/repo_distribution.go @@ -0,0 +1,79 @@ +package channelserver + +import ( + "github.com/jmoiron/sqlx" +) + +// DistributionRepository centralizes all database access for the distribution, +// distribution_items, and distributions_accepted tables. +type DistributionRepository struct { + db *sqlx.DB +} + +// NewDistributionRepository creates a new DistributionRepository. +func NewDistributionRepository(db *sqlx.DB) *DistributionRepository { + return &DistributionRepository{db: db} +} + +// List returns all distributions matching the given character and type. +func (r *DistributionRepository) List(charID uint32, distType uint8) ([]Distribution, error) { + rows, err := r.db.Queryx(` + SELECT d.id, event_name, description, COALESCE(rights, 0) AS rights, COALESCE(selection, false) AS selection, times_acceptable, + COALESCE(min_hr, -1) AS min_hr, COALESCE(max_hr, -1) AS max_hr, + COALESCE(min_sr, -1) AS min_sr, COALESCE(max_sr, -1) AS max_sr, + COALESCE(min_gr, -1) AS min_gr, COALESCE(max_gr, -1) AS max_gr, + ( + SELECT count(*) FROM distributions_accepted da + WHERE d.id = da.distribution_id AND da.character_id = $1 + ) AS times_accepted, + COALESCE(deadline, TO_TIMESTAMP(0)) AS deadline + FROM distribution d + WHERE character_id = $1 AND type = $2 OR character_id IS NULL AND type = $2 ORDER BY id DESC + `, charID, distType) + if err != nil { + return nil, err + } + defer rows.Close() + + var dists []Distribution + for rows.Next() { + var d Distribution + if err := rows.StructScan(&d); err != nil { + continue + } + dists = append(dists, d) + } + return dists, nil +} + +// GetItems returns all items for a given distribution. +func (r *DistributionRepository) GetItems(distributionID uint32) ([]DistributionItem, error) { + rows, err := r.db.Queryx(`SELECT id, item_type, COALESCE(item_id, 0) AS item_id, COALESCE(quantity, 0) AS quantity FROM distribution_items WHERE distribution_id=$1`, distributionID) + if err != nil { + return nil, err + } + defer rows.Close() + + var items []DistributionItem + for rows.Next() { + var item DistributionItem + if err := rows.StructScan(&item); err != nil { + continue + } + items = append(items, item) + } + return items, nil +} + +// RecordAccepted records that a character has accepted a distribution. +func (r *DistributionRepository) RecordAccepted(distributionID, charID uint32) error { + _, err := r.db.Exec(`INSERT INTO public.distributions_accepted VALUES ($1, $2)`, distributionID, charID) + return err +} + +// GetDescription returns the description text for a distribution. +func (r *DistributionRepository) GetDescription(distributionID uint32) (string, error) { + var desc string + err := r.db.QueryRow("SELECT description FROM distribution WHERE id = $1", distributionID).Scan(&desc) + return desc, err +} diff --git a/server/channelserver/repo_guild.go b/server/channelserver/repo_guild.go index aeec4f043..6ed3ff7a7 100644 --- a/server/channelserver/repo_guild.go +++ b/server/channelserver/repo_guild.go @@ -886,6 +886,18 @@ type ScoutedCharacter struct { ActorID uint32 `db:"actor_id"` } +// ClearTreasureHunt clears the treasure_hunt field for a character on logout. +func (r *GuildRepository) ClearTreasureHunt(charID uint32) error { + _, err := r.db.Exec(`UPDATE guild_characters SET treasure_hunt=NULL WHERE character_id=$1`, charID) + return err +} + +// InsertKillLog records a monster kill log entry for a character. +func (r *GuildRepository) InsertKillLog(charID uint32, monster int, quantity uint8, timestamp time.Time) error { + _, err := r.db.Exec(`INSERT INTO kill_logs (character_id, monster, quantity, timestamp) VALUES ($1, $2, $3, $4)`, charID, monster, quantity, timestamp) + return err +} + // ListInvitedCharacters returns all characters with pending guild invitations. func (r *GuildRepository) ListInvitedCharacters(guildID uint32) ([]*ScoutedCharacter, error) { rows, err := r.db.Queryx(` diff --git a/server/channelserver/repo_session.go b/server/channelserver/repo_session.go new file mode 100644 index 000000000..bb8a0dc6e --- /dev/null +++ b/server/channelserver/repo_session.go @@ -0,0 +1,40 @@ +package channelserver + +import ( + "github.com/jmoiron/sqlx" +) + +// SessionRepository centralizes all database access for sign_sessions and servers tables. +type SessionRepository struct { + db *sqlx.DB +} + +// NewSessionRepository creates a new SessionRepository. +func NewSessionRepository(db *sqlx.DB) *SessionRepository { + return &SessionRepository{db: db} +} + +// ValidateLoginToken validates that the given token, session ID, and character ID +// correspond to a valid sign session. Returns an error if the token is invalid. +func (r *SessionRepository) ValidateLoginToken(token string, sessionID uint32, charID uint32) error { + var t string + return r.db.QueryRow("SELECT token FROM sign_sessions ss INNER JOIN public.users u on ss.user_id = u.id WHERE token=$1 AND ss.id=$2 AND u.id=(SELECT c.user_id FROM characters c WHERE c.id=$3)", token, sessionID, charID).Scan(&t) +} + +// BindSession associates a sign session token with a server and character. +func (r *SessionRepository) BindSession(token string, serverID uint16, charID uint32) error { + _, err := r.db.Exec("UPDATE sign_sessions SET server_id=$1, char_id=$2 WHERE token=$3", serverID, charID, token) + return err +} + +// ClearSession removes the server and character association from a sign session. +func (r *SessionRepository) ClearSession(token string) error { + _, err := r.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", token) + return err +} + +// UpdatePlayerCount updates the current player count for a server. +func (r *SessionRepository) UpdatePlayerCount(serverID uint16, count int) error { + _, err := r.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", count, serverID) + return err +} diff --git a/server/channelserver/repo_stamp.go b/server/channelserver/repo_stamp.go new file mode 100644 index 000000000..28c65de0e --- /dev/null +++ b/server/channelserver/repo_stamp.go @@ -0,0 +1,61 @@ +package channelserver + +import ( + "fmt" + "time" + + "github.com/jmoiron/sqlx" +) + +// StampRepository centralizes all database access for the stamps table. +type StampRepository struct { + db *sqlx.DB +} + +// NewStampRepository creates a new StampRepository. +func NewStampRepository(db *sqlx.DB) *StampRepository { + return &StampRepository{db: db} +} + +// GetChecked returns the last check time for the given stamp type ("hl" or "ex"). +func (r *StampRepository) GetChecked(charID uint32, stampType string) (time.Time, error) { + var lastCheck time.Time + err := r.db.QueryRow(fmt.Sprintf("SELECT %s_checked FROM stamps WHERE character_id=$1", stampType), charID).Scan(&lastCheck) + return lastCheck, err +} + +// Init inserts a new stamps record for a character with both check times set to now. +func (r *StampRepository) Init(charID uint32, now time.Time) error { + _, err := r.db.Exec("INSERT INTO stamps (character_id, hl_checked, ex_checked) VALUES ($1, $2, $2)", charID, now) + return err +} + +// SetChecked updates the check time for a given stamp type. +func (r *StampRepository) SetChecked(charID uint32, stampType string, now time.Time) error { + _, err := r.db.Exec(fmt.Sprintf(`UPDATE stamps SET %s_checked=$1 WHERE character_id=$2`, stampType), now, charID) + return err +} + +// IncrementTotal increments the total stamp count for a given stamp type. +func (r *StampRepository) IncrementTotal(charID uint32, stampType string) error { + _, err := r.db.Exec(fmt.Sprintf("UPDATE stamps SET %s_total=%s_total+1 WHERE character_id=$1", stampType, stampType), charID) + return err +} + +// GetTotals returns the total and redeemed counts for a given stamp type. +func (r *StampRepository) GetTotals(charID uint32, stampType string) (total, redeemed uint16, err error) { + err = r.db.QueryRow(fmt.Sprintf("SELECT %s_total, %s_redeemed FROM stamps WHERE character_id=$1", stampType, stampType), charID).Scan(&total, &redeemed) + return +} + +// ExchangeYearly performs a yearly stamp exchange, subtracting 48 from both hl_total and hl_redeemed. +func (r *StampRepository) ExchangeYearly(charID uint32) (total, redeemed uint16, err error) { + err = r.db.QueryRow("UPDATE stamps SET hl_total=hl_total-48, hl_redeemed=hl_redeemed-48 WHERE character_id=$1 RETURNING hl_total, hl_redeemed", charID).Scan(&total, &redeemed) + return +} + +// Exchange performs a stamp exchange, adding 8 to the redeemed count for a given stamp type. +func (r *StampRepository) Exchange(charID uint32, stampType string) (total, redeemed uint16, err error) { + err = r.db.QueryRow(fmt.Sprintf("UPDATE stamps SET %s_redeemed=%s_redeemed+8 WHERE character_id=$1 RETURNING %s_total, %s_redeemed", stampType, stampType, stampType, stampType), charID).Scan(&total, &redeemed) + return +} diff --git a/server/channelserver/session_lifecycle_integration_test.go b/server/channelserver/session_lifecycle_integration_test.go index 275a9da2e..3725afca0 100644 --- a/server/channelserver/session_lifecycle_integration_test.go +++ b/server/channelserver/session_lifecycle_integration_test.go @@ -607,6 +607,9 @@ func createTestServerWithDB(t *testing.T, db *sqlx.DB) *Server { server.towerRepo = NewTowerRepository(db) server.rengokuRepo = NewRengokuRepository(db) server.mailRepo = NewMailRepository(db) + server.stampRepo = NewStampRepository(db) + server.distRepo = NewDistributionRepository(db) + server.sessionRepo = NewSessionRepository(db) return server } diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index e814442f6..605a457d7 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -55,6 +55,9 @@ type Server struct { towerRepo *TowerRepository rengokuRepo *RengokuRepository mailRepo *MailRepository + stampRepo *StampRepository + distRepo *DistributionRepository + sessionRepo *SessionRepository erupeConfig *_config.Config acceptConns chan net.Conn deleteConns chan net.Conn @@ -134,6 +137,9 @@ func NewServer(config *Config) *Server { s.towerRepo = NewTowerRepository(config.DB) s.rengokuRepo = NewRengokuRepository(config.DB) s.mailRepo = NewMailRepository(config.DB) + s.stampRepo = NewStampRepository(config.DB) + s.distRepo = NewDistributionRepository(config.DB) + s.sessionRepo = NewSessionRepository(config.DB) // Mezeporta s.stages["sl1Ns200p0a0u0"] = NewStage("sl1Ns200p0a0u0") diff --git a/server/channelserver/testhelpers_db.go b/server/channelserver/testhelpers_db.go index a96472b65..87dc599ab 100644 --- a/server/channelserver/testhelpers_db.go +++ b/server/channelserver/testhelpers_db.go @@ -378,4 +378,7 @@ func SetTestDB(s *Server, db *sqlx.DB) { s.towerRepo = NewTowerRepository(db) s.rengokuRepo = NewRengokuRepository(db) s.mailRepo = NewMailRepository(db) + s.stampRepo = NewStampRepository(db) + s.distRepo = NewDistributionRepository(db) + s.sessionRepo = NewSessionRepository(db) }