From ab9fd0bc9ccb82ccfe32eb18089fe73af37d524a Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 20 Feb 2026 22:18:46 +0100 Subject: [PATCH] refactor(channelserver): extract UserRepository for users table access Centralizes all 31 direct users-table SQL queries from 11 handler files into a single UserRepository, following the same pattern as CharacterRepository and GuildRepository. The only excluded query is the sign_sessions JOIN in handleMsgSysLogin which spans multiple tables. --- server/channelserver/handlers_cast_binary.go | 4 +- server/channelserver/handlers_commands.go | 29 ++- server/channelserver/handlers_discord.go | 5 +- server/channelserver/handlers_distitem.go | 6 +- server/channelserver/handlers_gacha.go | 18 +- server/channelserver/handlers_helpers.go | 6 +- server/channelserver/handlers_items.go | 4 +- server/channelserver/handlers_session.go | 2 +- server/channelserver/handlers_shop.go | 8 +- server/channelserver/repo_user.go | 220 +++++++++++++++++++ server/channelserver/sys_channel_server.go | 2 + server/channelserver/sys_session.go | 9 +- 12 files changed, 265 insertions(+), 48 deletions(-) create mode 100644 server/channelserver/repo_user.go diff --git a/server/channelserver/handlers_cast_binary.go b/server/channelserver/handlers_cast_binary.go index 85dabe08c..7bcbd6f0c 100644 --- a/server/channelserver/handlers_cast_binary.go +++ b/server/channelserver/handlers_cast_binary.go @@ -41,8 +41,8 @@ func handleMsgSysCastBinary(s *Session, p mhfpacket.MHFPacket) { ) if pkt.BroadcastType == BroadcastTypeStage && pkt.MessageType == BinaryMessageTypeData && len(pkt.RawDataPayload) == timerPayloadSize { if tmp.ReadUint16() == timerSubtype && tmp.ReadUint8() == timerFlag { - var timer bool - if err := s.server.db.QueryRow(`SELECT COALESCE(timer, false) FROM users WHERE id=$1`, s.userID).Scan(&timer); err != nil { + timer, err := s.server.userRepo.GetTimer(s.userID) + if err != nil { s.logger.Error("Failed to get timer setting", zap.Error(err)) } if timer { diff --git a/server/channelserver/handlers_commands.go b/server/channelserver/handlers_commands.go index 718505536..a75706a00 100644 --- a/server/channelserver/handlers_commands.go +++ b/server/channelserver/handlers_commands.go @@ -101,9 +101,7 @@ func parseChatCommand(s *Session, command string) { } cid := mhfcid.ConvertCID(args[1]) if cid > 0 { - var uid uint32 - var uname string - err := s.server.db.QueryRow(`SELECT id, username FROM users u WHERE u.id=(SELECT c.user_id FROM characters c WHERE c.id=$1)`, cid).Scan(&uid, &uname) + uid, uname, err := s.server.userRepo.GetByIDAndUsername(cid) if err == nil { if expiry.IsZero() { if _, err := s.server.db.Exec(`INSERT INTO bans VALUES ($1) @@ -133,11 +131,11 @@ func parseChatCommand(s *Session, command string) { } case commands["Timer"].Prefix: if commands["Timer"].Enabled || s.isOp() { - var state bool - if err := s.server.db.QueryRow(`SELECT COALESCE(timer, false) FROM users WHERE id=$1`, s.userID).Scan(&state); err != nil { + state, err := s.server.userRepo.GetTimer(s.userID) + if err != nil { s.logger.Error("Failed to get timer state", zap.Error(err)) } - if _, err := s.server.db.Exec(`UPDATE users SET timer=$1 WHERE id=$2`, !state, s.userID); err != nil { + if err := s.server.userRepo.SetTimer(s.userID, !state); err != nil { s.logger.Error("Failed to update timer setting", zap.Error(err)) } if state { @@ -151,12 +149,12 @@ func parseChatCommand(s *Session, command string) { case commands["PSN"].Prefix: if commands["PSN"].Enabled || s.isOp() { if len(args) > 1 { - var exists int - if err := s.server.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, args[1]).Scan(&exists); err != nil { + exists, err := s.server.userRepo.CountByPSNID(args[1]) + if err != nil { s.logger.Error("Failed to check PSN ID existence", zap.Error(err)) } if exists == 0 { - _, err := s.server.db.Exec(`UPDATE users SET psn_id=$1 WHERE id=$2`, args[1], s.userID) + err := s.server.userRepo.SetPSNID(s.userID, args[1]) if err == nil { sendServerChatMessage(s, fmt.Sprintf(s.server.i18n.commands.psn.success, args[1])) } @@ -258,7 +256,7 @@ func parseChatCommand(s *Session, command string) { if commands["Rights"].Enabled || s.isOp() { if len(args) > 1 { v, _ := strconv.Atoi(args[1]) - _, err := s.server.db.Exec("UPDATE users SET rights=$1 WHERE id=$2", v, s.userID) + err := s.server.userRepo.SetRights(s.userID, uint32(v)) if err == nil { sendServerChatMessage(s, fmt.Sprintf(s.server.i18n.commands.rights.success, v)) } else { @@ -277,7 +275,7 @@ func parseChatCommand(s *Session, command string) { for _, alias := range course.Aliases() { if strings.EqualFold(args[1], alias) { if slices.Contains(s.server.erupeConfig.Courses, _config.Course{Name: course.Aliases()[0], Enabled: true}) { - var delta, rightsInt uint32 + var delta uint32 if mhfcourse.CourseExists(course.ID, s.courses) { ei := slices.IndexFunc(s.courses, func(c mhfcourse.Course) bool { for _, alias := range c.Aliases() { @@ -295,9 +293,9 @@ func parseChatCommand(s *Session, command string) { delta = uint32(math.Pow(2, float64(course.ID))) sendServerChatMessage(s, fmt.Sprintf(s.server.i18n.commands.course.enabled, course.Aliases()[0])) } - err := s.server.db.QueryRow("SELECT rights FROM users WHERE id=$1", s.userID).Scan(&rightsInt) + rightsInt, err := s.server.userRepo.GetRights(s.userID) if err == nil { - if _, err := s.server.db.Exec("UPDATE users SET rights=$1 WHERE id=$2", rightsInt+delta, s.userID); err != nil { + if err := s.server.userRepo.SetRights(s.userID, rightsInt+delta); err != nil { s.logger.Error("Failed to update user rights", zap.Error(err)) } } @@ -391,13 +389,12 @@ func parseChatCommand(s *Session, command string) { } case commands["Discord"].Prefix: if commands["Discord"].Enabled || s.isOp() { - var _token string - err := s.server.db.QueryRow(`SELECT discord_token FROM users WHERE id=$1`, s.userID).Scan(&_token) + _token, err := s.server.userRepo.GetDiscordToken(s.userID) if err != nil { randToken := make([]byte, 4) _, _ = rand.Read(randToken) _token = fmt.Sprintf("%x-%x", randToken[:2], randToken[2:]) - if _, err := s.server.db.Exec(`UPDATE users SET discord_token = $1 WHERE id=$2`, _token, s.userID); err != nil { + if err := s.server.userRepo.SetDiscordToken(s.userID, _token); err != nil { s.logger.Error("Failed to update discord token", zap.Error(err)) } } diff --git a/server/channelserver/handlers_discord.go b/server/channelserver/handlers_discord.go index cd27ce5bd..06b7dcf99 100644 --- a/server/channelserver/handlers_discord.go +++ b/server/channelserver/handlers_discord.go @@ -12,8 +12,7 @@ import ( func (s *Server) onInteraction(ds *discordgo.Session, i *discordgo.InteractionCreate) { switch i.Interaction.ApplicationCommandData().Name { case "link": - var temp string - err := s.db.QueryRow(`UPDATE users SET discord_id = $1 WHERE discord_token = $2 RETURNING discord_id`, i.Member.User.ID, i.ApplicationCommandData().Options[0].StringValue()).Scan(&temp) + _, err := s.userRepo.LinkDiscord(i.Member.User.ID, i.ApplicationCommandData().Options[0].StringValue()) if err == nil { _ = ds.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, @@ -33,7 +32,7 @@ func (s *Server) onInteraction(ds *discordgo.Session, i *discordgo.InteractionCr } case "password": password, _ := bcrypt.GenerateFromPassword([]byte(i.ApplicationCommandData().Options[0].StringValue()), 10) - _, err := s.db.Exec(`UPDATE users SET password = $1 WHERE discord_id = $2`, password, i.Member.User.ID) + err := s.userRepo.SetPasswordByDiscordID(i.Member.User.ID, password) if err == nil { _ = ds.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ Type: discordgo.InteractionResponseChannelMessageWithSource, diff --git a/server/channelserver/handlers_distitem.go b/server/channelserver/handlers_distitem.go index 6311fce2d..901e24593 100644 --- a/server/channelserver/handlers_distitem.go +++ b/server/channelserver/handlers_distitem.go @@ -172,15 +172,15 @@ func handleMsgMhfAcquireDistItem(s *Session, p mhfpacket.MHFPacket) { case 17: _ = addPointNetcafe(s, int(item.Quantity)) case 19: - if _, err := s.server.db.Exec("UPDATE users SET gacha_premium=gacha_premium+$1 WHERE id=$2", item.Quantity, s.userID); err != nil { + if err := s.server.userRepo.AddPremiumCoins(s.userID, item.Quantity); err != nil { s.logger.Error("Failed to update gacha premium", zap.Error(err)) } case 20: - if _, err := s.server.db.Exec("UPDATE users SET gacha_trial=gacha_trial+$1 WHERE id=$2", item.Quantity, s.userID); err != nil { + if err := s.server.userRepo.AddTrialCoins(s.userID, item.Quantity); err != nil { s.logger.Error("Failed to update gacha trial", zap.Error(err)) } case 21: - if _, err := s.server.db.Exec("UPDATE users SET frontier_points=frontier_points+$1 WHERE id=$2", item.Quantity, s.userID); err != nil { + if err := s.server.userRepo.AddFrontierPoints(s.userID, item.Quantity); err != nil { s.logger.Error("Failed to update frontier points", zap.Error(err)) } case 23: diff --git a/server/channelserver/handlers_gacha.go b/server/channelserver/handlers_gacha.go index 795e8785e..8dbf1debe 100644 --- a/server/channelserver/handlers_gacha.go +++ b/server/channelserver/handlers_gacha.go @@ -54,8 +54,7 @@ func handleMsgMhfGetGachaPlayHistory(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfGetGachaPoint(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfGetGachaPoint) - var fp, gp, gt uint32 - _ = s.server.db.QueryRow("SELECT COALESCE(frontier_points, 0), COALESCE(gacha_premium, 0), COALESCE(gacha_trial, 0) FROM users WHERE id=$1", s.userID).Scan(&fp, &gp, >) + fp, gp, gt, _ := s.server.userRepo.GetGachaPoints(s.userID) resp := byteframe.NewByteFrame() resp.WriteUint32(gp) resp.WriteUint32(gt) @@ -66,12 +65,12 @@ func handleMsgMhfGetGachaPoint(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfUseGachaPoint(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfUseGachaPoint) if pkt.TrialCoins > 0 { - if _, err := s.server.db.Exec(`UPDATE users SET gacha_trial=gacha_trial-$1 WHERE id=$2`, pkt.TrialCoins, s.userID); err != nil { + if err := s.server.userRepo.DeductTrialCoins(s.userID, pkt.TrialCoins); err != nil { s.logger.Error("Failed to deduct gacha trial coins", zap.Error(err)) } } if pkt.PremiumCoins > 0 { - if _, err := s.server.db.Exec(`UPDATE users SET gacha_premium=gacha_premium-$1 WHERE id=$2`, pkt.PremiumCoins, s.userID); err != nil { + if err := s.server.userRepo.DeductPremiumCoins(s.userID, pkt.PremiumCoins); err != nil { s.logger.Error("Failed to deduct gacha premium coins", zap.Error(err)) } } @@ -79,14 +78,13 @@ func handleMsgMhfUseGachaPoint(s *Session, p mhfpacket.MHFPacket) { } func spendGachaCoin(s *Session, quantity uint16) { - var gt uint16 - _ = s.server.db.QueryRow(`SELECT COALESCE(gacha_trial, 0) FROM users WHERE id=$1`, s.userID).Scan(>) + gt, _ := s.server.userRepo.GetTrialCoins(s.userID) if quantity <= gt { - if _, err := s.server.db.Exec(`UPDATE users SET gacha_trial=gacha_trial-$1 WHERE id=$2`, quantity, s.userID); err != nil { + if err := s.server.userRepo.DeductTrialCoins(s.userID, uint32(quantity)); err != nil { s.logger.Error("Failed to deduct gacha trial coins", zap.Error(err)) } } else { - if _, err := s.server.db.Exec(`UPDATE users SET gacha_premium=gacha_premium-$1 WHERE id=$2`, quantity, s.userID); err != nil { + if err := s.server.userRepo.DeductPremiumCoins(s.userID, uint32(quantity)); err != nil { s.logger.Error("Failed to deduct gacha premium coins", zap.Error(err)) } } @@ -117,7 +115,7 @@ func transactGacha(s *Session, gachaID uint32, rollID uint8) (int, error) { case 20: spendGachaCoin(s, itemNumber) case 21: - if _, err := s.server.db.Exec("UPDATE users SET frontier_points=frontier_points-$1 WHERE id=$2", itemNumber, s.userID); err != nil { + if err := s.server.userRepo.DeductFrontierPoints(s.userID, uint32(itemNumber)); err != nil { s.logger.Error("Failed to deduct frontier points for gacha", zap.Error(err)) } } @@ -287,7 +285,7 @@ func handleMsgMhfPlayStepupGacha(s *Session, p mhfpacket.MHFPacket) { doAckBufSucceed(s, pkt.AckHandle, make([]byte, 1)) return } - if _, err := s.server.db.Exec("UPDATE users SET frontier_points=frontier_points+(SELECT frontier_points FROM gacha_entries WHERE gacha_id = $1 AND entry_type = $2) WHERE id=$3", pkt.GachaID, pkt.RollType, s.userID); err != nil { + if err := s.server.userRepo.AddFrontierPointsFromGacha(s.userID, pkt.GachaID, pkt.RollType); err != nil { s.logger.Error("Failed to award stepup gacha frontier points", zap.Error(err)) } if _, err := s.server.db.Exec(`DELETE FROM gacha_stepup WHERE gacha_id = $1 AND character_id = $2`, pkt.GachaID, s.charID); err != nil { diff --git a/server/channelserver/handlers_helpers.go b/server/channelserver/handlers_helpers.go index ae7b254c4..73204fdcc 100644 --- a/server/channelserver/handlers_helpers.go +++ b/server/channelserver/handlers_helpers.go @@ -108,8 +108,10 @@ func adjustCharacterInt(s *Session, column string, delta int) (int, error) { } func updateRights(s *Session) { - rightsInt := uint32(2) - _ = s.server.db.QueryRow("SELECT rights FROM users WHERE id=$1", s.userID).Scan(&rightsInt) + rightsInt, err := s.server.userRepo.GetRights(s.userID) + if err != nil { + rightsInt = 2 + } s.courses, rightsInt = mhfcourse.GetCourseStruct(rightsInt, s.server.erupeConfig.DefaultCourses) update := &mhfpacket.MsgSysUpdateRight{ ClientRespAckHandle: 0, diff --git a/server/channelserver/handlers_items.go b/server/channelserver/handlers_items.go index 744bd3ee2..7a5afafbb 100644 --- a/server/channelserver/handlers_items.go +++ b/server/channelserver/handlers_items.go @@ -192,7 +192,7 @@ func handleMsgMhfGetExtraInfo(s *Session, p mhfpacket.MHFPacket) {} func userGetItems(s *Session) []mhfitem.MHFItemStack { var data []byte var items []mhfitem.MHFItemStack - _ = s.server.db.QueryRow(`SELECT item_box FROM users WHERE id=$1`, s.userID).Scan(&data) + data, _ = s.server.userRepo.GetItemBox(s.userID) if len(data) > 0 { box := byteframe.NewByteFrameFromBytes(data) numStacks := box.ReadUint16() @@ -215,7 +215,7 @@ func handleMsgMhfEnumerateUnionItem(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfUpdateUnionItem(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfUpdateUnionItem) newStacks := mhfitem.DiffItemStacks(userGetItems(s), pkt.UpdatedItems) - if _, err := s.server.db.Exec(`UPDATE users SET item_box=$1 WHERE id=$2`, mhfitem.SerializeWarehouseItems(newStacks), s.userID); err != nil { + if err := s.server.userRepo.SetItemBox(s.userID, mhfitem.SerializeWarehouseItems(newStacks)); err != nil { s.logger.Error("Failed to update union item box", zap.Error(err)) } doAckSimpleSucceed(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index f7ef7d87c..75eaac40e 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -102,7 +102,7 @@ func handleMsgSysLogin(s *Session, p mhfpacket.MHFPacket) { return } - _, err = s.server.db.Exec("UPDATE users SET last_character=$1 WHERE id=$2", s.charID, s.userID) + err = s.server.userRepo.SetLastCharacter(s.userID, s.charID) if err != nil { s.logger.Error("Failed to update last character", zap.Error(err)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_shop.go b/server/channelserver/handlers_shop.go index 72ff9b8b6..cf1dcc170 100644 --- a/server/channelserver/handlers_shop.go +++ b/server/channelserver/handlers_shop.go @@ -261,7 +261,6 @@ type FPointExchange struct { func handleMsgMhfExchangeFpoint2Item(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfExchangeFpoint2Item) - var balance uint32 var itemValue, quantity int if err := s.server.db.QueryRow("SELECT quantity, fpoints FROM fpoint_items WHERE id=$1", pkt.TradeID).Scan(&quantity, &itemValue); err != nil { s.logger.Error("Failed to read fpoint item cost", zap.Error(err)) @@ -269,7 +268,8 @@ func handleMsgMhfExchangeFpoint2Item(s *Session, p mhfpacket.MHFPacket) { return } cost := (int(pkt.Quantity) * quantity) * itemValue - if err := s.server.db.QueryRow("UPDATE users SET frontier_points=frontier_points::int - $1 WHERE id=$2 RETURNING frontier_points", cost, s.userID).Scan(&balance); err != nil { + balance, err := s.server.userRepo.AdjustFrontierPointsDeduct(s.userID, cost) + if err != nil { s.logger.Error("Failed to deduct frontier points", zap.Error(err)) doAckSimpleFail(s, pkt.AckHandle, nil) return @@ -281,7 +281,6 @@ func handleMsgMhfExchangeFpoint2Item(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfExchangeItem2Fpoint(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfExchangeItem2Fpoint) - var balance uint32 var itemValue, quantity int if err := s.server.db.QueryRow("SELECT quantity, fpoints FROM fpoint_items WHERE id=$1", pkt.TradeID).Scan(&quantity, &itemValue); err != nil { s.logger.Error("Failed to read fpoint item value", zap.Error(err)) @@ -289,7 +288,8 @@ func handleMsgMhfExchangeItem2Fpoint(s *Session, p mhfpacket.MHFPacket) { return } cost := (int(pkt.Quantity) / quantity) * itemValue - if err := s.server.db.QueryRow("UPDATE users SET frontier_points=COALESCE(frontier_points::int + $1, $1) WHERE id=$2 RETURNING frontier_points", cost, s.userID).Scan(&balance); err != nil { + balance, err := s.server.userRepo.AdjustFrontierPointsCredit(s.userID, cost) + if err != nil { s.logger.Error("Failed to credit frontier points", zap.Error(err)) doAckSimpleFail(s, pkt.AckHandle, nil) return diff --git a/server/channelserver/repo_user.go b/server/channelserver/repo_user.go new file mode 100644 index 000000000..b3bc7b003 --- /dev/null +++ b/server/channelserver/repo_user.go @@ -0,0 +1,220 @@ +package channelserver + +import ( + "database/sql" + + "github.com/jmoiron/sqlx" +) + +// UserRepository centralizes all database access for the users table. +type UserRepository struct { + db *sqlx.DB +} + +// NewUserRepository creates a new UserRepository. +func NewUserRepository(db *sqlx.DB) *UserRepository { + return &UserRepository{db: db} +} + +// Gacha/Currency methods + +// GetGachaPoints returns the user's frontier points, premium gacha coins, and trial gacha coins. +func (r *UserRepository) GetGachaPoints(userID uint32) (fp, premium, trial uint32, err error) { + err = r.db.QueryRow( + `SELECT COALESCE(frontier_points, 0), COALESCE(gacha_premium, 0), COALESCE(gacha_trial, 0) FROM users WHERE id=$1`, + userID, + ).Scan(&fp, &premium, &trial) + return +} + +// GetTrialCoins returns the user's trial gacha coin balance. +func (r *UserRepository) GetTrialCoins(userID uint32) (uint16, error) { + var balance uint16 + err := r.db.QueryRow(`SELECT COALESCE(gacha_trial, 0) FROM users WHERE id=$1`, userID).Scan(&balance) + return balance, err +} + +// DeductTrialCoins subtracts the given amount from the user's trial gacha coins. +func (r *UserRepository) DeductTrialCoins(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET gacha_trial=gacha_trial-$1 WHERE id=$2`, amount, userID) + return err +} + +// DeductPremiumCoins subtracts the given amount from the user's premium gacha coins. +func (r *UserRepository) DeductPremiumCoins(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET gacha_premium=gacha_premium-$1 WHERE id=$2`, amount, userID) + return err +} + +// AddPremiumCoins adds the given amount to the user's premium gacha coins. +func (r *UserRepository) AddPremiumCoins(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET gacha_premium=gacha_premium+$1 WHERE id=$2`, amount, userID) + return err +} + +// AddTrialCoins adds the given amount to the user's trial gacha coins. +func (r *UserRepository) AddTrialCoins(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET gacha_trial=gacha_trial+$1 WHERE id=$2`, amount, userID) + return err +} + +// DeductFrontierPoints subtracts the given amount from the user's frontier points. +func (r *UserRepository) DeductFrontierPoints(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET frontier_points=frontier_points-$1 WHERE id=$2`, amount, userID) + return err +} + +// AddFrontierPoints adds the given amount to the user's frontier points. +func (r *UserRepository) AddFrontierPoints(userID uint32, amount uint32) error { + _, err := r.db.Exec(`UPDATE users SET frontier_points=frontier_points+$1 WHERE id=$2`, amount, userID) + return err +} + +// AdjustFrontierPointsDeduct atomically deducts frontier points and returns the new balance. +func (r *UserRepository) AdjustFrontierPointsDeduct(userID uint32, amount int) (uint32, error) { + var balance uint32 + err := r.db.QueryRow( + `UPDATE users SET frontier_points=frontier_points::int - $1 WHERE id=$2 RETURNING frontier_points`, + amount, userID, + ).Scan(&balance) + return balance, err +} + +// AdjustFrontierPointsCredit atomically credits frontier points and returns the new balance. +func (r *UserRepository) AdjustFrontierPointsCredit(userID uint32, amount int) (uint32, error) { + var balance uint32 + err := r.db.QueryRow( + `UPDATE users SET frontier_points=COALESCE(frontier_points::int + $1, $1) WHERE id=$2 RETURNING frontier_points`, + amount, userID, + ).Scan(&balance) + return balance, err +} + +// AddFrontierPointsFromGacha awards frontier points from a gacha entry's defined value. +func (r *UserRepository) AddFrontierPointsFromGacha(userID uint32, gachaID uint32, entryType uint8) error { + _, err := r.db.Exec( + `UPDATE users SET frontier_points=frontier_points+(SELECT frontier_points FROM gacha_entries WHERE gacha_id = $1 AND entry_type = $2) WHERE id=$3`, + gachaID, entryType, userID, + ) + return err +} + +// Rights/Permissions methods + +// GetRights returns the user's rights bitmask. +func (r *UserRepository) GetRights(userID uint32) (uint32, error) { + var rights uint32 + err := r.db.QueryRow(`SELECT rights FROM users WHERE id=$1`, userID).Scan(&rights) + return rights, err +} + +// SetRights sets the user's rights bitmask. +func (r *UserRepository) SetRights(userID uint32, rights uint32) error { + _, err := r.db.Exec(`UPDATE users SET rights=$1 WHERE id=$2`, rights, userID) + return err +} + +// IsOp returns whether the user has operator privileges. +func (r *UserRepository) IsOp(userID uint32) (bool, error) { + var op bool + err := r.db.QueryRow(`SELECT op FROM users WHERE id=$1`, userID).Scan(&op) + if err != nil { + return false, err + } + return op, nil +} + +// User metadata methods + +// SetLastCharacter records the last-played character for a user. +func (r *UserRepository) SetLastCharacter(userID uint32, charID uint32) error { + _, err := r.db.Exec(`UPDATE users SET last_character=$1 WHERE id=$2`, charID, userID) + return err +} + +// GetTimer returns whether the user has the quest timer display enabled. +func (r *UserRepository) GetTimer(userID uint32) (bool, error) { + var timer bool + err := r.db.QueryRow(`SELECT COALESCE(timer, false) FROM users WHERE id=$1`, userID).Scan(&timer) + return timer, err +} + +// SetTimer sets the user's quest timer display preference. +func (r *UserRepository) SetTimer(userID uint32, value bool) error { + _, err := r.db.Exec(`UPDATE users SET timer=$1 WHERE id=$2`, value, userID) + return err +} + +// CountByPSNID returns the number of users with the given PSN ID. +func (r *UserRepository) CountByPSNID(psnID string) (int, error) { + var count int + err := r.db.QueryRow(`SELECT count(*) FROM users WHERE psn_id = $1`, psnID).Scan(&count) + return count, err +} + +// SetPSNID associates a PSN ID with the user's account. +func (r *UserRepository) SetPSNID(userID uint32, psnID string) error { + _, err := r.db.Exec(`UPDATE users SET psn_id=$1 WHERE id=$2`, psnID, userID) + return err +} + +// GetDiscordToken returns the user's discord link token. +func (r *UserRepository) GetDiscordToken(userID uint32) (string, error) { + var token string + err := r.db.QueryRow(`SELECT discord_token FROM users WHERE id=$1`, userID).Scan(&token) + return token, err +} + +// SetDiscordToken sets the user's discord link token. +func (r *UserRepository) SetDiscordToken(userID uint32, token string) error { + _, err := r.db.Exec(`UPDATE users SET discord_token = $1 WHERE id=$2`, token, userID) + return err +} + +// Warehouse methods + +// GetItemBox returns the user's serialized warehouse item data. +func (r *UserRepository) GetItemBox(userID uint32) ([]byte, error) { + var data []byte + err := r.db.QueryRow(`SELECT item_box FROM users WHERE id=$1`, userID).Scan(&data) + if err == sql.ErrNoRows { + return nil, nil + } + return data, err +} + +// SetItemBox persists the user's warehouse item data. +func (r *UserRepository) SetItemBox(userID uint32, data []byte) error { + _, err := r.db.Exec(`UPDATE users SET item_box=$1 WHERE id=$2`, data, userID) + return err +} + +// Discord bot methods (Server-level) + +// LinkDiscord associates a Discord user ID with the account matching the given token. +// Returns the discord_id on success. +func (r *UserRepository) LinkDiscord(discordID string, token string) (string, error) { + var result string + err := r.db.QueryRow( + `UPDATE users SET discord_id = $1 WHERE discord_token = $2 RETURNING discord_id`, + discordID, token, + ).Scan(&result) + return result, err +} + +// SetPasswordByDiscordID updates the password for the user linked to the given Discord ID. +func (r *UserRepository) SetPasswordByDiscordID(discordID string, hash []byte) error { + _, err := r.db.Exec(`UPDATE users SET password = $1 WHERE discord_id = $2`, hash, discordID) + return err +} + +// Auth methods + +// GetByIDAndUsername resolves a character ID to the owning user's ID and username. +func (r *UserRepository) GetByIDAndUsername(charID uint32) (userID uint32, username string, err error) { + err = r.db.QueryRow( + `SELECT id, username FROM users u WHERE u.id=(SELECT c.user_id FROM characters c WHERE c.id=$1)`, + charID, + ).Scan(&userID, &username) + return +} diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index e89bb7ffb..6622afc5f 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -47,6 +47,7 @@ type Server struct { db *sqlx.DB charRepo *CharacterRepository guildRepo *GuildRepository + userRepo *UserRepository erupeConfig *_config.Config acceptConns chan net.Conn deleteConns chan net.Conn @@ -119,6 +120,7 @@ func NewServer(config *Config) *Server { s.charRepo = NewCharacterRepository(config.DB) s.guildRepo = NewGuildRepository(config.DB) + s.userRepo = NewUserRepository(config.DB) // Mezeporta s.stages["sl1Ns200p0a0u0"] = NewStage("sl1Ns200p0a0u0") diff --git a/server/channelserver/sys_session.go b/server/channelserver/sys_session.go index 703672776..468ddd5f6 100644 --- a/server/channelserver/sys_session.go +++ b/server/channelserver/sys_session.go @@ -352,10 +352,9 @@ func (s *Session) GetSemaphoreID() uint32 { } func (s *Session) isOp() bool { - var op bool - err := s.server.db.QueryRow(`SELECT op FROM users WHERE id=$1`, s.userID).Scan(&op) - if err == nil && op { - return true + op, err := s.server.userRepo.IsOp(s.userID) + if err != nil { + return false } - return false + return op }