diff --git a/server/channelserver/handlers_achievement.go b/server/channelserver/handlers_achievement.go index d26f8a1c2..d466d58e3 100644 --- a/server/channelserver/handlers_achievement.go +++ b/server/channelserver/handlers_achievement.go @@ -1,10 +1,10 @@ package channelserver import ( - "erupe-ce/common/byteframe" - "erupe-ce/network/mhfpacket" "io" + "erupe-ce/common/byteframe" + "erupe-ce/network/mhfpacket" "go.uber.org/zap" ) @@ -97,33 +97,25 @@ func GetAchData(id uint8, score int32) Achievement { func handleMsgMhfGetAchievement(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfGetAchievement) - if err := s.server.achievementRepo.EnsureExists(pkt.CharID); err != nil { - s.logger.Error("Failed to ensure achievements record", zap.Error(err)) - } - - scores, err := s.server.achievementRepo.GetAllScores(pkt.CharID) + summary, err := s.server.achievementService.GetAll(pkt.CharID) if err != nil { doAckBufSucceed(s, pkt.AckHandle, make([]byte, 20)) return } resp := byteframe.NewByteFrame() - var points uint32 resp.WriteBytes(make([]byte, 16)) resp.WriteBytes([]byte{0x02, 0x00, 0x00}) // Unk - var id uint8 - entries := uint8(33) - resp.WriteUint8(entries) // Entry count - for id = 0; id < entries; id++ { - achData := GetAchData(id, scores[id]) - points += achData.Value + resp.WriteUint8(achievementEntryCount) + for id := uint8(0); id < achievementEntryCount; id++ { + ach := summary.Achievements[id] resp.WriteUint8(id) - resp.WriteUint8(achData.Level) - resp.WriteUint16(achData.NextValue) - resp.WriteUint32(achData.Required) + resp.WriteUint8(ach.Level) + resp.WriteUint16(ach.NextValue) + resp.WriteUint32(ach.Required) resp.WriteBool(false) // TODO: Notify on rank increase since last checked, see MhfDisplayedAchievement - resp.WriteUint8(achData.Trophy) + resp.WriteUint8(ach.Trophy) /* Trophy bitfield 0000 0000 abcd efgh @@ -132,13 +124,13 @@ func handleMsgMhfGetAchievement(s *Session, p mhfpacket.MHFPacket) { B-H - Gold (0x7F) */ resp.WriteUint16(0) // Unk - resp.WriteUint32(achData.Progress) + resp.WriteUint32(ach.Progress) } _, _ = resp.Seek(0, io.SeekStart) - resp.WriteUint32(points) - resp.WriteUint32(points) - resp.WriteUint32(points) - resp.WriteUint32(points) + resp.WriteUint32(summary.Points) + resp.WriteUint32(summary.Points) + resp.WriteUint32(summary.Points) + resp.WriteUint32(summary.Points) doAckBufSucceed(s, pkt.AckHandle, resp.Data()) } @@ -151,16 +143,9 @@ func handleMsgMhfResetAchievement(s *Session, p mhfpacket.MHFPacket) {} func handleMsgMhfAddAchievement(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgMhfAddAchievement) - if pkt.AchievementID > 32 { - return - } - if err := s.server.achievementRepo.EnsureExists(s.charID); err != nil { - s.logger.Error("Failed to ensure achievements record", zap.Error(err)) - } - - if err := s.server.achievementRepo.IncrementScore(s.charID, pkt.AchievementID); err != nil { - s.logger.Error("Failed to update achievement score", zap.Error(err)) + if err := s.server.achievementService.Increment(s.charID, pkt.AchievementID); err != nil { + s.logger.Warn("Failed to increment achievement", zap.Error(err)) } } diff --git a/server/channelserver/handlers_achievement_test.go b/server/channelserver/handlers_achievement_test.go index 1cb2dac56..195ece2f6 100644 --- a/server/channelserver/handlers_achievement_test.go +++ b/server/channelserver/handlers_achievement_test.go @@ -461,6 +461,7 @@ func TestHandleMsgMhfGetAchievement_Success(t *testing.T) { scores: [33]int32{5, 0, 20, 0, 0, 0, 0, 1}, // A few non-zero scores } server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(1, server) pkt := &mhfpacket.MsgMhfGetAchievement{ @@ -492,6 +493,7 @@ func TestHandleMsgMhfGetAchievement_DBError(t *testing.T) { getScoresErr: errNotFound, } server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(1, server) pkt := &mhfpacket.MsgMhfGetAchievement{ @@ -516,6 +518,7 @@ func TestHandleMsgMhfGetAchievement_AllZeroScores(t *testing.T) { server := createMockServer() mock := &mockAchievementRepo{} // All scores default to 0 server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(1, server) pkt := &mhfpacket.MsgMhfGetAchievement{ @@ -539,6 +542,7 @@ func TestHandleMsgMhfAddAchievement_Valid(t *testing.T) { server := createMockServer() mock := &mockAchievementRepo{} server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(42, server) pkt := &mhfpacket.MsgMhfAddAchievement{ @@ -559,6 +563,7 @@ func TestHandleMsgMhfAddAchievement_OutOfRange(t *testing.T) { server := createMockServer() mock := &mockAchievementRepo{} server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(42, server) pkt := &mhfpacket.MsgMhfAddAchievement{ @@ -576,6 +581,7 @@ func TestHandleMsgMhfAddAchievement_BoundaryID32(t *testing.T) { server := createMockServer() mock := &mockAchievementRepo{} server.achievementRepo = mock + ensureAchievementService(server) session := createMockSession(42, server) pkt := &mhfpacket.MsgMhfAddAchievement{ diff --git a/server/channelserver/svc_achievement.go b/server/channelserver/svc_achievement.go new file mode 100644 index 000000000..01e93b003 --- /dev/null +++ b/server/channelserver/svc_achievement.go @@ -0,0 +1,62 @@ +package channelserver + +import ( + "fmt" + + "go.uber.org/zap" +) + +// AchievementService encapsulates business logic for the achievement system. +type AchievementService struct { + achievementRepo AchievementRepo + logger *zap.Logger +} + +// NewAchievementService creates a new AchievementService. +func NewAchievementService(ar AchievementRepo, log *zap.Logger) *AchievementService { + return &AchievementService{achievementRepo: ar, logger: log} +} + +const achievementEntryCount = uint8(33) + +// AchievementSummary holds the computed achievements and total points for a character. +type AchievementSummary struct { + Points uint32 + Achievements [33]Achievement +} + +// GetAll ensures the achievement record exists, fetches all scores, and computes +// the achievement state for every category. Returns the total accumulated points +// and per-category Achievement data. +func (svc *AchievementService) GetAll(charID uint32) (*AchievementSummary, error) { + if err := svc.achievementRepo.EnsureExists(charID); err != nil { + svc.logger.Error("Failed to ensure achievements record", zap.Error(err)) + } + + scores, err := svc.achievementRepo.GetAllScores(charID) + if err != nil { + return nil, err + } + + var summary AchievementSummary + for id := uint8(0); id < achievementEntryCount; id++ { + ach := GetAchData(id, scores[id]) + summary.Points += ach.Value + summary.Achievements[id] = ach + } + return &summary, nil +} + +// Increment validates the achievement ID, ensures the record exists, and bumps +// the score for the given achievement category. +func (svc *AchievementService) Increment(charID uint32, achievementID uint8) error { + if achievementID > 32 { + return fmt.Errorf("achievement ID %d out of range [0, 32]", achievementID) + } + + if err := svc.achievementRepo.EnsureExists(charID); err != nil { + svc.logger.Error("Failed to ensure achievements record", zap.Error(err)) + } + + return svc.achievementRepo.IncrementScore(charID, achievementID) +} diff --git a/server/channelserver/svc_achievement_test.go b/server/channelserver/svc_achievement_test.go new file mode 100644 index 000000000..c60d6ed19 --- /dev/null +++ b/server/channelserver/svc_achievement_test.go @@ -0,0 +1,169 @@ +package channelserver + +import ( + "testing" + + "go.uber.org/zap" +) + +func newTestAchievementService(repo AchievementRepo) *AchievementService { + logger, _ := zap.NewDevelopment() + return NewAchievementService(repo, logger) +} + +func TestAchievementService_GetAll(t *testing.T) { + tests := []struct { + name string + scores [33]int32 + scoresErr error + wantErr bool + wantPoints uint32 + }{ + { + name: "all zeros", + scores: [33]int32{}, + wantPoints: 0, + }, + { + name: "some scores", + scores: [33]int32{5, 0, 20}, + wantPoints: 5 + 0 + 15, // id0: level1=5pts, id1: level0=0pts, id2: level1(5)+level2(10)=15pts (score=20, curve[0]={5,15,...}: 20-5=15, 15-15=0 → level2=15pts) + }, + { + name: "db error", + scoresErr: errNotFound, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockAchievementRepo{ + scores: tt.scores, + getScoresErr: tt.scoresErr, + } + svc := newTestAchievementService(mock) + + summary, err := svc.GetAll(1) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if !mock.ensureCalled { + t.Error("EnsureExists should have been called") + } + if summary.Points != tt.wantPoints { + t.Errorf("Points = %d, want %d", summary.Points, tt.wantPoints) + } + }) + } +} + +func TestAchievementService_GetAll_EnsureErrorNonFatal(t *testing.T) { + mock := &mockAchievementRepo{ + ensureErr: errNotFound, + scores: [33]int32{}, + } + svc := newTestAchievementService(mock) + + summary, err := svc.GetAll(1) + if err != nil { + t.Fatalf("EnsureExists error should not propagate: %v", err) + } + if summary == nil { + t.Fatal("Summary should not be nil") + } +} + +func TestAchievementService_GetAll_AchievementCount(t *testing.T) { + mock := &mockAchievementRepo{scores: [33]int32{}} + svc := newTestAchievementService(mock) + + summary, err := svc.GetAll(1) + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + + // Verify all 33 entries are populated + for id := uint8(0); id < 33; id++ { + // At score 0, every achievement should be level 0 + if summary.Achievements[id].Level != 0 { + t.Errorf("Achievement[%d].Level = %d, want 0", id, summary.Achievements[id].Level) + } + } +} + +func TestAchievementService_Increment(t *testing.T) { + tests := []struct { + name string + achievementID uint8 + incrementErr error + wantErr bool + wantEnsure bool + wantIncID uint8 + }{ + { + name: "valid ID", + achievementID: 5, + wantEnsure: true, + wantIncID: 5, + }, + { + name: "boundary ID 0", + achievementID: 0, + wantEnsure: true, + wantIncID: 0, + }, + { + name: "boundary ID 32", + achievementID: 32, + wantEnsure: true, + wantIncID: 32, + }, + { + name: "out of range", + achievementID: 33, + wantErr: true, + }, + { + name: "repo error", + achievementID: 5, + incrementErr: errNotFound, + wantErr: true, + wantEnsure: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &mockAchievementRepo{ + incrementErr: tt.incrementErr, + } + svc := newTestAchievementService(mock) + + err := svc.Increment(1, tt.achievementID) + + if tt.wantErr { + if err == nil { + t.Fatal("Expected error, got nil") + } + return + } + if err != nil { + t.Fatalf("Unexpected error: %v", err) + } + if mock.ensureCalled != tt.wantEnsure { + t.Errorf("EnsureExists called = %v, want %v", mock.ensureCalled, tt.wantEnsure) + } + if mock.incrementedID != tt.wantIncID { + t.Errorf("IncrementScore ID = %d, want %d", mock.incrementedID, tt.wantIncID) + } + }) + } +} diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index 6fdca87e7..5e517eacb 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -71,7 +71,8 @@ type Server struct { miscRepo MiscRepo scenarioRepo ScenarioRepo mercenaryRepo MercenaryRepo - guildService *GuildService + guildService *GuildService + achievementService *AchievementService erupeConfig *cfg.Config acceptConns chan net.Conn deleteConns chan net.Conn @@ -155,6 +156,7 @@ func NewServer(config *Config) *Server { s.mercenaryRepo = NewMercenaryRepository(config.DB) s.guildService = NewGuildService(s.guildRepo, s.mailRepo, s.charRepo, s.logger) + s.achievementService = NewAchievementService(s.achievementRepo, s.logger) // Mezeporta s.stages.Store("sl1Ns200p0a0u0", NewStage("sl1Ns200p0a0u0")) diff --git a/server/channelserver/test_helpers_test.go b/server/channelserver/test_helpers_test.go index 5a46e6ff6..e4fbca83b 100644 --- a/server/channelserver/test_helpers_test.go +++ b/server/channelserver/test_helpers_test.go @@ -61,6 +61,11 @@ func ensureGuildService(s *Server) { s.guildService = NewGuildService(s.guildRepo, s.mailRepo, s.charRepo, s.logger) } +// ensureAchievementService wires the AchievementService from the server's current repos. +func ensureAchievementService(s *Server) { + s.achievementService = NewAchievementService(s.achievementRepo, s.logger) +} + // createMockSession creates a minimal Session for testing. // Imported from v9.2.x-stable and adapted for main. func createMockSession(charID uint32, server *Server) *Session {