diff --git a/server/channelserver/handlers_achievement_test.go b/server/channelserver/handlers_achievement_test.go index 992e64a7d..e7c5d2869 100644 --- a/server/channelserver/handlers_achievement_test.go +++ b/server/channelserver/handlers_achievement_test.go @@ -240,3 +240,215 @@ func TestEmptyAchievementHandlers(t *testing.T) { }) } } + +// --- NEW TESTS --- + +// TestGetAchData_Level6BronzeTrophy tests that level 6 (in-progress toward level 7) +// awards the bronze trophy (0x40). +// Curve 0: {5, 15, 30, 50, 100, 150, 200, 300} +// Cumulative at each level: L1=5, L2=20, L3=50, L4=100, L5=200, L6=350, L7=550, L8=850 +// At cumulative 350, we reach level 6. Score 400 means level 6 with progress 50 toward next. +func TestGetAchData_Level6BronzeTrophy(t *testing.T) { + // Score to reach level 6 and be partway to level 7: + // cumulative to level 6 = 5+15+30+50+100+150 = 350 + // score 400 = level 6 with 50 remaining progress + ach := GetAchData(0, 400) + if ach.Level != 6 { + t.Errorf("Level = %d, want 6", ach.Level) + } + if ach.Trophy != 0x40 { + t.Errorf("Trophy = 0x%02x, want 0x40 (bronze)", ach.Trophy) + } + if ach.NextValue != 15 { + t.Errorf("NextValue = %d, want 15", ach.NextValue) + } + if ach.Progress != 50 { + t.Errorf("Progress = %d, want 50", ach.Progress) + } + if ach.Required != 200 { + t.Errorf("Required = %d, want 200 (curve[6])", ach.Required) + } +} + +// TestGetAchData_Level7SilverTrophy tests that level 7 (in-progress toward level 8) +// awards the silver trophy (0x60). +// cumulative to level 7 = 5+15+30+50+100+150+200 = 550 +// score 600 = level 7 with 50 remaining progress +func TestGetAchData_Level7SilverTrophy(t *testing.T) { + ach := GetAchData(0, 600) + if ach.Level != 7 { + t.Errorf("Level = %d, want 7", ach.Level) + } + if ach.Trophy != 0x60 { + t.Errorf("Trophy = 0x%02x, want 0x60 (silver)", ach.Trophy) + } + if ach.NextValue != 20 { + t.Errorf("NextValue = %d, want 20", ach.NextValue) + } + if ach.Progress != 50 { + t.Errorf("Progress = %d, want 50", ach.Progress) + } + if ach.Required != 300 { + t.Errorf("Required = %d, want 300 (curve[7])", ach.Required) + } +} + +// TestGetAchData_MaxedOut_AllCurves tests that reaching max level on each curve +// produces the correct gold trophy and the last threshold as Required/Progress. +func TestGetAchData_MaxedOut_AllCurves(t *testing.T) { + tests := []struct { + name string + id uint8 + score int32 + lastThresh int32 + }{ + // Curve 0: {5,15,30,50,100,150,200,300} sum=850, last=300 + {"Curve0_ID0", 0, 5000, 300}, + // Curve 1: {1,5,10,15,30,50,75,100} sum=286, last=100 + {"Curve1_ID7", 7, 5000, 100}, + // Curve 2: {1,2,3,4,5,6,7,8} sum=36, last=8 + {"Curve2_ID8", 8, 5000, 8}, + // Curve 3: {10,50,100,200,350,500,750,999} sum=2959, last=999 + {"Curve3_ID16", 16, 50000, 999}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ach := GetAchData(tt.id, tt.score) + if ach.Level != 8 { + t.Errorf("Level = %d, want 8 (max)", ach.Level) + } + if ach.Trophy != 0x7F { + t.Errorf("Trophy = 0x%02x, want 0x7F (gold)", ach.Trophy) + } + if ach.Required != uint32(tt.lastThresh) { + t.Errorf("Required = %d, want %d", ach.Required, tt.lastThresh) + } + if ach.Progress != ach.Required { + t.Errorf("Progress = %d, want %d (should equal Required at max)", ach.Progress, ach.Required) + } + }) + } +} + +// TestGetAchData_ExactlyAtEachThreshold tests the exact cumulative score at each +// threshold boundary for curve 0. +func TestGetAchData_ExactlyAtEachThreshold(t *testing.T) { + // Curve 0: {5, 15, 30, 50, 100, 150, 200, 300} + // Cumulative thresholds (exact score to reach each level): + // L1: 5, L2: 20, L3: 50, L4: 100, L5: 200, L6: 350, L7: 550, L8: 850 + cumulativeScores := []int32{5, 20, 50, 100, 200, 350, 550, 850} + expectedLevels := []uint8{1, 2, 3, 4, 5, 6, 7, 8} + expectedValues := []uint32{5, 15, 25, 35, 50, 65, 80, 100} + + for i, score := range cumulativeScores { + t.Run("ExactThreshold_L"+string(rune('1'+i)), func(t *testing.T) { + ach := GetAchData(0, score) + if ach.Level != expectedLevels[i] { + t.Errorf("score=%d: Level = %d, want %d", score, ach.Level, expectedLevels[i]) + } + if ach.Value != expectedValues[i] { + t.Errorf("score=%d: Value = %d, want %d", score, ach.Value, expectedValues[i]) + } + }) + } +} + +// TestGetAchData_OneBeforeEachThreshold tests scores that are one less than +// each cumulative threshold, verifying they stay at the previous level. +func TestGetAchData_OneBeforeEachThreshold(t *testing.T) { + // Curve 0: cumulative thresholds: 5, 20, 50, 100, 200, 350, 550, 850 + cumulativeScores := []int32{4, 19, 49, 99, 199, 349, 549, 849} + expectedLevels := []uint8{0, 1, 2, 3, 4, 5, 6, 7} + + for i, score := range cumulativeScores { + t.Run("OneBeforeThreshold_L"+string(rune('0'+i)), func(t *testing.T) { + ach := GetAchData(0, score) + if ach.Level != expectedLevels[i] { + t.Errorf("score=%d: Level = %d, want %d", score, ach.Level, expectedLevels[i]) + } + }) + } +} + +// TestGetAchData_Curve2_FestaWins exercises the "Festa wins" curve which has +// small thresholds: {1, 2, 3, 4, 5, 6, 7, 8} +func TestGetAchData_Curve2_FestaWins(t *testing.T) { + // Curve 2: {1, 2, 3, 4, 5, 6, 7, 8} + // Cumulative: 1, 3, 6, 10, 15, 21, 28, 36 + tests := []struct { + score int32 + wantLvl uint8 + wantProg uint32 + wantReq uint32 + }{ + {0, 0, 0, 1}, + {1, 1, 0, 2}, // Exactly at first threshold + {2, 1, 1, 2}, // One into second threshold + {3, 2, 0, 3}, // Exactly at second cumulative + {36, 8, 8, 8}, // Max level (sum of all thresholds) + {100, 8, 8, 8}, // Well above max + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + ach := GetAchData(8, tt.score) // ID 8 maps to curve 2 + if ach.Level != tt.wantLvl { + t.Errorf("score=%d: Level = %d, want %d", tt.score, ach.Level, tt.wantLvl) + } + if ach.Progress != tt.wantProg { + t.Errorf("score=%d: Progress = %d, want %d", tt.score, ach.Progress, tt.wantProg) + } + if ach.Required != tt.wantReq { + t.Errorf("score=%d: Required = %d, want %d", tt.score, ach.Required, tt.wantReq) + } + }) + } +} + +// TestGetAchData_AllIDs_ZeroScore verifies that calling GetAchData with score=0 +// for every valid ID returns level 0 without panicking. +func TestGetAchData_AllIDs_ZeroScore(t *testing.T) { + for id := uint8(0); id <= 32; id++ { + ach := GetAchData(id, 0) + if ach.Level != 0 { + t.Errorf("ID %d, score 0: Level = %d, want 0", id, ach.Level) + } + if ach.Value != 0 { + t.Errorf("ID %d, score 0: Value = %d, want 0", id, ach.Value) + } + if ach.Trophy != 0 { + t.Errorf("ID %d, score 0: Trophy = 0x%02x, want 0x00", id, ach.Trophy) + } + } +} + +// TestGetAchData_AllIDs_MaxScore verifies that calling GetAchData with a very +// high score for every valid ID returns level 8 with gold trophy. +func TestGetAchData_AllIDs_MaxScore(t *testing.T) { + for id := uint8(0); id <= 32; id++ { + ach := GetAchData(id, 99999) + if ach.Level != 8 { + t.Errorf("ID %d: Level = %d, want 8", id, ach.Level) + } + if ach.Trophy != 0x7F { + t.Errorf("ID %d: Trophy = 0x%02x, want 0x7F", id, ach.Trophy) + } + // At max, Progress should equal Required + if ach.Progress != ach.Required { + t.Errorf("ID %d: Progress (%d) != Required (%d) at max", id, ach.Progress, ach.Required) + } + } +} + +// TestGetAchData_UpdatedAlwaysFalse confirms Updated is always false since +// GetAchData never sets it. +func TestGetAchData_UpdatedAlwaysFalse(t *testing.T) { + scores := []int32{0, 1, 5, 50, 500, 5000} + for _, score := range scores { + ach := GetAchData(0, score) + if ach.Updated { + t.Errorf("score=%d: Updated should always be false, got true", score) + } + } +} diff --git a/server/channelserver/handlers_coverage2_test.go b/server/channelserver/handlers_coverage2_test.go new file mode 100644 index 000000000..9254eaf2d --- /dev/null +++ b/server/channelserver/handlers_coverage2_test.go @@ -0,0 +1,942 @@ +package channelserver + +import ( + "testing" + + "erupe-ce/config" + "erupe-ce/network/mhfpacket" +) + +// Tests for guild handlers that do not require database access. + +func TestHandleMsgMhfEntryRookieGuild(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEntryRookieGuild{ + AckHandle: 12345, + Unk: 42, + } + + handleMsgMhfEntryRookieGuild(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfGenerateUdGuildMap(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfGenerateUdGuildMap{ + AckHandle: 12345, + } + + handleMsgMhfGenerateUdGuildMap(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfCheckMonthlyItem(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfCheckMonthlyItem{ + AckHandle: 12345, + Type: 0, + } + + handleMsgMhfCheckMonthlyItem(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfAcquireMonthlyItem(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfAcquireMonthlyItem{ + AckHandle: 12345, + } + + handleMsgMhfAcquireMonthlyItem(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfEnumerateInvGuild(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEnumerateInvGuild{ + AckHandle: 12345, + } + + handleMsgMhfEnumerateInvGuild(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfOperationInvGuild(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfOperationInvGuild{ + AckHandle: 12345, + Operation: 1, + } + + handleMsgMhfOperationInvGuild(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Tests for mercenary handlers that do not require database access. + +func TestHandleMsgMhfMercenaryHuntdata_Unk0Is1(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfMercenaryHuntdata{ + AckHandle: 12345, + Unk0: 1, + } + + handleMsgMhfMercenaryHuntdata(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfMercenaryHuntdata_Unk0Is0(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfMercenaryHuntdata{ + AckHandle: 12345, + Unk0: 0, + } + + handleMsgMhfMercenaryHuntdata(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfMercenaryHuntdata_Unk0Is2(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfMercenaryHuntdata{ + AckHandle: 12345, + Unk0: 2, + } + + handleMsgMhfMercenaryHuntdata(session, pkt) + + // Unk0=2 takes the else branch (same as 0) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Tests for festa/ranking handlers. + +func TestHandleMsgMhfEnumerateRanking_DefaultBranch(t *testing.T) { + server := createMockServer() + server.erupeConfig = &config.Config{ + DevMode: false, + DevModeOptions: config.DevModeOptions{ + TournamentEvent: 0, + }, + } + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEnumerateRanking{ + AckHandle: 99999, + } + + handleMsgMhfEnumerateRanking(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfEnumerateRanking_NegativeState(t *testing.T) { + server := createMockServer() + server.erupeConfig = &config.Config{ + DevMode: true, + DevModeOptions: config.DevModeOptions{ + TournamentEvent: -1, + }, + } + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEnumerateRanking{ + AckHandle: 99999, + } + + handleMsgMhfEnumerateRanking(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Tests for rengoku handlers. + +func TestHandleMsgMhfGetRengokuRankingRank_ResponseData(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfGetRengokuRankingRank{ + AckHandle: 55555, + } + + handleMsgMhfGetRengokuRankingRank(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Tests for empty handlers that are not covered in other test files. + +func TestEmptyHandlers_Coverage2(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + handler func(s *Session, p mhfpacket.MHFPacket) + }{ + {"handleMsgSysCastedBinary", handleMsgSysCastedBinary}, + {"handleMsgMhfResetTitle", handleMsgMhfResetTitle}, + {"handleMsgMhfUpdateForceGuildRank", handleMsgMhfUpdateForceGuildRank}, + {"handleMsgMhfUpdateGuild", handleMsgMhfUpdateGuild}, + {"handleMsgMhfUpdateGuildcard", handleMsgMhfUpdateGuildcard}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.handler(session, nil) + }) + } +} + +// Tests for handlers.go - handlers that produce responses without DB access. + +func TestHandleMsgSysTerminalLog_MultipleEntries(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysTerminalLog{ + AckHandle: 12345, + LogID: 200, + Entries: []*mhfpacket.TerminalLogEntry{ + {Type1: 10, Type2: 20, Data: []int16{100, 200, 300}}, + {Type1: 11, Type2: 21, Data: []int16{400, 500}}, + {Type1: 12, Type2: 22, Data: []int16{600}}, + }, + } + + handleMsgSysTerminalLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysTerminalLog_ZeroLogID(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysTerminalLog{ + AckHandle: 12345, + LogID: 0, + Entries: []*mhfpacket.TerminalLogEntry{}, + } + + handleMsgSysTerminalLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysPing_DifferentAckHandle(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysPing{ + AckHandle: 0xFFFFFFFF, + } + + handleMsgSysPing(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysTime_GetRemoteTimeFalse(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysTime{ + GetRemoteTime: false, + } + + handleMsgSysTime(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysIssueLogkey_LogKeyGenerated(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysIssueLogkey{ + AckHandle: 77777, + } + + handleMsgSysIssueLogkey(session, pkt) + + // Verify that the logKey was set on the session + session.Lock() + keyLen := len(session.logKey) + session.Unlock() + + if keyLen != 16 { + t.Errorf("logKey length = %d, want 16", keyLen) + } + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysIssueLogkey_Uniqueness(t *testing.T) { + server := createMockServer() + + // Generate two logkeys and verify they differ + session1 := createMockSession(1, server) + session2 := createMockSession(2, server) + + pkt1 := &mhfpacket.MsgSysIssueLogkey{AckHandle: 1} + pkt2 := &mhfpacket.MsgSysIssueLogkey{AckHandle: 2} + + handleMsgSysIssueLogkey(session1, pkt1) + handleMsgSysIssueLogkey(session2, pkt2) + + // Drain send packets + <-session1.sendPackets + <-session2.sendPackets + + session1.Lock() + key1 := make([]byte, len(session1.logKey)) + copy(key1, session1.logKey) + session1.Unlock() + + session2.Lock() + key2 := make([]byte, len(session2.logKey)) + copy(key2, session2.logKey) + session2.Unlock() + + if len(key1) != 16 || len(key2) != 16 { + t.Fatalf("logKeys should be 16 bytes each, got %d and %d", len(key1), len(key2)) + } + + same := true + for i := range key1 { + if key1[i] != key2[i] { + same = false + break + } + } + if same { + t.Error("Two generated logkeys should differ (extremely unlikely to be the same)") + } +} + +// Tests for event handlers. + +func TestHandleMsgMhfReleaseEvent_ErrorCode(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfReleaseEvent{ + AckHandle: 88888, + } + + handleMsgMhfReleaseEvent(session, pkt) + + // This handler manually sends a response with error code 0x41 + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfEnumerateEvent_Stub(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEnumerateEvent{ + AckHandle: 77777, + } + + handleMsgMhfEnumerateEvent(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Tests for achievement handler. + +func TestHandleMsgMhfSetCaAchievementHist_Response(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfSetCaAchievementHist{ + AckHandle: 44444, + } + + handleMsgMhfSetCaAchievementHist(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Test concurrent handler invocations to catch potential data races. + +func TestHandlersConcurrentInvocations(t *testing.T) { + server := createMockServer() + + done := make(chan struct{}) + const numGoroutines = 10 + + for i := 0; i < numGoroutines; i++ { + go func(id uint32) { + defer func() { + if r := recover(); r != nil { + t.Errorf("goroutine %d panicked: %v", id, r) + } + done <- struct{}{} + }() + + session := createMockSession(id, server) + + // Run several handlers concurrently + handleMsgSysPing(session, &mhfpacket.MsgSysPing{AckHandle: id}) + <-session.sendPackets + + handleMsgSysTime(session, &mhfpacket.MsgSysTime{GetRemoteTime: true}) + <-session.sendPackets + + handleMsgSysIssueLogkey(session, &mhfpacket.MsgSysIssueLogkey{AckHandle: id}) + <-session.sendPackets + + handleMsgMhfMercenaryHuntdata(session, &mhfpacket.MsgMhfMercenaryHuntdata{AckHandle: id, Unk0: 1}) + <-session.sendPackets + + handleMsgMhfEnumerateMercenaryLog(session, &mhfpacket.MsgMhfEnumerateMercenaryLog{AckHandle: id}) + <-session.sendPackets + }(uint32(i + 100)) + } + + for i := 0; i < numGoroutines; i++ { + <-done + } +} + +// Test festa handler with various config states. + +func TestHandleMsgMhfVoteFesta_Response(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfVoteFesta{ + AckHandle: 33333, + } + + handleMsgMhfVoteFesta(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Test record log handler with stage setup. + +func TestHandleMsgSysRecordLog_RemovesReservation(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + stage := NewStage("test_stage_record") + session.stage = stage + stage.reservedClientSlots[session.charID] = true + + pkt := &mhfpacket.MsgSysRecordLog{ + AckHandle: 55555, + } + + handleMsgSysRecordLog(session, pkt) + + if _, exists := stage.reservedClientSlots[session.charID]; exists { + t.Error("charID should be removed from reserved slots after record log") + } + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgSysRecordLog_NoExistingReservation(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + stage := NewStage("test_stage_no_reservation") + session.stage = stage + // No reservation exists for this charID + + pkt := &mhfpacket.MsgSysRecordLog{ + AckHandle: 55556, + } + + // Should not panic even if charID is not in reservedClientSlots + handleMsgSysRecordLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Test unlock global sema handler. + +func TestHandleMsgSysUnlockGlobalSema_Response(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysUnlockGlobalSema{ + AckHandle: 66666, + } + + handleMsgSysUnlockGlobalSema(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Test handlers from handlers_event.go with edge cases. + +func TestHandleMsgMhfSetRestrictionEvent_Response(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfSetRestrictionEvent{ + AckHandle: 11111, + } + + handleMsgMhfSetRestrictionEvent(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +func TestHandleMsgMhfGetRestrictionEvent_Empty(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + defer func() { + if r := recover(); r != nil { + t.Errorf("handleMsgMhfGetRestrictionEvent panicked: %v", r) + } + }() + + handleMsgMhfGetRestrictionEvent(session, nil) +} + +// Test handlers from handlers_mercenary.go - legend dispatch (no DB). + +func TestHandleMsgMhfLoadLegendDispatch_Response(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfLoadLegendDispatch{ + AckHandle: 22222, + } + + handleMsgMhfLoadLegendDispatch(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// Test multiple handler invocations on the same session to verify session state is not corrupted. + +func TestMultipleHandlersOnSameSession(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Call multiple handlers in sequence + handleMsgSysPing(session, &mhfpacket.MsgSysPing{AckHandle: 1}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from Ping handler") + } + + handleMsgSysTime(session, &mhfpacket.MsgSysTime{GetRemoteTime: true}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from Time handler") + } + + handleMsgMhfRegisterEvent(session, &mhfpacket.MsgMhfRegisterEvent{AckHandle: 2, Unk2: 5, Unk4: 10}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from RegisterEvent handler") + } + + handleMsgMhfReleaseEvent(session, &mhfpacket.MsgMhfReleaseEvent{AckHandle: 3}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from ReleaseEvent handler") + } + + handleMsgMhfEnumerateEvent(session, &mhfpacket.MsgMhfEnumerateEvent{AckHandle: 4}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from EnumerateEvent handler") + } + + handleMsgMhfSetCaAchievementHist(session, &mhfpacket.MsgMhfSetCaAchievementHist{AckHandle: 5}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from SetCaAchievementHist handler") + } + + handleMsgMhfGetRengokuRankingRank(session, &mhfpacket.MsgMhfGetRengokuRankingRank{AckHandle: 6}) + select { + case <-session.sendPackets: + default: + t.Fatal("Expected packet from GetRengokuRankingRank handler") + } +} + +// Test festa timestamp generation. + +func TestGenerateFestaTimestamps_Debug(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + start uint32 + }{ + {"Debug_Start1", 1}, + {"Debug_Start2", 2}, + {"Debug_Start3", 3}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + timestamps := generateFestaTimestamps(session, tt.start, true) + if len(timestamps) != 5 { + t.Errorf("Expected 5 timestamps, got %d", len(timestamps)) + } + for i, ts := range timestamps { + if ts == 0 { + t.Errorf("Timestamp %d should not be zero", i) + } + } + }) + } +} + +func TestGenerateFestaTimestamps_NonDebug_FutureStart(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Use a far-future start time so it does not trigger cleanup + futureStart := uint32(TimeAdjusted().Unix() + 5000000) + timestamps := generateFestaTimestamps(session, futureStart, false) + + if len(timestamps) != 5 { + t.Errorf("Expected 5 timestamps, got %d", len(timestamps)) + } + if timestamps[0] != futureStart { + t.Errorf("First timestamp = %d, want %d", timestamps[0], futureStart) + } + // Verify intervals + if timestamps[1] != timestamps[0]+604800 { + t.Errorf("Second timestamp should be start+604800, got %d", timestamps[1]) + } + if timestamps[2] != timestamps[1]+604800 { + t.Errorf("Third timestamp should be second+604800, got %d", timestamps[2]) + } + if timestamps[3] != timestamps[2]+9000 { + t.Errorf("Fourth timestamp should be third+9000, got %d", timestamps[3]) + } + if timestamps[4] != timestamps[3]+1240200 { + t.Errorf("Fifth timestamp should be fourth+1240200, got %d", timestamps[4]) + } +} + +// Test trial struct from handlers_festa.go. + +func TestTrialStruct(t *testing.T) { + trial := Trial{ + ID: 100, + Objective: 2, + GoalID: 500, + TimesReq: 10, + Locale: 1, + Reward: 50, + } + if trial.ID != 100 { + t.Errorf("ID = %d, want 100", trial.ID) + } + if trial.Objective != 2 { + t.Errorf("Objective = %d, want 2", trial.Objective) + } + if trial.GoalID != 500 { + t.Errorf("GoalID = %d, want 500", trial.GoalID) + } + if trial.TimesReq != 10 { + t.Errorf("TimesReq = %d, want 10", trial.TimesReq) + } +} + +// Test prize struct from handlers_festa.go. + +func TestPrizeStruct(t *testing.T) { + prize := Prize{ + ID: 1, + Tier: 2, + SoulsReq: 100, + ItemID: 0x1234, + NumItem: 5, + Claimed: 1, + } + if prize.ID != 1 { + t.Errorf("ID = %d, want 1", prize.ID) + } + if prize.Tier != 2 { + t.Errorf("Tier = %d, want 2", prize.Tier) + } + if prize.SoulsReq != 100 { + t.Errorf("SoulsReq = %d, want 100", prize.SoulsReq) + } + if prize.Claimed != 1 { + t.Errorf("Claimed = %d, want 1", prize.Claimed) + } +} + +// Test CatDefinition struct from handlers_mercenary.go. + +func TestCatDefinitionStruct(t *testing.T) { + cat := CatDefinition{ + CatID: 42, + CatName: []byte("TestCat"), + CurrentTask: 4, + Personality: 2, + Class: 1, + Experience: 1500, + WeaponType: 6, + WeaponID: 100, + } + + if cat.CatID != 42 { + t.Errorf("CatID = %d, want 42", cat.CatID) + } + if cat.CurrentTask != 4 { + t.Errorf("CurrentTask = %d, want 4", cat.CurrentTask) + } + if cat.Experience != 1500 { + t.Errorf("Experience = %d, want 1500", cat.Experience) + } + if cat.WeaponType != 6 { + t.Errorf("WeaponType = %d, want 6", cat.WeaponType) + } + if cat.WeaponID != 100 { + t.Errorf("WeaponID = %d, want 100", cat.WeaponID) + } +} + +// Test RengokuScore struct default values. + +func TestRengokuScoreStruct_Fields(t *testing.T) { + score := RengokuScore{ + Name: "Hunter", + Score: 99999, + } + + if score.Name != "Hunter" { + t.Errorf("Name = %s, want Hunter", score.Name) + } + if score.Score != 99999 { + t.Errorf("Score = %d, want 99999", score.Score) + } +} diff --git a/server/channelserver/handlers_coverage3_test.go b/server/channelserver/handlers_coverage3_test.go new file mode 100644 index 000000000..a5eafa379 --- /dev/null +++ b/server/channelserver/handlers_coverage3_test.go @@ -0,0 +1,1319 @@ +package channelserver + +import ( + "sync" + "testing" + + "erupe-ce/network/mhfpacket" +) + +// ============================================================================= +// Category 1: Empty handlers from handlers.go +// These have empty function bodies and can be called with nil packet safely. +// ============================================================================= + +func TestEmptyHandlers_HandlersGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgSysEcho", func() { handleMsgSysEcho(session, nil) }}, + {"handleMsgSysUpdateRight", func() { handleMsgSysUpdateRight(session, nil) }}, + {"handleMsgSysAuthQuery", func() { handleMsgSysAuthQuery(session, nil) }}, + {"handleMsgSysAuthTerminal", func() { handleMsgSysAuthTerminal(session, nil) }}, + {"handleMsgCaExchangeItem", func() { handleMsgCaExchangeItem(session, nil) }}, + {"handleMsgMhfServerCommand", func() { handleMsgMhfServerCommand(session, nil) }}, + {"handleMsgMhfSetLoginwindow", func() { handleMsgMhfSetLoginwindow(session, nil) }}, + {"handleMsgSysTransBinary", func() { handleMsgSysTransBinary(session, nil) }}, + {"handleMsgSysCollectBinary", func() { handleMsgSysCollectBinary(session, nil) }}, + {"handleMsgSysGetState", func() { handleMsgSysGetState(session, nil) }}, + {"handleMsgSysSerialize", func() { handleMsgSysSerialize(session, nil) }}, + {"handleMsgSysEnumlobby", func() { handleMsgSysEnumlobby(session, nil) }}, + {"handleMsgSysEnumuser", func() { handleMsgSysEnumuser(session, nil) }}, + {"handleMsgSysInfokyserver", func() { handleMsgSysInfokyserver(session, nil) }}, + {"handleMsgMhfGetCaUniqueID", func() { handleMsgMhfGetCaUniqueID(session, nil) }}, + {"handleMsgMhfEnumerateItem", func() { handleMsgMhfEnumerateItem(session, nil) }}, + {"handleMsgMhfAcquireItem", func() { handleMsgMhfAcquireItem(session, nil) }}, + {"handleMsgMhfGetExtraInfo", func() { handleMsgMhfGetExtraInfo(session, nil) }}, + {"handleMsgSysSetStatus", func() { handleMsgSysSetStatus(session, nil) }}, + // Also empty in handlers.go but have non-empty struct bodies + {"handleMsgMhfStampcardPrize", func() { handleMsgMhfStampcardPrize(session, nil) }}, + {"handleMsgMhfUnreserveSrg", func() { handleMsgMhfUnreserveSrg(session, nil) }}, + {"handleMsgMhfReadBeatLevelAllRanking", func() { handleMsgMhfReadBeatLevelAllRanking(session, nil) }}, + {"handleMsgMhfReadBeatLevelMyRanking", func() { handleMsgMhfReadBeatLevelMyRanking(session, nil) }}, + {"handleMsgMhfReadLastWeekBeatRanking", func() { handleMsgMhfReadLastWeekBeatRanking(session, nil) }}, + {"handleMsgMhfGetFixedSeibatuRankingTable", func() { handleMsgMhfGetFixedSeibatuRankingTable(session, nil) }}, + {"handleMsgMhfKickExportForce", func() { handleMsgMhfKickExportForce(session, nil) }}, + {"handleMsgMhfRegistSpabiTime", func() { handleMsgMhfRegistSpabiTime(session, nil) }}, + {"handleMsgMhfDebugPostValue", func() { handleMsgMhfDebugPostValue(session, nil) }}, + {"handleMsgMhfGetNotice", func() { handleMsgMhfGetNotice(session, nil) }}, + {"handleMsgMhfPostNotice", func() { handleMsgMhfPostNotice(session, nil) }}, + {"handleMsgMhfGetRandFromTable", func() { handleMsgMhfGetRandFromTable(session, nil) }}, + {"handleMsgMhfGetSenyuDailyCount", func() { handleMsgMhfGetSenyuDailyCount(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 2: Empty handlers from handlers_object.go +// All empty function bodies, safe to call with nil packet. +// ============================================================================= + +func TestEmptyHandlers_ObjectGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgSysDeleteObject", func() { handleMsgSysDeleteObject(session, nil) }}, + {"handleMsgSysRotateObject", func() { handleMsgSysRotateObject(session, nil) }}, + {"handleMsgSysDuplicateObject", func() { handleMsgSysDuplicateObject(session, nil) }}, + {"handleMsgSysGetObjectBinary", func() { handleMsgSysGetObjectBinary(session, nil) }}, + {"handleMsgSysGetObjectOwner", func() { handleMsgSysGetObjectOwner(session, nil) }}, + {"handleMsgSysUpdateObjectBinary", func() { handleMsgSysUpdateObjectBinary(session, nil) }}, + {"handleMsgSysCleanupObject", func() { handleMsgSysCleanupObject(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 3: Empty handlers from handlers_clients.go +// All empty function bodies, safe to call with nil packet. +// ============================================================================= + +func TestEmptyHandlers_ClientsGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgMhfShutClient", func() { handleMsgMhfShutClient(session, nil) }}, + {"handleMsgSysHideClient", func() { handleMsgSysHideClient(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 4: Empty handler from handlers_stage.go +// ============================================================================= + +func TestEmptyHandlers_StageGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgSysStageDestruct", func() { handleMsgSysStageDestruct(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 5: Empty handlers from handlers_achievement.go +// ============================================================================= + +func TestEmptyHandlers_AchievementGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgMhfDisplayedAchievement", func() { + handleMsgMhfDisplayedAchievement(session, &mhfpacket.MsgMhfDisplayedAchievement{}) + }}, + {"handleMsgMhfGetCaAchievementHist", func() { handleMsgMhfGetCaAchievementHist(session, nil) }}, + {"handleMsgMhfSetCaAchievement", func() { handleMsgMhfSetCaAchievement(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 6: Empty handlers from handlers_caravan.go +// ============================================================================= + +func TestEmptyHandlers_CaravanGo(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + {"handleMsgMhfCaravanMyScore", func() { handleMsgMhfCaravanMyScore(session, nil) }}, + {"handleMsgMhfCaravanRanking", func() { handleMsgMhfCaravanRanking(session, nil) }}, + {"handleMsgMhfCaravanMyRank", func() { handleMsgMhfCaravanMyRank(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 7: Simple ack handlers from handlers_tactics.go (no DB needed) +// ============================================================================= + +func TestSimpleAckHandlers_TacticsGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfAddUdTacticsPoint", func(s *Session) { + handleMsgMhfAddUdTacticsPoint(s, &mhfpacket.MsgMhfAddUdTacticsPoint{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 8: Simple ack handlers from handlers_tower.go (no DB needed) +// ============================================================================= + +func TestSimpleAckHandlers_TowerGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetTowerInfo_TowerRankPoint", func(s *Session) { + handleMsgMhfGetTowerInfo(s, &mhfpacket.MsgMhfGetTowerInfo{ + AckHandle: 1, + InfoType: mhfpacket.TowerInfoTypeTowerRankPoint, + }) + }}, + {"handleMsgMhfGetTowerInfo_GetOwnTowerSkill", func(s *Session) { + handleMsgMhfGetTowerInfo(s, &mhfpacket.MsgMhfGetTowerInfo{ + AckHandle: 1, + InfoType: mhfpacket.TowerInfoTypeGetOwnTowerSkill, + }) + }}, + {"handleMsgMhfGetTowerInfo_GetOwnTowerLevelV3", func(s *Session) { + handleMsgMhfGetTowerInfo(s, &mhfpacket.MsgMhfGetTowerInfo{ + AckHandle: 1, + InfoType: mhfpacket.TowerInfoTypeGetOwnTowerLevelV3, + }) + }}, + {"handleMsgMhfGetTowerInfo_TowerTouhaHistory", func(s *Session) { + handleMsgMhfGetTowerInfo(s, &mhfpacket.MsgMhfGetTowerInfo{ + AckHandle: 1, + InfoType: mhfpacket.TowerInfoTypeTowerTouhaHistory, + }) + }}, + {"handleMsgMhfGetTowerInfo_Unk5", func(s *Session) { + handleMsgMhfGetTowerInfo(s, &mhfpacket.MsgMhfGetTowerInfo{ + AckHandle: 1, + InfoType: mhfpacket.TowerInfoTypeUnk5, + }) + }}, + {"handleMsgMhfPostTowerInfo", func(s *Session) { + handleMsgMhfPostTowerInfo(s, &mhfpacket.MsgMhfPostTowerInfo{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 9: Simple ack handlers from handlers_reward.go (no DB needed) +// ============================================================================= + +func TestSimpleAckHandlers_RewardGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetRewardSong", func(s *Session) { + handleMsgMhfGetRewardSong(s, &mhfpacket.MsgMhfGetRewardSong{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 10: Simple ack handler from handlers_semaphore.go (no DB needed) +// handleMsgSysCreateSemaphore produces a response via doAckSimpleSucceed. +// ============================================================================= + +func TestSimpleAckHandlers_SemaphoreGo(t *testing.T) { + server := createMockServer() + + t.Run("handleMsgSysCreateSemaphore", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysCreateSemaphore(session, &mhfpacket.MsgSysCreateSemaphore{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("handleMsgSysCreateSemaphore: response should have data") + } + default: + t.Error("handleMsgSysCreateSemaphore: no response queued") + } + }) +} + +// ============================================================================= +// Category 11: handleMsgSysCreateAcquireSemaphore from handlers_semaphore.go +// This handler accesses s.server.semaphore map. It creates or acquires a +// semaphore, so it needs the semaphore map initialized on the server. +// ============================================================================= + +func TestHandleMsgSysCreateAcquireSemaphore(t *testing.T) { + server := createMockServer() + server.semaphore = make(map[string]*Semaphore) + + t.Run("creates_new_semaphore", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysCreateAcquireSemaphore(session, &mhfpacket.MsgSysCreateAcquireSemaphore{ + AckHandle: 1, + SemaphoreID: "test_sema_1", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + // Verify semaphore was created + if _, exists := server.semaphore["test_sema_1"]; !exists { + t.Error("semaphore should have been created in server map") + } + }) + + t.Run("acquires_existing_semaphore", func(t *testing.T) { + session := createMockSession(2, server) + // Acquire the same semaphore again + handleMsgSysCreateAcquireSemaphore(session, &mhfpacket.MsgSysCreateAcquireSemaphore{ + AckHandle: 2, + SemaphoreID: "test_sema_1", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("creates_ravi_semaphore", func(t *testing.T) { + session := createMockSession(3, server) + handleMsgSysCreateAcquireSemaphore(session, &mhfpacket.MsgSysCreateAcquireSemaphore{ + AckHandle: 3, + SemaphoreID: "hs_l0u3B51", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + if _, exists := server.semaphore["hs_l0u3B51"]; !exists { + t.Error("ravi semaphore should have been created") + } + }) +} + +// ============================================================================= +// Category 12: Additional simple ack handlers from various files (no DB) +// ============================================================================= + +func TestSimpleAckHandlers_MiscFiles(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + // From handlers_rengoku.go - handleMsgMhfGetRengokuBinary reads from file, + // but returns empty data on error which still produces a response. + {"handleMsgMhfGetRengokuBinary_noFile", func(s *Session) { + handleMsgMhfGetRengokuBinary(s, &mhfpacket.MsgMhfGetRengokuBinary{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 13: Other empty handlers from various files +// ============================================================================= + +func TestEmptyHandlers_MiscFiles(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + tests := []struct { + name string + fn func() + }{ + // From handlers_reward.go + {"handleMsgMhfUseRewardSong", func() { handleMsgMhfUseRewardSong(session, nil) }}, + {"handleMsgMhfAddRewardSongCount", func() { handleMsgMhfAddRewardSongCount(session, nil) }}, + {"handleMsgMhfAcceptReadReward", func() { handleMsgMhfAcceptReadReward(session, nil) }}, + // From handlers_tower.go + {"handleMsgMhfGetBreakSeibatuLevelReward", func() { handleMsgMhfGetBreakSeibatuLevelReward(session, nil) }}, + {"handleMsgMhfPostGemInfo", func() { handleMsgMhfPostGemInfo(session, nil) }}, + // From handlers_caravan.go + {"handleMsgMhfPostRyoudama", func() { handleMsgMhfPostRyoudama(session, nil) }}, + // From handlers_tactics.go + {"handleMsgMhfSetUdTacticsFollower", func() { handleMsgMhfSetUdTacticsFollower(session, nil) }}, + {"handleMsgMhfGetUdTacticsLog", func() { handleMsgMhfGetUdTacticsLog(session, nil) }}, + // From handlers_achievement.go + {"handleMsgMhfPaymentAchievement", func() { handleMsgMhfPaymentAchievement(session, nil) }}, + // From handlers.go (additional empty ones) + {"handleMsgMhfGetCogInfo", func() { handleMsgMhfGetCogInfo(session, nil) }}, + {"handleMsgMhfUseUdShopCoin", func() { handleMsgMhfUseUdShopCoin(session, nil) }}, + {"handleMsgMhfGetDailyMissionMaster", func() { handleMsgMhfGetDailyMissionMaster(session, nil) }}, + {"handleMsgMhfGetDailyMissionPersonal", func() { handleMsgMhfGetDailyMissionPersonal(session, nil) }}, + {"handleMsgMhfSetDailyMissionPersonal", func() { handleMsgMhfSetDailyMissionPersonal(session, nil) }}, + {"handleMsgMhfPostSeibattle", func() { handleMsgMhfPostSeibattle(session, nil) }}, + // From handlers_object.go (additional empty ones) + {"handleMsgSysAddObject", func() { handleMsgSysAddObject(session, nil) }}, + {"handleMsgSysDelObject", func() { handleMsgSysDelObject(session, nil) }}, + {"handleMsgSysDispObject", func() { handleMsgSysDispObject(session, nil) }}, + {"handleMsgSysHideObject", func() { handleMsgSysHideObject(session, nil) }}, + // From handlers.go (non-trivial but no pkt dereference) + {"handleMsgHead", func() { handleMsgHead(session, nil) }}, + {"handleMsgSysExtendThreshold", func() { handleMsgSysExtendThreshold(session, nil) }}, + {"handleMsgSysEnd", func() { handleMsgSysEnd(session, nil) }}, + {"handleMsgSysNop", func() { handleMsgSysNop(session, nil) }}, + {"handleMsgSysAck", func() { handleMsgSysAck(session, nil) }}, + // From handlers_semaphore.go + {"handleMsgSysReleaseSemaphore", func() { handleMsgSysReleaseSemaphore(session, nil) }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("%s panicked: %v", tt.name, r) + } + }() + tt.fn() + }) + } +} + +// ============================================================================= +// Category 14: Handlers that produce responses without DB access +// These are non-trivial handlers with static/canned responses. +// ============================================================================= + +func TestNonTrivialHandlers_NoDB(t *testing.T) { + server := createMockServer() + + t.Run("handleMsgMhfGetEarthStatus", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetEarthStatus(session, &mhfpacket.MsgMhfGetEarthStatus{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetEarthValue_Type1", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetEarthValue(session, &mhfpacket.MsgMhfGetEarthValue{AckHandle: 1, ReqType: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetEarthValue_Type2", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetEarthValue(session, &mhfpacket.MsgMhfGetEarthValue{AckHandle: 1, ReqType: 2}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetEarthValue_Type3", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetEarthValue(session, &mhfpacket.MsgMhfGetEarthValue{AckHandle: 1, ReqType: 3}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetSeibattle", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetSeibattle(session, &mhfpacket.MsgMhfGetSeibattle{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetTrendWeapon", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetTrendWeapon(session, &mhfpacket.MsgMhfGetTrendWeapon{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfUpdateUseTrendWeaponLog", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfUpdateUseTrendWeaponLog(session, &mhfpacket.MsgMhfUpdateUseTrendWeaponLog{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfUpdateBeatLevel", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfUpdateBeatLevel(session, &mhfpacket.MsgMhfUpdateBeatLevel{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfReadBeatLevel", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfReadBeatLevel(session, &mhfpacket.MsgMhfReadBeatLevel{ + AckHandle: 1, + ValidIDCount: 2, + IDs: [16]uint32{100, 200}, + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfTransferItem", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfTransferItem(session, &mhfpacket.MsgMhfTransferItem{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfEnumerateOrder", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfEnumerateOrder(session, &mhfpacket.MsgMhfEnumerateOrder{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetUdShopCoin", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetUdShopCoin(session, &mhfpacket.MsgMhfGetUdShopCoin{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfGetLobbyCrowd", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetLobbyCrowd(session, &mhfpacket.MsgMhfGetLobbyCrowd{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgMhfEnumeratePrice", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfEnumeratePrice(session, &mhfpacket.MsgMhfEnumeratePrice{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 15: Handlers from handlers_tactics.go that produce responses (no DB) +// ============================================================================= + +func TestNonTrivialHandlers_TacticsGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetUdTacticsPoint", func(s *Session) { + handleMsgMhfGetUdTacticsPoint(s, &mhfpacket.MsgMhfGetUdTacticsPoint{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsRewardList", func(s *Session) { + handleMsgMhfGetUdTacticsRewardList(s, &mhfpacket.MsgMhfGetUdTacticsRewardList{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsFollower", func(s *Session) { + handleMsgMhfGetUdTacticsFollower(s, &mhfpacket.MsgMhfGetUdTacticsFollower{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsBonusQuest", func(s *Session) { + handleMsgMhfGetUdTacticsBonusQuest(s, &mhfpacket.MsgMhfGetUdTacticsBonusQuest{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsFirstQuestBonus", func(s *Session) { + handleMsgMhfGetUdTacticsFirstQuestBonus(s, &mhfpacket.MsgMhfGetUdTacticsFirstQuestBonus{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsRemainingPoint", func(s *Session) { + handleMsgMhfGetUdTacticsRemainingPoint(s, &mhfpacket.MsgMhfGetUdTacticsRemainingPoint{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdTacticsRanking", func(s *Session) { + handleMsgMhfGetUdTacticsRanking(s, &mhfpacket.MsgMhfGetUdTacticsRanking{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 16: Handlers from handlers_tower.go that produce responses (no DB) +// ============================================================================= + +func TestNonTrivialHandlers_TowerGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetTenrouirai_Type1", func(s *Session) { + handleMsgMhfGetTenrouirai(s, &mhfpacket.MsgMhfGetTenrouirai{AckHandle: 1, Unk0: 1}) + }}, + {"handleMsgMhfGetTenrouirai_Type4", func(s *Session) { + handleMsgMhfGetTenrouirai(s, &mhfpacket.MsgMhfGetTenrouirai{AckHandle: 1, Unk0: 0, Unk2: 4}) + }}, + {"handleMsgMhfGetTenrouirai_Unknown", func(s *Session) { + handleMsgMhfGetTenrouirai(s, &mhfpacket.MsgMhfGetTenrouirai{AckHandle: 1, Unk0: 0, Unk2: 0}) + }}, + {"handleMsgMhfPostTenrouirai", func(s *Session) { + handleMsgMhfPostTenrouirai(s, &mhfpacket.MsgMhfPostTenrouirai{AckHandle: 1}) + }}, + {"handleMsgMhfGetWeeklySeibatuRankingReward", func(s *Session) { + handleMsgMhfGetWeeklySeibatuRankingReward(s, &mhfpacket.MsgMhfGetWeeklySeibatuRankingReward{AckHandle: 1}) + }}, + {"handleMsgMhfPresentBox", func(s *Session) { + handleMsgMhfPresentBox(s, &mhfpacket.MsgMhfPresentBox{AckHandle: 1}) + }}, + {"handleMsgMhfGetGemInfo", func(s *Session) { + handleMsgMhfGetGemInfo(s, &mhfpacket.MsgMhfGetGemInfo{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 17: Handlers from handlers_reward.go that produce responses (no DB) +// ============================================================================= + +func TestNonTrivialHandlers_RewardGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetAdditionalBeatReward", func(s *Session) { + handleMsgMhfGetAdditionalBeatReward(s, &mhfpacket.MsgMhfGetAdditionalBeatReward{AckHandle: 1}) + }}, + {"handleMsgMhfGetUdRankingRewardList", func(s *Session) { + handleMsgMhfGetUdRankingRewardList(s, &mhfpacket.MsgMhfGetUdRankingRewardList{AckHandle: 1}) + }}, + {"handleMsgMhfAcquireMonthlyReward", func(s *Session) { + handleMsgMhfAcquireMonthlyReward(s, &mhfpacket.MsgMhfAcquireMonthlyReward{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 18: Handlers from handlers_caravan.go that produce responses (no DB) +// ============================================================================= + +func TestNonTrivialHandlers_CaravanGo(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + fn func(s *Session) + }{ + {"handleMsgMhfGetRyoudama", func(s *Session) { + handleMsgMhfGetRyoudama(s, &mhfpacket.MsgMhfGetRyoudama{AckHandle: 1}) + }}, + {"handleMsgMhfGetTinyBin", func(s *Session) { + handleMsgMhfGetTinyBin(s, &mhfpacket.MsgMhfGetTinyBin{AckHandle: 1}) + }}, + {"handleMsgMhfPostTinyBin", func(s *Session) { + handleMsgMhfPostTinyBin(s, &mhfpacket.MsgMhfPostTinyBin{AckHandle: 1}) + }}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + tt.fn(session) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("%s: response should have data", tt.name) + } + default: + t.Errorf("%s: no response queued", tt.name) + } + }) + } +} + +// ============================================================================= +// Category 19: Handlers from handlers_rengoku.go (no DB needed) +// ============================================================================= + +func TestNonTrivialHandlers_RengokuGo(t *testing.T) { + server := createMockServer() + + t.Run("handleMsgMhfGetRengokuRankingRank", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfGetRengokuRankingRank(session, &mhfpacket.MsgMhfGetRengokuRankingRank{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 20: Handlers from handlers.go that produce responses (no DB) +// ============================================================================= + +func TestNonTrivialHandlers_InfoScenarioCounter(t *testing.T) { + server := createMockServer() + + t.Run("handleMsgMhfInfoScenarioCounter", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgMhfInfoScenarioCounter(session, &mhfpacket.MsgMhfInfoScenarioCounter{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 21: handleMsgSysPing and handleMsgSysTime (no DB) +// ============================================================================= + +func TestSimpleHandlers_PingAndTime(t *testing.T) { + server := createMockServer() + + t.Run("handleMsgSysPing", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysPing(session, &mhfpacket.MsgSysPing{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("handleMsgSysTime", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysTime(session, &mhfpacket.MsgSysTime{}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 22: handleMsgSysIssueLogkey (no DB, uses crypto/rand) +// ============================================================================= + +func TestHandleMsgSysIssueLogkey_Coverage3(t *testing.T) { + server := createMockServer() + + t.Run("generates_logkey", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysIssueLogkey(session, &mhfpacket.MsgSysIssueLogkey{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + if session.logKey == nil { + t.Error("logKey should be set after IssueLogkey") + } + if len(session.logKey) != 16 { + t.Errorf("logKey length = %d, want 16", len(session.logKey)) + } + }) +} + +// ============================================================================= +// Category 23: handleMsgSysUnlockGlobalSema (no DB) +// ============================================================================= + +func TestHandleMsgSysUnlockGlobalSema_Coverage3(t *testing.T) { + server := createMockServer() + + t.Run("produces_response", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysUnlockGlobalSema(session, &mhfpacket.MsgSysUnlockGlobalSema{AckHandle: 1}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 24: handleMsgSysLockGlobalSema (no DB, but needs Channels) +// ============================================================================= + +func TestHandleMsgSysLockGlobalSema(t *testing.T) { + server := createMockServer() + server.Channels = make([]*Server, 0) + + t.Run("no_channels_returns_response", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysLockGlobalSema(session, &mhfpacket.MsgSysLockGlobalSema{ + AckHandle: 1, + UserIDString: "testuser", + ServerChannelIDString: "ch1", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 25: handleMsgSysCheckSemaphore (no DB) +// ============================================================================= + +func TestHandleMsgSysCheckSemaphore(t *testing.T) { + server := createMockServer() + server.semaphore = make(map[string]*Semaphore) + + t.Run("semaphore_not_exists", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysCheckSemaphore(session, &mhfpacket.MsgSysCheckSemaphore{ + AckHandle: 1, + SemaphoreID: "nonexistent", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("semaphore_exists", func(t *testing.T) { + server.semaphore["existing_sema"] = NewSemaphore(server, "existing_sema", 1) + session := createMockSession(1, server) + handleMsgSysCheckSemaphore(session, &mhfpacket.MsgSysCheckSemaphore{ + AckHandle: 1, + SemaphoreID: "existing_sema", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 26: handleMsgSysAcquireSemaphore (no DB) +// ============================================================================= + +func TestHandleMsgSysAcquireSemaphore(t *testing.T) { + server := createMockServer() + server.semaphore = make(map[string]*Semaphore) + + t.Run("semaphore_exists", func(t *testing.T) { + server.semaphore["acquire_sema"] = NewSemaphore(server, "acquire_sema", 1) + session := createMockSession(1, server) + handleMsgSysAcquireSemaphore(session, &mhfpacket.MsgSysAcquireSemaphore{ + AckHandle: 1, + SemaphoreID: "acquire_sema", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("semaphore_not_exists", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysAcquireSemaphore(session, &mhfpacket.MsgSysAcquireSemaphore{ + AckHandle: 1, + SemaphoreID: "nonexistent_sema", + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 27: handleMsgSysCreateStage (no DB) +// ============================================================================= + +func TestHandleMsgSysCreateStage_Coverage3(t *testing.T) { + server := createMockServer() + + t.Run("creates_new_stage", func(t *testing.T) { + session := createMockSession(1, server) + handleMsgSysCreateStage(session, &mhfpacket.MsgSysCreateStage{ + AckHandle: 1, + StageID: "test_create_stage", + PlayerCount: 4, + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + if _, exists := server.stages["test_create_stage"]; !exists { + t.Error("stage should have been created") + } + }) + + t.Run("duplicate_stage_fails", func(t *testing.T) { + session := createMockSession(1, server) + // Stage already exists from the previous test + handleMsgSysCreateStage(session, &mhfpacket.MsgSysCreateStage{ + AckHandle: 2, + StageID: "test_create_stage", + PlayerCount: 4, + }) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data even on failure") + } + default: + t.Error("no response queued") + } + }) +} + +// ============================================================================= +// Category 28: Concurrency test for empty handlers +// Verify that calling empty handlers concurrently does not panic. +// ============================================================================= + +func TestEmptyHandlers_Concurrent(t *testing.T) { + server := createMockServer() + + handlers := []func(*Session, mhfpacket.MHFPacket){ + handleMsgSysEcho, + handleMsgSysUpdateRight, + handleMsgSysAuthQuery, + handleMsgSysAuthTerminal, + handleMsgCaExchangeItem, + handleMsgMhfServerCommand, + handleMsgMhfSetLoginwindow, + handleMsgSysTransBinary, + handleMsgSysCollectBinary, + handleMsgSysGetState, + handleMsgSysSerialize, + handleMsgSysEnumlobby, + handleMsgSysEnumuser, + handleMsgSysInfokyserver, + handleMsgMhfGetCaUniqueID, + handleMsgMhfEnumerateItem, + handleMsgMhfAcquireItem, + handleMsgMhfGetExtraInfo, + handleMsgSysSetStatus, + handleMsgSysDeleteObject, + handleMsgSysRotateObject, + handleMsgSysDuplicateObject, + handleMsgSysGetObjectBinary, + handleMsgSysGetObjectOwner, + handleMsgSysUpdateObjectBinary, + handleMsgSysCleanupObject, + handleMsgMhfShutClient, + handleMsgSysHideClient, + handleMsgSysStageDestruct, + } + + var wg sync.WaitGroup + for _, h := range handlers { + for i := 0; i < 10; i++ { + wg.Add(1) + go func(handler func(*Session, mhfpacket.MHFPacket)) { + defer wg.Done() + session := createMockSession(1, server) + handler(session, nil) + }(h) + } + } + wg.Wait() +} + +// ============================================================================= +// Category 29: stubEnumerateNoResults and stubGetNoResults helper coverage +// These are called by many handlers; test them directly too. +// ============================================================================= + +func TestStubHelpers(t *testing.T) { + server := createMockServer() + + t.Run("stubEnumerateNoResults", func(t *testing.T) { + session := createMockSession(1, server) + stubEnumerateNoResults(session, 1) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("stubGetNoResults", func(t *testing.T) { + session := createMockSession(1, server) + stubGetNoResults(session, 1) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("doAckBufSucceed", func(t *testing.T) { + session := createMockSession(1, server) + doAckBufSucceed(session, 1, []byte{0x01, 0x02, 0x03}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("doAckBufFail", func(t *testing.T) { + session := createMockSession(1, server) + doAckBufFail(session, 1, []byte{0x01, 0x02, 0x03}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("doAckSimpleSucceed", func(t *testing.T) { + session := createMockSession(1, server) + doAckSimpleSucceed(session, 1, []byte{0x00, 0x00, 0x00, 0x00}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) + + t.Run("doAckSimpleFail", func(t *testing.T) { + session := createMockSession(1, server) + doAckSimpleFail(session, 1, []byte{0x00, 0x00, 0x00, 0x00}) + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } + }) +} diff --git a/server/channelserver/handlers_coverage_test.go b/server/channelserver/handlers_coverage_test.go new file mode 100644 index 000000000..b5908b672 --- /dev/null +++ b/server/channelserver/handlers_coverage_test.go @@ -0,0 +1,145 @@ +package channelserver + +import ( + "testing" + + "erupe-ce/network/mhfpacket" +) + +// Tests for handlers that do NOT require database access, exercising additional +// code paths not covered by existing test files (handlers_core_test.go, +// handlers_rengoku_test.go, etc.). + +// TestHandleMsgSysPing_DifferentAckHandles verifies ping works with various ack handles. +func TestHandleMsgSysPing_DifferentAckHandles(t *testing.T) { + server := createMockServer() + + ackHandles := []uint32{0, 1, 99999, 0xFFFFFFFF} + for _, ack := range ackHandles { + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysPing{AckHandle: ack} + + handleMsgSysPing(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("AckHandle=%d: Response packet should have data", ack) + } + default: + t.Errorf("AckHandle=%d: No response packet queued", ack) + } + } +} + +// TestHandleMsgSysTerminalLog_NoEntries verifies the handler works with nil entries. +func TestHandleMsgSysTerminalLog_NoEntries(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysTerminalLog{ + AckHandle: 99999, + LogID: 0, + Entries: nil, + } + + handleMsgSysTerminalLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// TestHandleMsgSysTerminalLog_ManyEntries verifies the handler with many log entries. +func TestHandleMsgSysTerminalLog_ManyEntries(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + entries := make([]*mhfpacket.TerminalLogEntry, 20) + for i := range entries { + entries[i] = &mhfpacket.TerminalLogEntry{ + Index: uint32(i), + Type1: uint8(i % 256), + Type2: uint8((i + 1) % 256), + Data: make([]int16, 15), + } + } + + pkt := &mhfpacket.MsgSysTerminalLog{ + AckHandle: 55555, + LogID: 42, + Entries: entries, + } + + handleMsgSysTerminalLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// TestHandleMsgSysTime_MultipleCalls verifies calling time handler repeatedly. +func TestHandleMsgSysTime_MultipleCalls(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysTime{ + GetRemoteTime: false, + Timestamp: 0, + } + + for i := 0; i < 5; i++ { + handleMsgSysTime(session, pkt) + } + + // Should have 5 queued responses + count := 0 + for { + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + count++ + default: + goto done + } + } +done: + if count != 5 { + t.Errorf("Expected 5 queued responses, got %d", count) + } +} + +// TestHandleMsgMhfGetRengokuRankingRank_DifferentAck verifies rengoku ranking +// works with different ack handles. +func TestHandleMsgMhfGetRengokuRankingRank_DifferentAck(t *testing.T) { + server := createMockServer() + + ackHandles := []uint32{0, 1, 54321, 0xDEADBEEF} + for _, ack := range ackHandles { + session := createMockSession(1, server) + pkt := &mhfpacket.MsgMhfGetRengokuRankingRank{AckHandle: ack} + + handleMsgMhfGetRengokuRankingRank(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Errorf("AckHandle=%d: Response packet should have data", ack) + } + default: + t.Errorf("AckHandle=%d: No response packet queued", ack) + } + } +} diff --git a/server/channelserver/handlers_event_test.go b/server/channelserver/handlers_event_test.go index d283a4224..eed99b817 100644 --- a/server/channelserver/handlers_event_test.go +++ b/server/channelserver/handlers_event_test.go @@ -1,6 +1,7 @@ package channelserver import ( + "math/bits" "testing" "erupe-ce/network/mhfpacket" @@ -159,3 +160,99 @@ func TestGenerateFeatureWeapons_ZeroCount(t *testing.T) { t.Errorf("Expected 0 for zero count, got %d", result.ActiveFeatures) } } + +// --- NEW TESTS --- + +// TestGenerateFeatureWeapons_BitCount verifies that the number of set bits +// in ActiveFeatures matches the requested count (capped at 14). +func TestGenerateFeatureWeapons_BitCount(t *testing.T) { + tests := []struct { + name string + count int + wantBits int + }{ + {"1 weapon", 1, 1}, + {"5 weapons", 5, 5}, + {"10 weapons", 10, 10}, + {"14 weapons", 14, 14}, + {"20 capped to 14", 20, 14}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := generateFeatureWeapons(tt.count) + setBits := bits.OnesCount32(result.ActiveFeatures) + if setBits != tt.wantBits { + t.Errorf("Set bits = %d, want %d (ActiveFeatures=0b%032b)", + setBits, tt.wantBits, result.ActiveFeatures) + } + }) + } +} + +// TestGenerateFeatureWeapons_BitsInRange verifies that all set bits are within +// bits 0-13 (no bits above bit 13 should be set). +func TestGenerateFeatureWeapons_BitsInRange(t *testing.T) { + for i := 0; i < 50; i++ { + result := generateFeatureWeapons(7) + // Bits 14+ should never be set + if result.ActiveFeatures&^uint32(0x3FFF) != 0 { + t.Errorf("Bits above 13 are set: 0x%08X", result.ActiveFeatures) + } + } +} + +// TestGenerateFeatureWeapons_MaxYieldsAllBits verifies that requesting 14 +// weapons sets exactly bits 0-13 (the value 16383 = 0x3FFF). +func TestGenerateFeatureWeapons_MaxYieldsAllBits(t *testing.T) { + result := generateFeatureWeapons(14) + if result.ActiveFeatures != 0x3FFF { + t.Errorf("ActiveFeatures = 0x%04X, want 0x3FFF (all 14 bits set)", result.ActiveFeatures) + } +} + +// TestGenerateFeatureWeapons_StartTimeZero verifies that the returned +// activeFeature has a zero StartTime (not set by generateFeatureWeapons). +func TestGenerateFeatureWeapons_StartTimeZero(t *testing.T) { + result := generateFeatureWeapons(5) + if !result.StartTime.IsZero() { + t.Errorf("StartTime should be zero, got %v", result.StartTime) + } +} + +// TestHandleMsgMhfRegisterEvent_DifferentValues tests with various Unk2/Unk4 values. +func TestHandleMsgMhfRegisterEvent_DifferentValues(t *testing.T) { + server := createMockServer() + + tests := []struct { + name string + unk2 uint8 + unk4 uint8 + }{ + {"zeros", 0, 0}, + {"max values", 255, 255}, + {"typical", 5, 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + session := createMockSession(1, server) + pkt := &mhfpacket.MsgMhfRegisterEvent{ + AckHandle: 99999, + Unk2: tt.unk2, + Unk4: tt.unk4, + } + + handleMsgMhfRegisterEvent(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } + }) + } +} diff --git a/server/channelserver/handlers_mercenary_test.go b/server/channelserver/handlers_mercenary_test.go index 340e4024c..a4273196e 100644 --- a/server/channelserver/handlers_mercenary_test.go +++ b/server/channelserver/handlers_mercenary_test.go @@ -1,8 +1,11 @@ package channelserver import ( + "bytes" + "encoding/binary" "testing" + "erupe-ce/common/byteframe" "erupe-ce/network/mhfpacket" ) @@ -25,3 +28,271 @@ func TestHandleMsgMhfLoadLegendDispatch(t *testing.T) { t.Error("No response packet queued") } } + +// --- NEW TESTS --- + +// buildCatBytes constructs a binary cat data payload suitable for GetCatDetails. +func buildCatBytes(cats []CatDefinition) []byte { + buf := new(bytes.Buffer) + // catCount + buf.WriteByte(byte(len(cats))) + for _, cat := range cats { + catBuf := new(bytes.Buffer) + // catID (uint32) + binary.Write(catBuf, binary.BigEndian, cat.CatID) + // 1 byte skip (unknown bool) + catBuf.WriteByte(0) + // CatName (18 bytes) + name := make([]byte, 18) + copy(name, cat.CatName) + catBuf.Write(name) + // CurrentTask (uint8) + catBuf.WriteByte(cat.CurrentTask) + // 16 bytes skip (appearance data) + catBuf.Write(make([]byte, 16)) + // Personality (uint8) + catBuf.WriteByte(cat.Personality) + // Class (uint8) + catBuf.WriteByte(cat.Class) + // 5 bytes skip (affection and colour sliders) + catBuf.Write(make([]byte, 5)) + // Experience (uint32) + binary.Write(catBuf, binary.BigEndian, cat.Experience) + // 1 byte skip (bool for weapon equipped) + catBuf.WriteByte(0) + // WeaponType (uint8) + catBuf.WriteByte(cat.WeaponType) + // WeaponID (uint16) + binary.Write(catBuf, binary.BigEndian, cat.WeaponID) + + catData := catBuf.Bytes() + // catDefLen (uint32) - total length of the cat data after this field + binary.Write(buf, binary.BigEndian, uint32(len(catData))) + buf.Write(catData) + } + return buf.Bytes() +} + +func TestGetCatDetails_Empty(t *testing.T) { + // Zero cats + data := []byte{0x00} + bf := byteframe.NewByteFrameFromBytes(data) + cats := GetCatDetails(bf) + + if len(cats) != 0 { + t.Errorf("Expected 0 cats, got %d", len(cats)) + } +} + +func TestGetCatDetails_SingleCat(t *testing.T) { + input := CatDefinition{ + CatID: 42, + CatName: []byte("TestCat"), + CurrentTask: 4, + Personality: 3, + Class: 2, + Experience: 1500, + WeaponType: 6, + WeaponID: 100, + } + + data := buildCatBytes([]CatDefinition{input}) + bf := byteframe.NewByteFrameFromBytes(data) + cats := GetCatDetails(bf) + + if len(cats) != 1 { + t.Fatalf("Expected 1 cat, got %d", len(cats)) + } + + cat := cats[0] + if cat.CatID != 42 { + t.Errorf("CatID = %d, want 42", cat.CatID) + } + if cat.CurrentTask != 4 { + t.Errorf("CurrentTask = %d, want 4", cat.CurrentTask) + } + if cat.Personality != 3 { + t.Errorf("Personality = %d, want 3", cat.Personality) + } + if cat.Class != 2 { + t.Errorf("Class = %d, want 2", cat.Class) + } + if cat.Experience != 1500 { + t.Errorf("Experience = %d, want 1500", cat.Experience) + } + if cat.WeaponType != 6 { + t.Errorf("WeaponType = %d, want 6", cat.WeaponType) + } + if cat.WeaponID != 100 { + t.Errorf("WeaponID = %d, want 100", cat.WeaponID) + } + // Name should be 18 bytes (padded with nulls) + if len(cat.CatName) != 18 { + t.Errorf("CatName length = %d, want 18", len(cat.CatName)) + } + // First bytes should match "TestCat" + if !bytes.HasPrefix(cat.CatName, []byte("TestCat")) { + t.Errorf("CatName does not start with 'TestCat', got %v", cat.CatName) + } +} + +func TestGetCatDetails_MultipleCats(t *testing.T) { + inputs := []CatDefinition{ + {CatID: 1, CatName: []byte("Alpha"), CurrentTask: 1, Personality: 0, Class: 0, Experience: 100, WeaponType: 6, WeaponID: 10}, + {CatID: 2, CatName: []byte("Beta"), CurrentTask: 2, Personality: 1, Class: 1, Experience: 200, WeaponType: 6, WeaponID: 20}, + {CatID: 3, CatName: []byte("Gamma"), CurrentTask: 4, Personality: 2, Class: 2, Experience: 300, WeaponType: 6, WeaponID: 30}, + } + + data := buildCatBytes(inputs) + bf := byteframe.NewByteFrameFromBytes(data) + cats := GetCatDetails(bf) + + if len(cats) != 3 { + t.Fatalf("Expected 3 cats, got %d", len(cats)) + } + + for i, cat := range cats { + if cat.CatID != inputs[i].CatID { + t.Errorf("Cat %d: CatID = %d, want %d", i, cat.CatID, inputs[i].CatID) + } + if cat.CurrentTask != inputs[i].CurrentTask { + t.Errorf("Cat %d: CurrentTask = %d, want %d", i, cat.CurrentTask, inputs[i].CurrentTask) + } + if cat.Experience != inputs[i].Experience { + t.Errorf("Cat %d: Experience = %d, want %d", i, cat.Experience, inputs[i].Experience) + } + if cat.WeaponID != inputs[i].WeaponID { + t.Errorf("Cat %d: WeaponID = %d, want %d", i, cat.WeaponID, inputs[i].WeaponID) + } + } +} + +func TestGetCatDetails_ExtraTrailingBytes(t *testing.T) { + // The GetCatDetails function handles extra bytes by seeking to catStart+catDefLen. + // Simulate a cat definition with extra trailing bytes by increasing catDefLen. + buf := new(bytes.Buffer) + buf.WriteByte(1) // catCount = 1 + + catBuf := new(bytes.Buffer) + binary.Write(catBuf, binary.BigEndian, uint32(99)) // catID + catBuf.WriteByte(0) // skip + catBuf.Write(make([]byte, 18)) // name + catBuf.WriteByte(3) // currentTask + catBuf.Write(make([]byte, 16)) // appearance skip + catBuf.WriteByte(1) // personality + catBuf.WriteByte(2) // class + catBuf.Write(make([]byte, 5)) // affection skip + binary.Write(catBuf, binary.BigEndian, uint32(500)) // experience + catBuf.WriteByte(0) // weapon equipped bool + catBuf.WriteByte(6) // weaponType + binary.Write(catBuf, binary.BigEndian, uint16(50)) // weaponID + + catData := catBuf.Bytes() + // Add 10 extra trailing bytes + extra := make([]byte, 10) + catDataWithExtra := append(catData, extra...) + + binary.Write(buf, binary.BigEndian, uint32(len(catDataWithExtra))) + buf.Write(catDataWithExtra) + + bf := byteframe.NewByteFrameFromBytes(buf.Bytes()) + cats := GetCatDetails(bf) + + if len(cats) != 1 { + t.Fatalf("Expected 1 cat, got %d", len(cats)) + } + if cats[0].CatID != 99 { + t.Errorf("CatID = %d, want 99", cats[0].CatID) + } + if cats[0].Experience != 500 { + t.Errorf("Experience = %d, want 500", cats[0].Experience) + } +} + +func TestGetCatDetails_CatNamePadding(t *testing.T) { + // Verify that names shorter than 18 bytes are correctly padded with null bytes. + input := CatDefinition{ + CatID: 1, + CatName: []byte("Hi"), + } + + data := buildCatBytes([]CatDefinition{input}) + bf := byteframe.NewByteFrameFromBytes(data) + cats := GetCatDetails(bf) + + if len(cats) != 1 { + t.Fatalf("Expected 1 cat, got %d", len(cats)) + } + if len(cats[0].CatName) != 18 { + t.Errorf("CatName length = %d, want 18", len(cats[0].CatName)) + } + // "Hi" followed by null bytes + if cats[0].CatName[0] != 'H' || cats[0].CatName[1] != 'i' { + t.Errorf("CatName first bytes = %v, want 'Hi...'", cats[0].CatName[:2]) + } +} + +// TestHandleMsgMhfMercenaryHuntdata_Unk0_1 tests with Unk0=1 (returns 1 byte) +func TestHandleMsgMhfMercenaryHuntdata_Unk0_1(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfMercenaryHuntdata{ + AckHandle: 12345, + Unk0: 1, + } + + handleMsgMhfMercenaryHuntdata(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// TestHandleMsgMhfMercenaryHuntdata_Unk0_0 tests with Unk0=0 (returns 0 bytes payload) +func TestHandleMsgMhfMercenaryHuntdata_Unk0_0(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfMercenaryHuntdata{ + AckHandle: 12345, + Unk0: 0, + } + + handleMsgMhfMercenaryHuntdata(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} + +// TestHandleMsgMhfEnumerateMercenaryLog tests the mercenary log enumeration handler +func TestHandleMsgMhfEnumerateMercenaryLog(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgMhfEnumerateMercenaryLog{ + AckHandle: 12345, + } + + handleMsgMhfEnumerateMercenaryLog(session, pkt) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("Response packet should have data") + } + default: + t.Error("No response packet queued") + } +} diff --git a/server/channelserver/handlers_misc_test.go b/server/channelserver/handlers_misc_test.go index 48e12c9ae..fc66628dd 100644 --- a/server/channelserver/handlers_misc_test.go +++ b/server/channelserver/handlers_misc_test.go @@ -568,30 +568,3 @@ func TestDistributionItemStruct(t *testing.T) { t.Errorf("ItemID = %d, want 1234", item.ItemID) } } - -// Login boost struct test -func TestLoginBoostStruct(t *testing.T) { - boost := loginBoost{ - WeekReq: 1, - WeekCount: 2, - Active: true, - } - - if boost.WeekReq != 1 { - t.Errorf("WeekReq = %d, want 1", boost.WeekReq) - } - if !boost.Active { - t.Error("Active should be true") - } -} - -// ActiveFeature struct test -func TestActiveFeatureStruct(t *testing.T) { - feature := activeFeature{ - ActiveFeatures: 0x0FFF, - } - - if feature.ActiveFeatures != 0x0FFF { - t.Errorf("ActiveFeatures = %x, want 0x0FFF", feature.ActiveFeatures) - } -} diff --git a/server/channelserver/handlers_register_test.go b/server/channelserver/handlers_register_test.go index a4bc55dce..d897254fc 100644 --- a/server/channelserver/handlers_register_test.go +++ b/server/channelserver/handlers_register_test.go @@ -2,62 +2,1563 @@ package channelserver import ( "testing" + + "erupe-ce/common/byteframe" + "erupe-ce/network/mhfpacket" ) -func TestHandleMsgSysNotifyRegister(t *testing.T) { - server := createMockServer() +// createMockServerWithRaviente creates a mock server with raviente and semaphore +// initialized, which the base createMockServer() does not do. +func createMockServerWithRaviente() *Server { + s := createMockServer() + s.raviente = NewRaviente() + s.semaphore = make(map[string]*Semaphore) + return s +} + +// --- NewRaviente --- + +func TestNewRaviente_FullValidation(t *testing.T) { + r := NewRaviente() + if r == nil { + t.Fatal("NewRaviente returned nil") + } + if r.register == nil { + t.Fatal("register is nil") + } + if r.state == nil { + t.Fatal("state is nil") + } + if r.support == nil { + t.Fatal("support is nil") + } + if len(r.register.register) != 5 { + t.Errorf("register length = %d, want 5", len(r.register.register)) + } + if len(r.state.stateData) != 29 { + t.Errorf("stateData length = %d, want 29", len(r.state.stateData)) + } + if len(r.support.supportData) != 25 { + t.Errorf("supportData length = %d, want 25", len(r.support.supportData)) + } + // All values should be zero-initialized + for i, v := range r.register.register { + if v != 0 { + t.Errorf("register[%d] = %d, want 0", i, v) + } + } + for i, v := range r.state.stateData { + if v != 0 { + t.Errorf("stateData[%d] = %d, want 0", i, v) + } + } + for i, v := range r.support.supportData { + if v != 0 { + t.Errorf("supportData[%d] = %d, want 0", i, v) + } + } + if r.register.nextTime != 0 { + t.Errorf("nextTime = %d, want 0", r.register.nextTime) + } + if r.register.startTime != 0 { + t.Errorf("startTime = %d, want 0", r.register.startTime) + } + if r.register.killedTime != 0 { + t.Errorf("killedTime = %d, want 0", r.register.killedTime) + } + if r.register.postTime != 0 { + t.Errorf("postTime = %d, want 0", r.register.postTime) + } + if r.register.ravienteType != 0 { + t.Errorf("ravienteType = %d, want 0", r.register.ravienteType) + } + if r.register.maxPlayers != 0 { + t.Errorf("maxPlayers = %d, want 0", r.register.maxPlayers) + } + if r.register.carveQuest != 0 { + t.Errorf("carveQuest = %d, want 0", r.register.carveQuest) + } +} + +// --- handleMsgSysLoadRegister --- + +func TestHandleMsgSysLoadRegister_Case12(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.nextTime = 100 + server.raviente.register.startTime = 200 + server.raviente.register.killedTime = 300 + server.raviente.register.postTime = 400 + server.raviente.register.register[0] = 10 + server.raviente.register.register[1] = 20 + server.raviente.register.register[2] = 30 + server.raviente.register.register[3] = 40 + server.raviente.register.register[4] = 50 + server.raviente.register.carveQuest = 500 + server.raviente.register.maxPlayers = 32 + server.raviente.register.ravienteType = 2 session := createMockSession(1, server) - // Should not panic (empty handler) - defer func() { - if r := recover(); r != nil { - t.Errorf("handleMsgSysNotifyRegister panicked: %v", r) + handleMsgSysLoadRegister(session, &mhfpacket.MsgSysLoadRegister{ + AckHandle: 1, RegisterID: 0, Unk1: 12, + }) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") } - }() - - handleMsgSysNotifyRegister(session, nil) + default: + t.Error("no response queued") + } } -func TestGetRaviSemaphore_None(t *testing.T) { - server := createMockServer() - server.semaphore = make(map[string]*Semaphore) +func TestHandleMsgSysLoadRegister_Case29(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.state.stateData[0] = 111 + server.raviente.state.stateData[14] = 222 + server.raviente.state.stateData[28] = 333 + session := createMockSession(1, server) + handleMsgSysLoadRegister(session, &mhfpacket.MsgSysLoadRegister{ + AckHandle: 2, RegisterID: 0, Unk1: 29, + }) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysLoadRegister_Case25(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.support.supportData[0] = 777 + server.raviente.support.supportData[12] = 888 + server.raviente.support.supportData[24] = 999 + session := createMockSession(1, server) + + handleMsgSysLoadRegister(session, &mhfpacket.MsgSysLoadRegister{ + AckHandle: 3, RegisterID: 0, Unk1: 25, + }) + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysLoadRegister_UnknownCase(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + // Unk1=99 doesn't match any case, so no response should be sent + handleMsgSysLoadRegister(session, &mhfpacket.MsgSysLoadRegister{ + AckHandle: 4, RegisterID: 0, Unk1: 99, + }) + + select { + case <-session.sendPackets: + t.Error("no response expected for unknown Unk1 value") + default: + // Expected: no packet queued + } +} + +// --- handleMsgSysOperateRegister --- + +type opEntry struct { + op uint8 + dest uint8 + data uint32 +} + +func buildPayload(entries ...opEntry) []byte { + bf := byteframe.NewByteFrame() + for _, e := range entries { + bf.WriteUint8(e.op) + bf.WriteUint8(e.dest) + bf.WriteUint32(e.data) + } + bf.WriteUint8(0) // terminator + return bf.Data() +} + +// --- SemaphoreID=4 (stateData) --- + +func TestHandleMsgSysOperateRegister_State_Op2_Normal(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.state.stateData[0] = 100 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 0, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // With no ravi semaphore, GetRaviMultiplier returns 0, so data becomes 0 + // *ref += 0 => stateData[0] stays 100 + if server.raviente.state.stateData[0] != 100 { + t.Errorf("stateData[0] = %d, want 100 (multiplier=0 makes data=0)", server.raviente.state.stateData[0]) + } + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op2_Dest28(t *testing.T) { + // dest=28 is the Berserk resurrection tracker, no multiplier applied + server := createMockServerWithRaviente() + server.raviente.state.stateData[28] = 100 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 28, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // dest=28 adds data directly without multiplier + if server.raviente.state.stateData[28] != 150 { + t.Errorf("stateData[28] = %d, want 150", server.raviente.state.stateData[28]) + } + + select { + case p := <-session.sendPackets: + if len(p.data) == 0 { + t.Error("response should have data") + } + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op2_Dest17_MultiplierIs1(t *testing.T) { + // dest=17 is Berserk poison tracker, only adds when damageMultiplier==1 + server := createMockServerWithRaviente() + server.raviente.state.stateData[17] = 100 + server.raviente.register.maxPlayers = 4 // small ravi, minPlayers=4 + + // Create a ravi semaphore with enough clients for multiplier=1 + sema := &Semaphore{ + id_semaphore: "hs_l0u3B51234_3", + id: 7, + clients: make(map[*Session]uint32), + } + // Need > 4 clients (minPlayers) for multiplier=1 + for i := 0; i < 5; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 17, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // multiplier=1, so dest=17 adds data + if server.raviente.state.stateData[17] != 150 { + t.Errorf("stateData[17] = %d, want 150", server.raviente.state.stateData[17]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op2_Dest17_MultiplierNot1(t *testing.T) { + // dest=17 with multiplier != 1 should NOT add data + server := createMockServerWithRaviente() + server.raviente.state.stateData[17] = 100 + server.raviente.register.maxPlayers = 4 // small ravi, minPlayers=4 + + // Create a ravi semaphore with fewer clients than minPlayers for multiplier > 1 + sema := &Semaphore{ + id_semaphore: "hs_l0u3B51234_3", + id: 7, + clients: make(map[*Session]uint32), + } + // Need <= 4 clients so multiplier = 4/len(clients) != 1 + for i := 0; i < 2; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 17, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // multiplier=4/2=2 != 1, so dest=17 does NOT add data + if server.raviente.state.stateData[17] != 100 { + t.Errorf("stateData[17] = %d, want 100 (should not change)", server.raviente.state.stateData[17]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.state.stateData[5] = 999 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 5, data: 42}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.state.stateData[5] != 42 { + t.Errorf("stateData[5] = %d, want 42", server.raviente.state.stateData[5]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op13(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.state.stateData[3] = 888 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 3, data: 77}), + } + handleMsgSysOperateRegister(session, pkt) + + // op=13 falls through to op=14 behavior: sets value + if server.raviente.state.stateData[3] != 77 { + t.Errorf("stateData[3] = %d, want 77", server.raviente.state.stateData[3]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_MultipleEntries(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload( + opEntry{op: 14, dest: 0, data: 10}, + opEntry{op: 14, dest: 1, data: 20}, + opEntry{op: 14, dest: 2, data: 30}, + ), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.state.stateData[0] != 10 { + t.Errorf("stateData[0] = %d, want 10", server.raviente.state.stateData[0]) + } + if server.raviente.state.stateData[1] != 20 { + t.Errorf("stateData[1] = %d, want 20", server.raviente.state.stateData[1]) + } + if server.raviente.state.stateData[2] != 30 { + t.Errorf("stateData[2] = %d, want 30", server.raviente.state.stateData[2]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +// --- SemaphoreID=5 (supportData) --- + +func TestHandleMsgSysOperateRegister_Support_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.support.supportData[0] = 100 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 5, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 0, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.support.supportData[0] != 150 { + t.Errorf("supportData[0] = %d, want 150", server.raviente.support.supportData[0]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Support_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.support.supportData[10] = 999 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 5, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 10, data: 42}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.support.supportData[10] != 42 { + t.Errorf("supportData[10] = %d, want 42", server.raviente.support.supportData[10]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Support_Op13(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.support.supportData[5] = 888 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 5, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 5, data: 77}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.support.supportData[5] != 77 { + t.Errorf("supportData[5] = %d, want 77", server.raviente.support.supportData[5]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Support_MultipleEntries(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 5, + RawDataPayload: buildPayload( + opEntry{op: 2, dest: 0, data: 10}, + opEntry{op: 14, dest: 1, data: 20}, + opEntry{op: 13, dest: 2, data: 30}, + ), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.support.supportData[0] != 10 { + t.Errorf("supportData[0] = %d, want 10", server.raviente.support.supportData[0]) + } + if server.raviente.support.supportData[1] != 20 { + t.Errorf("supportData[1] = %d, want 20", server.raviente.support.supportData[1]) + } + if server.raviente.support.supportData[2] != 30 { + t.Errorf("supportData[2] = %d, want 30", server.raviente.support.supportData[2]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +// --- SemaphoreID=6 (register fields) --- + +func TestHandleMsgSysOperateRegister_Register_Dest0_NextTime(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 0, data: 12345}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.nextTime != 12345 { + t.Errorf("nextTime = %d, want 12345", server.raviente.register.nextTime) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest1_StartTime(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 1, data: 67890}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.startTime != 67890 { + t.Errorf("startTime = %d, want 67890", server.raviente.register.startTime) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest2_KilledTime(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 2, data: 11111}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.killedTime != 11111 { + t.Errorf("killedTime = %d, want 11111", server.raviente.register.killedTime) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest3_PostTime(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 3, data: 22222}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.postTime != 22222 { + t.Errorf("postTime = %d, want 22222", server.raviente.register.postTime) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest4_Register0_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[0] = 100 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 4, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[0] != 150 { + t.Errorf("register[0] = %d, want 150", server.raviente.register.register[0]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest4_Register0_Op13(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[0] = 999 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 4, data: 42}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[0] != 42 { + t.Errorf("register[0] = %d, want 42", server.raviente.register.register[0]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest4_Register0_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[0] = 999 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 4, data: 77}), + } + handleMsgSysOperateRegister(session, pkt) + + // op=14 for dest=4 writes response data but does NOT set *ref (unlike op=13) + if server.raviente.register.register[0] != 999 { + t.Errorf("register[0] = %d, want 999 (op=14 doesn't set ref)", server.raviente.register.register[0]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest5_CarveQuest(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 5, data: 33333}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.carveQuest != 33333 { + t.Errorf("carveQuest = %d, want 33333", server.raviente.register.carveQuest) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest6_Register1_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[1] = 200 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 6, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[1] != 250 { + t.Errorf("register[1] = %d, want 250", server.raviente.register.register[1]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest6_Register1_Op13(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 6, data: 55}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[1] != 55 { + t.Errorf("register[1] = %d, want 55", server.raviente.register.register[1]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest6_Register1_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[1] = 999 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 6, data: 77}), + } + handleMsgSysOperateRegister(session, pkt) + + // op=14 for register dests doesn't set *ref + if server.raviente.register.register[1] != 999 { + t.Errorf("register[1] = %d, want 999", server.raviente.register.register[1]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest7_Register2_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[2] = 300 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 7, data: 25}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[2] != 325 { + t.Errorf("register[2] = %d, want 325", server.raviente.register.register[2]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest7_Register2_Op13(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 7, data: 66}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[2] != 66 { + t.Errorf("register[2] = %d, want 66", server.raviente.register.register[2]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest7_Register2_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[2] = 500 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 7, data: 88}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[2] != 500 { + t.Errorf("register[2] = %d, want 500", server.raviente.register.register[2]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest8_Register3_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[3] = 400 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 8, data: 10}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[3] != 410 { + t.Errorf("register[3] = %d, want 410", server.raviente.register.register[3]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest8_Register3_Op13(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 8, data: 99}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[3] != 99 { + t.Errorf("register[3] = %d, want 99", server.raviente.register.register[3]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest8_Register3_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[3] = 777 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 8, data: 11}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[3] != 777 { + t.Errorf("register[3] = %d, want 777", server.raviente.register.register[3]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest9_MaxPlayers(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 9, data: 32}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.maxPlayers != 32 { + t.Errorf("maxPlayers = %d, want 32", server.raviente.register.maxPlayers) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest10_RavienteType(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 10, data: 3}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.ravienteType != 3 { + t.Errorf("ravienteType = %d, want 3", server.raviente.register.ravienteType) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest11_Register4_Op2(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[4] = 500 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 11, data: 100}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[4] != 600 { + t.Errorf("register[4] = %d, want 600", server.raviente.register.register[4]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest11_Register4_Op13(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 13, dest: 11, data: 44}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[4] != 44 { + t.Errorf("register[4] = %d, want 44", server.raviente.register.register[4]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_Dest11_Register4_Op14(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.register[4] = 888 + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 11, data: 55}), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.register[4] != 888 { + t.Errorf("register[4] = %d, want 888", server.raviente.register.register[4]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_DefaultDest(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + // dest=99 doesn't match any case, hits default branch + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 99, data: 123}), + } + handleMsgSysOperateRegister(session, pkt) + + select { + case <-session.sendPackets: + // Default case writes zeros and sends response + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_Register_AllDests(t *testing.T) { + // Exercise all dest cases in a single operation to test full branch coverage + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 6, + RawDataPayload: buildPayload( + opEntry{op: 14, dest: 0, data: 1}, // nextTime + opEntry{op: 14, dest: 1, data: 2}, // startTime + opEntry{op: 14, dest: 2, data: 3}, // killedTime + opEntry{op: 14, dest: 3, data: 4}, // postTime + opEntry{op: 2, dest: 4, data: 5}, // register[0] op2 + opEntry{op: 14, dest: 5, data: 6}, // carveQuest + opEntry{op: 2, dest: 6, data: 7}, // register[1] op2 + opEntry{op: 2, dest: 7, data: 8}, // register[2] op2 + opEntry{op: 2, dest: 8, data: 9}, // register[3] op2 + opEntry{op: 14, dest: 9, data: 10}, // maxPlayers + opEntry{op: 14, dest: 10, data: 11}, // ravienteType + opEntry{op: 2, dest: 11, data: 12}, // register[4] op2 + opEntry{op: 14, dest: 99, data: 0}, // default + ), + } + handleMsgSysOperateRegister(session, pkt) + + if server.raviente.register.nextTime != 1 { + t.Errorf("nextTime = %d, want 1", server.raviente.register.nextTime) + } + if server.raviente.register.startTime != 2 { + t.Errorf("startTime = %d, want 2", server.raviente.register.startTime) + } + if server.raviente.register.killedTime != 3 { + t.Errorf("killedTime = %d, want 3", server.raviente.register.killedTime) + } + if server.raviente.register.postTime != 4 { + t.Errorf("postTime = %d, want 4", server.raviente.register.postTime) + } + if server.raviente.register.register[0] != 5 { + t.Errorf("register[0] = %d, want 5", server.raviente.register.register[0]) + } + if server.raviente.register.carveQuest != 6 { + t.Errorf("carveQuest = %d, want 6", server.raviente.register.carveQuest) + } + if server.raviente.register.register[1] != 7 { + t.Errorf("register[1] = %d, want 7", server.raviente.register.register[1]) + } + if server.raviente.register.register[2] != 8 { + t.Errorf("register[2] = %d, want 8", server.raviente.register.register[2]) + } + if server.raviente.register.register[3] != 9 { + t.Errorf("register[3] = %d, want 9", server.raviente.register.register[3]) + } + if server.raviente.register.maxPlayers != 10 { + t.Errorf("maxPlayers = %d, want 10", server.raviente.register.maxPlayers) + } + if server.raviente.register.ravienteType != 11 { + t.Errorf("ravienteType = %d, want 11", server.raviente.register.ravienteType) + } + if server.raviente.register.register[4] != 12 { + t.Errorf("register[4] = %d, want 12", server.raviente.register.register[4]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +// --- getRaviSemaphore --- + +func TestGetRaviSemaphore_NoMatch(t *testing.T) { + server := createMockServerWithRaviente() result := getRaviSemaphore(server) - if result != nil { - t.Error("Expected nil when no raviente semaphore exists") + t.Error("should return nil when no semaphore matches") } } -func TestGetRaviSemaphore_Found(t *testing.T) { - server := createMockServer() - server.semaphore = make(map[string]*Semaphore) - - // Create a raviente semaphore (matches prefix hs_l0u3B5 and suffix 3) - sema := NewSemaphore(server, "hs_l0u3B53", 32) - server.semaphore["hs_l0u3B53"] = sema - - result := getRaviSemaphore(server) - - if result == nil { - t.Error("Expected to find raviente semaphore") +func TestGetRaviSemaphore_WrongPrefix(t *testing.T) { + server := createMockServerWithRaviente() + server.semaphore["wrong"] = &Semaphore{ + id_semaphore: "wrong_prefix_3", + id: 7, + clients: make(map[*Session]uint32), } - if result.id_semaphore != "hs_l0u3B53" { - t.Errorf("Wrong semaphore returned: %s", result.id_semaphore) + result := getRaviSemaphore(server) + if result != nil { + t.Error("should return nil when no semaphore has correct prefix") } } func TestGetRaviSemaphore_WrongSuffix(t *testing.T) { - server := createMockServer() - server.semaphore = make(map[string]*Semaphore) - - // Create a semaphore with wrong suffix - sema := NewSemaphore(server, "hs_l0u3B51", 32) - server.semaphore["hs_l0u3B51"] = sema - + server := createMockServerWithRaviente() + server.semaphore["wrong"] = &Semaphore{ + id_semaphore: "hs_l0u3B5test_4", + id: 7, + clients: make(map[*Session]uint32), + } result := getRaviSemaphore(server) - if result != nil { - t.Error("Should not match semaphore with wrong suffix") + t.Error("should return nil when semaphore has wrong suffix") + } +} + +func TestGetRaviSemaphore_Match(t *testing.T) { + server := createMockServerWithRaviente() + expected := &Semaphore{ + id_semaphore: "hs_l0u3B5some_data_3", + id: 7, + clients: make(map[*Session]uint32), + } + server.semaphore["ravi"] = expected + result := getRaviSemaphore(server) + if result != expected { + t.Error("should return matching semaphore") + } +} + +func TestGetRaviSemaphore_ExactMinimal(t *testing.T) { + server := createMockServerWithRaviente() + expected := &Semaphore{ + id_semaphore: "hs_l0u3B53", + id: 7, + clients: make(map[*Session]uint32), + } + server.semaphore["ravi"] = expected + result := getRaviSemaphore(server) + if result != expected { + t.Error("should match when prefix immediately followed by suffix '3'") + } +} + +func TestGetRaviSemaphore_MultipleOnlyOneMatches(t *testing.T) { + server := createMockServerWithRaviente() + server.semaphore["wrong1"] = &Semaphore{ + id_semaphore: "something_else", + id: 8, + clients: make(map[*Session]uint32), + } + expected := &Semaphore{ + id_semaphore: "hs_l0u3B5ravi_3", + id: 9, + clients: make(map[*Session]uint32), + } + server.semaphore["ravi"] = expected + server.semaphore["wrong2"] = &Semaphore{ + id_semaphore: "hs_l0u3B5_no_suffix", + id: 10, + clients: make(map[*Session]uint32), + } + result := getRaviSemaphore(server) + if result != expected { + t.Error("should return the one matching semaphore") + } +} + +// --- notifyRavi --- + +func TestNotifyRavi_NoSemaphore(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + // Should not panic when no ravi semaphore exists + session.notifyRavi() + + // No clients to receive notifications, so nothing should be queued + select { + case <-session.sendPackets: + t.Error("no packet should be queued on the calling session") + default: + // Expected + } +} + +func TestNotifyRavi_WithSemaphoreAndClients(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + client1 := createMockSession(10, server) + client2 := createMockSession(20, server) + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + sema.clients[client1] = client1.charID + sema.clients[client2] = client2.charID + server.semaphore["ravi"] = sema + + session.notifyRavi() + + // Both clients on the semaphore should receive notification packets + receivedCount := 0 + select { + case p := <-client1.sendPackets: + if len(p.data) > 0 { + receivedCount++ + } + default: + t.Error("client1 should have received notification") + } + select { + case p := <-client2.sendPackets: + if len(p.data) > 0 { + receivedCount++ + } + default: + t.Error("client2 should have received notification") + } + if receivedCount != 2 { + t.Errorf("received %d notifications, want 2", receivedCount) + } +} + +// --- resetRavi --- + +func TestResetRavi(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + // Set various values + server.raviente.register.nextTime = 12345 + server.raviente.register.startTime = 67890 + server.raviente.register.killedTime = 11111 + server.raviente.register.postTime = 22222 + server.raviente.register.ravienteType = 3 + server.raviente.register.maxPlayers = 32 + server.raviente.register.carveQuest = 44444 + server.raviente.register.register[0] = 100 + server.raviente.register.register[1] = 200 + server.raviente.register.register[2] = 300 + server.raviente.register.register[3] = 400 + server.raviente.register.register[4] = 500 + server.raviente.state.stateData[0] = 999 + server.raviente.state.stateData[14] = 888 + server.raviente.state.stateData[28] = 777 + server.raviente.support.supportData[0] = 666 + server.raviente.support.supportData[12] = 555 + server.raviente.support.supportData[24] = 444 + + resetRavi(session) + + // Verify all register fields reset + if server.raviente.register.nextTime != 0 { + t.Errorf("nextTime = %d, want 0", server.raviente.register.nextTime) + } + if server.raviente.register.startTime != 0 { + t.Errorf("startTime = %d, want 0", server.raviente.register.startTime) + } + if server.raviente.register.killedTime != 0 { + t.Errorf("killedTime = %d, want 0", server.raviente.register.killedTime) + } + if server.raviente.register.postTime != 0 { + t.Errorf("postTime = %d, want 0", server.raviente.register.postTime) + } + if server.raviente.register.ravienteType != 0 { + t.Errorf("ravienteType = %d, want 0", server.raviente.register.ravienteType) + } + if server.raviente.register.maxPlayers != 0 { + t.Errorf("maxPlayers = %d, want 0", server.raviente.register.maxPlayers) + } + if server.raviente.register.carveQuest != 0 { + t.Errorf("carveQuest = %d, want 0", server.raviente.register.carveQuest) + } + + // Verify register array reset + for i, v := range server.raviente.register.register { + if v != 0 { + t.Errorf("register[%d] = %d, want 0", i, v) + } + } + if len(server.raviente.register.register) != 5 { + t.Errorf("register length = %d, want 5", len(server.raviente.register.register)) + } + + // Verify stateData reset + for i, v := range server.raviente.state.stateData { + if v != 0 { + t.Errorf("stateData[%d] = %d, want 0", i, v) + } + } + if len(server.raviente.state.stateData) != 29 { + t.Errorf("stateData length = %d, want 29", len(server.raviente.state.stateData)) + } + + // Verify supportData reset + for i, v := range server.raviente.support.supportData { + if v != 0 { + t.Errorf("supportData[%d] = %d, want 0", i, v) + } + } + if len(server.raviente.support.supportData) != 25 { + t.Errorf("supportData length = %d, want 25", len(server.raviente.support.supportData)) + } +} + +// --- GetRaviMultiplier --- + +func TestGetRaviMultiplier_NoSemaphore(t *testing.T) { + server := createMockServerWithRaviente() + result := server.raviente.GetRaviMultiplier(server) + if result != 0 { + t.Errorf("expected 0, got %f", result) + } +} + +func TestGetRaviMultiplier_LargeRavi_EnoughPlayers(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.maxPlayers = 32 // > 8, so minPlayers=24 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + // Need > 24 clients + for i := 0; i < 25; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + result := server.raviente.GetRaviMultiplier(server) + if result != 1 { + t.Errorf("expected 1, got %f", result) + } +} + +func TestGetRaviMultiplier_LargeRavi_NotEnoughPlayers(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.maxPlayers = 32 // > 8, so minPlayers=24 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + // 12 clients < 24 minPlayers => multiplier = 24/12 = 2 + for i := 0; i < 12; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + result := server.raviente.GetRaviMultiplier(server) + expected := float64(24 / 12) // integer division: 2 + if result != expected { + t.Errorf("expected %f, got %f", expected, result) + } +} + +func TestGetRaviMultiplier_SmallRavi_EnoughPlayers(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.maxPlayers = 4 // <= 8, so minPlayers=4 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + // Need > 4 clients + for i := 0; i < 5; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + result := server.raviente.GetRaviMultiplier(server) + if result != 1 { + t.Errorf("expected 1, got %f", result) + } +} + +func TestGetRaviMultiplier_SmallRavi_NotEnoughPlayers(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.register.maxPlayers = 4 // <= 8, so minPlayers=4 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + // 2 clients < 4 minPlayers => multiplier = 4/2 = 2 + for i := 0; i < 2; i++ { + s := createMockSession(uint32(100+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + result := server.raviente.GetRaviMultiplier(server) + expected := float64(4 / 2) // integer division: 2 + if result != expected { + t.Errorf("expected %f, got %f", expected, result) + } +} + +// --- handleMsgSysNotifyRegister (empty handler) --- + +func TestHandleMsgSysNotifyRegister(t *testing.T) { + server := createMockServerWithRaviente() + session := createMockSession(1, server) + + // Should not panic - handler is empty + handleMsgSysNotifyRegister(session, &mhfpacket.MsgSysNotifyRegister{ + RegisterID: 4, + }) + + // No response expected from empty handler + select { + case <-session.sendPackets: + t.Error("empty handler should not queue packets") + default: + // Expected + } +} + +// --- State op2 with multiplier applied (normal dest, not 17 or 28) --- + +func TestHandleMsgSysOperateRegister_State_Op2_WithMultiplier(t *testing.T) { + server := createMockServerWithRaviente() + server.raviente.state.stateData[5] = 100 + server.raviente.register.maxPlayers = 32 // large ravi, minPlayers=24 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + // 12 clients < 24 minPlayers => multiplier = 24/12 = 2 + for i := 0; i < 12; i++ { + s := createMockSession(uint32(200+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 5, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // data = uint32(float64(50) * 2.0) = 100 + // stateData[5] = 100 + 100 = 200 + if server.raviente.state.stateData[5] != 200 { + t.Errorf("stateData[5] = %d, want 200", server.raviente.state.stateData[5]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +func TestHandleMsgSysOperateRegister_State_Op2_Dest28_WithMultiplier(t *testing.T) { + // dest=28 should ignore multiplier regardless + server := createMockServerWithRaviente() + server.raviente.state.stateData[28] = 100 + server.raviente.register.maxPlayers = 32 + + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + for i := 0; i < 12; i++ { + s := createMockSession(uint32(200+i), server) + sema.clients[s] = s.charID + } + server.semaphore["ravi"] = sema + + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 4, + RawDataPayload: buildPayload(opEntry{op: 2, dest: 28, data: 50}), + } + handleMsgSysOperateRegister(session, pkt) + + // dest=28 always adds raw data, ignoring multiplier + if server.raviente.state.stateData[28] != 150 { + t.Errorf("stateData[28] = %d, want 150", server.raviente.state.stateData[28]) + } + + select { + case <-session.sendPackets: + default: + t.Error("no response queued") + } +} + +// Test that notifyRavi is called as part of handleMsgSysOperateRegister +// by verifying that clients on the ravi semaphore get notifications. +func TestHandleMsgSysOperateRegister_NotifiesRaviClients(t *testing.T) { + server := createMockServerWithRaviente() + + raviClient := createMockSession(50, server) + sema := &Semaphore{ + id_semaphore: "hs_l0u3B5test_3", + id: 7, + clients: make(map[*Session]uint32), + } + sema.clients[raviClient] = raviClient.charID + server.semaphore["ravi"] = sema + + session := createMockSession(1, server) + pkt := &mhfpacket.MsgSysOperateRegister{ + AckHandle: 1, + SemaphoreID: 5, + RawDataPayload: buildPayload(opEntry{op: 14, dest: 0, data: 1}), + } + handleMsgSysOperateRegister(session, pkt) + + // The calling session should receive the ack response + select { + case <-session.sendPackets: + default: + t.Error("calling session should receive ack response") + } + + // The ravi client should receive a notification + select { + case p := <-raviClient.sendPackets: + if len(p.data) == 0 { + t.Error("ravi client should receive non-empty notification") + } + default: + t.Error("ravi client should receive notification") } } diff --git a/server/channelserver/handlers_reserve_test.go b/server/channelserver/handlers_reserve_test.go index 72fb6fe24..f031fb15f 100644 --- a/server/channelserver/handlers_reserve_test.go +++ b/server/channelserver/handlers_reserve_test.go @@ -6,116 +6,98 @@ import ( "erupe-ce/network/mhfpacket" ) -// Test that reserve handlers with AckHandle respond correctly - -func TestHandleMsgSysReserve188(t *testing.T) { +func TestReserveHandlersWithAck(t *testing.T) { server := createMockServer() session := createMockSession(1, server) - pkt := &mhfpacket.MsgSysReserve188{ - AckHandle: 12345, - } - - handleMsgSysReserve188(session, pkt) - - // Verify response packet was queued + // Test handleMsgSysReserve188 + handleMsgSysReserve188(session, &mhfpacket.MsgSysReserve188{AckHandle: 12345}) select { case p := <-session.sendPackets: if len(p.data) == 0 { - t.Error("Response packet should have data") + t.Error("Reserve188: response should have data") } default: - t.Error("No response packet queued") - } -} - -func TestHandleMsgSysReserve18B(t *testing.T) { - server := createMockServer() - session := createMockSession(1, server) - - pkt := &mhfpacket.MsgSysReserve18B{ - AckHandle: 12345, + t.Error("Reserve188: no response queued") } - handleMsgSysReserve18B(session, pkt) - - // Verify response packet was queued + // Test handleMsgSysReserve18B + handleMsgSysReserve18B(session, &mhfpacket.MsgSysReserve18B{AckHandle: 12345}) select { case p := <-session.sendPackets: if len(p.data) == 0 { - t.Error("Response packet should have data") + t.Error("Reserve18B: response should have data") } default: - t.Error("No response packet queued") + t.Error("Reserve18B: no response queued") } } -// Test that empty reserve handlers don't panic - -func TestEmptyReserveHandlers(t *testing.T) { +func TestReserveEmptyHandlers(t *testing.T) { server := createMockServer() session := createMockSession(1, server) tests := []struct { name string handler func(s *Session, p mhfpacket.MHFPacket) + pkt mhfpacket.MHFPacket }{ - {"handleMsgSysReserve55", handleMsgSysReserve55}, - {"handleMsgSysReserve56", handleMsgSysReserve56}, - {"handleMsgSysReserve57", handleMsgSysReserve57}, - {"handleMsgSysReserve01", handleMsgSysReserve01}, - {"handleMsgSysReserve02", handleMsgSysReserve02}, - {"handleMsgSysReserve03", handleMsgSysReserve03}, - {"handleMsgSysReserve04", handleMsgSysReserve04}, - {"handleMsgSysReserve05", handleMsgSysReserve05}, - {"handleMsgSysReserve06", handleMsgSysReserve06}, - {"handleMsgSysReserve07", handleMsgSysReserve07}, - {"handleMsgSysReserve0C", handleMsgSysReserve0C}, - {"handleMsgSysReserve0D", handleMsgSysReserve0D}, - {"handleMsgSysReserve0E", handleMsgSysReserve0E}, - {"handleMsgSysReserve4A", handleMsgSysReserve4A}, - {"handleMsgSysReserve4B", handleMsgSysReserve4B}, - {"handleMsgSysReserve4C", handleMsgSysReserve4C}, - {"handleMsgSysReserve4D", handleMsgSysReserve4D}, - {"handleMsgSysReserve4E", handleMsgSysReserve4E}, - {"handleMsgSysReserve4F", handleMsgSysReserve4F}, - {"handleMsgSysReserve5C", handleMsgSysReserve5C}, - {"handleMsgSysReserve5E", handleMsgSysReserve5E}, - {"handleMsgSysReserve5F", handleMsgSysReserve5F}, - {"handleMsgSysReserve71", handleMsgSysReserve71}, - {"handleMsgSysReserve72", handleMsgSysReserve72}, - {"handleMsgSysReserve73", handleMsgSysReserve73}, - {"handleMsgSysReserve74", handleMsgSysReserve74}, - {"handleMsgSysReserve75", handleMsgSysReserve75}, - {"handleMsgSysReserve76", handleMsgSysReserve76}, - {"handleMsgSysReserve77", handleMsgSysReserve77}, - {"handleMsgSysReserve78", handleMsgSysReserve78}, - {"handleMsgSysReserve79", handleMsgSysReserve79}, - {"handleMsgSysReserve7A", handleMsgSysReserve7A}, - {"handleMsgSysReserve7B", handleMsgSysReserve7B}, - {"handleMsgSysReserve7C", handleMsgSysReserve7C}, - {"handleMsgSysReserve7E", handleMsgSysReserve7E}, - {"handleMsgMhfReserve10F", handleMsgMhfReserve10F}, - {"handleMsgSysReserve180", handleMsgSysReserve180}, - {"handleMsgSysReserve18E", handleMsgSysReserve18E}, - {"handleMsgSysReserve18F", handleMsgSysReserve18F}, - {"handleMsgSysReserve19E", handleMsgSysReserve19E}, - {"handleMsgSysReserve19F", handleMsgSysReserve19F}, - {"handleMsgSysReserve1A4", handleMsgSysReserve1A4}, - {"handleMsgSysReserve1A6", handleMsgSysReserve1A6}, - {"handleMsgSysReserve1A7", handleMsgSysReserve1A7}, - {"handleMsgSysReserve1A8", handleMsgSysReserve1A8}, - {"handleMsgSysReserve1A9", handleMsgSysReserve1A9}, - {"handleMsgSysReserve1AA", handleMsgSysReserve1AA}, - {"handleMsgSysReserve1AB", handleMsgSysReserve1AB}, - {"handleMsgSysReserve1AC", handleMsgSysReserve1AC}, - {"handleMsgSysReserve1AD", handleMsgSysReserve1AD}, - {"handleMsgSysReserve1AE", handleMsgSysReserve1AE}, - {"handleMsgSysReserve1AF", handleMsgSysReserve1AF}, - {"handleMsgSysReserve19B", handleMsgSysReserve19B}, - {"handleMsgSysReserve192", handleMsgSysReserve192}, - {"handleMsgSysReserve193", handleMsgSysReserve193}, - {"handleMsgSysReserve194", handleMsgSysReserve194}, + {"Reserve55", handleMsgSysReserve55, &mhfpacket.MsgSysReserve55{}}, + {"Reserve56", handleMsgSysReserve56, &mhfpacket.MsgSysReserve56{}}, + {"Reserve57", handleMsgSysReserve57, &mhfpacket.MsgSysReserve57{}}, + {"Reserve01", handleMsgSysReserve01, &mhfpacket.MsgSysReserve01{}}, + {"Reserve02", handleMsgSysReserve02, &mhfpacket.MsgSysReserve02{}}, + {"Reserve03", handleMsgSysReserve03, &mhfpacket.MsgSysReserve03{}}, + {"Reserve04", handleMsgSysReserve04, &mhfpacket.MsgSysReserve04{}}, + {"Reserve05", handleMsgSysReserve05, &mhfpacket.MsgSysReserve05{}}, + {"Reserve06", handleMsgSysReserve06, &mhfpacket.MsgSysReserve06{}}, + {"Reserve07", handleMsgSysReserve07, &mhfpacket.MsgSysReserve07{}}, + {"Reserve0C", handleMsgSysReserve0C, &mhfpacket.MsgSysReserve0C{}}, + {"Reserve0D", handleMsgSysReserve0D, &mhfpacket.MsgSysReserve0D{}}, + {"Reserve0E", handleMsgSysReserve0E, &mhfpacket.MsgSysReserve0E{}}, + {"Reserve4A", handleMsgSysReserve4A, &mhfpacket.MsgSysReserve4A{}}, + {"Reserve4B", handleMsgSysReserve4B, &mhfpacket.MsgSysReserve4B{}}, + {"Reserve4C", handleMsgSysReserve4C, &mhfpacket.MsgSysReserve4C{}}, + {"Reserve4D", handleMsgSysReserve4D, &mhfpacket.MsgSysReserve4D{}}, + {"Reserve4E", handleMsgSysReserve4E, &mhfpacket.MsgSysReserve4E{}}, + {"Reserve4F", handleMsgSysReserve4F, &mhfpacket.MsgSysReserve4F{}}, + {"Reserve5C", handleMsgSysReserve5C, &mhfpacket.MsgSysReserve5C{}}, + {"Reserve5E", handleMsgSysReserve5E, &mhfpacket.MsgSysReserve5E{}}, + {"Reserve5F", handleMsgSysReserve5F, &mhfpacket.MsgSysReserve5F{}}, + {"Reserve71", handleMsgSysReserve71, &mhfpacket.MsgSysReserve71{}}, + {"Reserve72", handleMsgSysReserve72, &mhfpacket.MsgSysReserve72{}}, + {"Reserve73", handleMsgSysReserve73, &mhfpacket.MsgSysReserve73{}}, + {"Reserve74", handleMsgSysReserve74, &mhfpacket.MsgSysReserve74{}}, + {"Reserve75", handleMsgSysReserve75, &mhfpacket.MsgSysReserve75{}}, + {"Reserve76", handleMsgSysReserve76, &mhfpacket.MsgSysReserve76{}}, + {"Reserve77", handleMsgSysReserve77, &mhfpacket.MsgSysReserve77{}}, + {"Reserve78", handleMsgSysReserve78, &mhfpacket.MsgSysReserve78{}}, + {"Reserve79", handleMsgSysReserve79, &mhfpacket.MsgSysReserve79{}}, + {"Reserve7A", handleMsgSysReserve7A, &mhfpacket.MsgSysReserve7A{}}, + {"Reserve7B", handleMsgSysReserve7B, &mhfpacket.MsgSysReserve7B{}}, + {"Reserve7C", handleMsgSysReserve7C, &mhfpacket.MsgSysReserve7C{}}, + {"Reserve7E", handleMsgSysReserve7E, &mhfpacket.MsgSysReserve7E{}}, + {"Reserve10F", handleMsgMhfReserve10F, &mhfpacket.MsgMhfReserve10F{}}, + {"Reserve180", handleMsgSysReserve180, &mhfpacket.MsgSysReserve180{}}, + {"Reserve18E", handleMsgSysReserve18E, &mhfpacket.MsgSysReserve18E{}}, + {"Reserve18F", handleMsgSysReserve18F, &mhfpacket.MsgSysReserve18F{}}, + {"Reserve19E", handleMsgSysReserve19E, &mhfpacket.MsgSysReserve19E{}}, + {"Reserve19F", handleMsgSysReserve19F, &mhfpacket.MsgSysReserve19F{}}, + {"Reserve1A4", handleMsgSysReserve1A4, &mhfpacket.MsgSysReserve1A4{}}, + {"Reserve1A6", handleMsgSysReserve1A6, &mhfpacket.MsgSysReserve1A6{}}, + {"Reserve1A7", handleMsgSysReserve1A7, &mhfpacket.MsgSysReserve1A7{}}, + {"Reserve1A8", handleMsgSysReserve1A8, &mhfpacket.MsgSysReserve1A8{}}, + {"Reserve1A9", handleMsgSysReserve1A9, &mhfpacket.MsgSysReserve1A9{}}, + {"Reserve1AA", handleMsgSysReserve1AA, &mhfpacket.MsgSysReserve1AA{}}, + {"Reserve1AB", handleMsgSysReserve1AB, &mhfpacket.MsgSysReserve1AB{}}, + {"Reserve1AC", handleMsgSysReserve1AC, &mhfpacket.MsgSysReserve1AC{}}, + {"Reserve1AD", handleMsgSysReserve1AD, &mhfpacket.MsgSysReserve1AD{}}, + {"Reserve1AE", handleMsgSysReserve1AE, &mhfpacket.MsgSysReserve1AE{}}, + {"Reserve1AF", handleMsgSysReserve1AF, &mhfpacket.MsgSysReserve1AF{}}, + {"Reserve19B", handleMsgSysReserve19B, &mhfpacket.MsgSysReserve19B{}}, + {"Reserve192", handleMsgSysReserve192, &mhfpacket.MsgSysReserve192{}}, + {"Reserve193", handleMsgSysReserve193, &mhfpacket.MsgSysReserve193{}}, + {"Reserve194", handleMsgSysReserve194, &mhfpacket.MsgSysReserve194{}}, } for _, tt := range tests { @@ -125,29 +107,7 @@ func TestEmptyReserveHandlers(t *testing.T) { t.Errorf("%s panicked: %v", tt.name, r) } }() - - // Call with nil packet - empty handlers should handle this - tt.handler(session, nil) + tt.handler(session, tt.pkt) }) } } - -// Test reserve handlers are registered in handler table - -func TestReserveHandlersRegistered(t *testing.T) { - if handlerTable == nil { - t.Fatal("handlerTable should be initialized") - } - - // Check that reserve handlers exist in the table - reserveHandlerCount := 0 - for _, handler := range handlerTable { - if handler != nil { - reserveHandlerCount++ - } - } - - if reserveHandlerCount < 50 { - t.Errorf("Expected at least 50 handlers registered, got %d", reserveHandlerCount) - } -} diff --git a/server/entranceserver/entrance_server_test.go b/server/entranceserver/entrance_server_test.go index 0bb260c1b..116378143 100644 --- a/server/entranceserver/entrance_server_test.go +++ b/server/entranceserver/entrance_server_test.go @@ -1,9 +1,13 @@ package entranceserver import ( + "net" "testing" + "time" "erupe-ce/config" + + "go.uber.org/zap" ) func TestNewServer(t *testing.T) { @@ -292,3 +296,248 @@ func TestServerMutexLocking(t *testing.T) { t.Error("Mutex should protect isShuttingDown flag") } } + +func TestServerStartAndShutdown(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Entrance: config.Entrance{ + Enabled: true, + Port: 0, // Use port 0 to get a random available port + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Verify listener is set + if s.listener == nil { + t.Error("Server listener should not be nil after Start()") + } + + // Verify not shutting down initially + s.Lock() + if s.isShuttingDown { + t.Error("Server should not be shutting down after Start()") + } + s.Unlock() + + // Shutdown + s.Shutdown() + + // Verify shutdown flag is set + s.Lock() + if !s.isShuttingDown { + t.Error("Server should be shutting down after Shutdown()") + } + s.Unlock() +} + +func TestServerStartWithInvalidPort(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Entrance: config.Entrance{ + Port: 1, // Privileged port, should fail to bind + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err == nil { + s.Shutdown() + t.Error("Start() should fail with invalid port") + } +} + +func TestServerListenerAddress(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Entrance: config.Entrance{ + Enabled: true, + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + addr := s.listener.Addr() + if addr == nil { + t.Error("Listener address should not be nil") + } + + tcpAddr, ok := addr.(*net.TCPAddr) + if !ok { + t.Error("Listener address should be a TCP address") + } + + if tcpAddr.Port == 0 { + t.Error("Listener port should be assigned") + } +} + +func TestServerAcceptClientsExitsOnShutdown(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + Entrance: config.Entrance{ + Enabled: true, + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + + // Give acceptClients goroutine time to start + time.Sleep(10 * time.Millisecond) + + // Shutdown should cause acceptClients to exit + s.Shutdown() + + // Give time for graceful exit + time.Sleep(10 * time.Millisecond) + + s.Lock() + if !s.isShuttingDown { + t.Error("Server should be marked as shutting down") + } + s.Unlock() +} + +func TestServerHandleConnectionImmediateClose(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Entrance: config.Entrance{ + Enabled: true, + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + // Connect and immediately close - handleEntranceServerConnection should handle gracefully + addr := s.listener.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() error: %v", err) + } + conn.Close() + + // Give time for handleEntranceServerConnection to process the error + time.Sleep(50 * time.Millisecond) +} + +func TestServerHandleConnectionShortInit(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Entrance: config.Entrance{ + Enabled: true, + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + // Send only 4 bytes instead of 8, then close + addr := s.listener.Addr().String() + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() error: %v", err) + } + _, _ = conn.Write([]byte{0, 0, 0, 0}) + conn.Close() + + time.Sleep(50 * time.Millisecond) +} + +func TestServerMultipleConnections(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + Entrance: config.Entrance{ + Enabled: true, + Port: 0, + }, + } + + cfg := &Config{ + Logger: logger, + ErupeConfig: erupeConfig, + } + + s := NewServer(cfg) + err := s.Start() + if err != nil { + t.Fatalf("Start() error: %v", err) + } + defer s.Shutdown() + + addr := s.listener.Addr().String() + + // Create multiple connections and close them + conns := make([]net.Conn, 3) + for i := range conns { + conn, err := net.Dial("tcp", addr) + if err != nil { + t.Fatalf("Dial() %d error: %v", i, err) + } + conns[i] = conn + } + + time.Sleep(50 * time.Millisecond) + + for _, conn := range conns { + conn.Close() + } +} diff --git a/server/signserver/dbutils_test.go b/server/signserver/dbutils_test.go index ffd8974ef..0a80ce41e 100644 --- a/server/signserver/dbutils_test.go +++ b/server/signserver/dbutils_test.go @@ -916,3 +916,352 @@ func TestGetGuildmatesNotInGuild(t *testing.T) { t.Errorf("unfulfilled expectations: %v", err) } } + +// TestGetFriendsForCharactersDBError tests getFriendsForCharacters when DB query fails +func TestGetFriendsForCharactersDBError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Get friends CSV for character - DB error + mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1"). + WithArgs(uint32(1)). + WillReturnError(sql.ErrNoRows) + + // Even on error, still produces the friend query (with empty/error friendsCSV) + // The function calls Scan which fails, then continues to build a query + // with the empty string. The query then fails as well. + mock.ExpectQuery("SELECT id, name FROM characters"). + WillReturnError(sql.ErrConnDone) + + friends := server.getFriendsForCharacters(chars) + // Should return 0 friends on error + if len(friends) != 0 { + t.Errorf("getFriendsForCharacters() with DB error = %d, want 0", len(friends)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestGetGuildmatesForCharactersGuildQueryError tests guild ID query failure +func TestGetGuildmatesForCharactersGuildQueryError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Check if in guild - yes + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Get guild ID - error + mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnError(sql.ErrConnDone) + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 0 { + t.Errorf("getGuildmatesForCharacters() with guild query error = %d, want 0", len(guildmates)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestGetGuildmatesForCharactersGuildmatesQueryError tests guildmates query failure +func TestGetGuildmatesForCharactersGuildmatesQueryError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + } + + // Check if in guild - yes + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Get guild ID + mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"guild_id"}).AddRow(100)) + + // Get guildmates - error + mock.ExpectQuery("SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=\\$1 AND character_id!=\\$2"). + WithArgs(100, uint32(1)). + WillReturnError(sql.ErrConnDone) + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 0 { + t.Errorf("getGuildmatesForCharacters() with guildmates query error = %d, want 0", len(guildmates)) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestDeleteCharacterDeleteError tests deleteCharacter when the delete/update query fails +func TestDeleteCharacterDeleteError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Token verification + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). + WithArgs("validtoken"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Check if new character + mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). + WithArgs(123). + WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false)) + + // Soft delete fails + mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1"). + WithArgs(123). + WillReturnError(sql.ErrConnDone) + + err := server.deleteCharacter(123, "validtoken") + if err == nil { + t.Error("deleteCharacter() should return error when update fails") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestNewUserCharaInsertError tests newUserChara when the INSERT fails +func TestNewUserCharaInsertError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Check for existing new characters + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + // Insert new character - error + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnError(sql.ErrConnDone) + + err := server.newUserChara("testuser") + if err == nil { + t.Error("newUserChara() should return error when insert fails") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestNewUserCharaCountError tests newUserChara when the COUNT query fails +func TestNewUserCharaCountError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Check for existing new characters - error + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). + WithArgs(1). + WillReturnError(sql.ErrConnDone) + + err := server.newUserChara("testuser") + if err == nil { + t.Error("newUserChara() should return error when count query fails") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestRegisterDBAccountGetIDError tests registerDBAccount when getting the new user ID fails +func TestRegisterDBAccountGetIDError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Insert user succeeds + mock.ExpectExec("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\)"). + WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Get user ID - error + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnError(sql.ErrConnDone) + + err := server.registerDBAccount("newuser", "password123") + if err == nil { + t.Error("registerDBAccount() should return error when getting ID fails") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestRegisterDBAccountCharacterInsertError tests registerDBAccount when character insert fails +func TestRegisterDBAccountCharacterInsertError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Insert user + mock.ExpectExec("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\)"). + WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Insert character - error + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnError(sql.ErrConnDone) + + err := server.registerDBAccount("newuser", "password123") + if err == nil { + t.Error("registerDBAccount() should return error when character insert fails") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestGetReturnExpiryDBError tests getReturnExpiry when the return_expires query fails +func TestGetReturnExpiryDBError(t *testing.T) { + server, mock := newTestServerWithMock(t) + + // Get last login - recent + recentLogin := time.Now().Add(-time.Hour * 24) + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(recentLogin)) + + // Get return expiry - error + mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnError(sql.ErrNoRows) + + // Should set return_expires to now + mock.ExpectExec("UPDATE users SET return_expires=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Update last login + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + expiry := server.getReturnExpiry(1) + + // Should still return a valid time (approximately now) + if expiry.IsZero() { + t.Error("getReturnExpiry() should return non-zero time even on error") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestGetFriendsForCharactersMultipleChars tests with multiple characters +func TestGetFriendsForCharactersMultipleChars(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + {ID: 2, Name: "Hunter2"}, + } + + // First character friends + mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"friends"}).AddRow("10")) + + mock.ExpectQuery("SELECT id, name FROM characters WHERE id=10"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "Friend1")) + + // Second character friends + mock.ExpectQuery("SELECT friends FROM characters WHERE id=\\$1"). + WithArgs(uint32(2)). + WillReturnRows(sqlmock.NewRows([]string{"friends"}).AddRow("20")) + + mock.ExpectQuery("SELECT id, name FROM characters WHERE id=20"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(20, "Friend2")) + + friends := server.getFriendsForCharacters(chars) + if len(friends) != 2 { + t.Errorf("getFriendsForCharacters() = %d, want 2", len(friends)) + } + + // Verify CID assignment + if len(friends) >= 2 { + if friends[0].CID != 1 { + t.Errorf("friends[0].CID = %d, want 1", friends[0].CID) + } + if friends[1].CID != 2 { + t.Errorf("friends[1].CID = %d, want 2", friends[1].CID) + } + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestGetGuildmatesForCharactersMultipleChars tests with multiple characters in guilds +func TestGetGuildmatesForCharactersMultipleChars(t *testing.T) { + server, mock := newTestServerWithMock(t) + + chars := []character{ + {ID: 1, Name: "Hunter1"}, + {ID: 2, Name: "Hunter2"}, + } + + // First character in guild + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + mock.ExpectQuery("SELECT guild_id FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"guild_id"}).AddRow(100)) + + mock.ExpectQuery("SELECT character_id AS id, c.name FROM guild_characters gc JOIN characters c ON c.id = gc.character_id WHERE guild_id=\\$1 AND character_id!=\\$2"). + WithArgs(100, uint32(1)). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(10, "Guildmate1")) + + // Second character not in guild + mock.ExpectQuery("SELECT count\\(\\*\\) FROM guild_characters WHERE character_id=\\$1"). + WithArgs(uint32(2)). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + guildmates := server.getGuildmatesForCharacters(chars) + if len(guildmates) != 1 { + t.Errorf("getGuildmatesForCharacters() = %d, want 1", len(guildmates)) + } + + if len(guildmates) >= 1 && guildmates[0].CID != 1 { + t.Errorf("guildmates[0].CID = %d, want 1", guildmates[0].CID) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} diff --git a/server/signserver/session_test.go b/server/signserver/session_test.go index b408954af..651ddd7cf 100644 --- a/server/signserver/session_test.go +++ b/server/signserver/session_test.go @@ -2,6 +2,7 @@ package signserver import ( "bytes" + "database/sql" "io" "net" "sync" @@ -12,7 +13,10 @@ import ( "erupe-ce/config" "erupe-ce/network" + "github.com/DATA-DOG/go-sqlmock" + "github.com/jmoiron/sqlx" "go.uber.org/zap" + "golang.org/x/crypto/bcrypt" ) // mockConn implements net.Conn for testing @@ -446,7 +450,781 @@ func TestSessionWorkWithEmptyRead(t *testing.T) { session.work() } -// Note: Tests for handleDSGNRequest require a database connection. -// The function immediately queries the database for user authentication. -// These tests should be implemented as integration tests with a test database -// or using sqlmock for database mocking. +// TestHandlePacketDSGNRequest tests the DSGN:100 path with a mocked database. +func TestHandlePacketDSGNRequest(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + // Use net.Pipe for bidirectional communication + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Create a DSGN:100 packet with username "testuser" and password "testpass" + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser")) + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user not found, auto-create off + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnError(sql.ErrNoRows) + + // Read the response in a goroutine + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + // Allow response to be sent + time.Sleep(50 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDLTSKEYSIGN tests the DLTSKEYSIGN:100 path (falls through to DSGN:100) +func TestHandlePacketDLTSKEYSIGN(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Create a DLTSKEYSIGN:100 packet + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DLTSKEYSIGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser")) + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user not found + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnError(sql.ErrNoRows) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(50 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDELETE tests the DELETE:100 path +func TestHandlePacketDELETE(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Create a DELETE:100 packet + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DELETE:100")) + bf.WriteNullTerminatedBytes([]byte("login-token-abc")) + bf.WriteUint32(123) // characterID + bf.WriteUint32(456) // login_token_number + + // Mock DB: Token verification + mock.ExpectQuery("SELECT count\\(\\*\\) FROM sign_sessions WHERE token = \\$1"). + WithArgs("login-token-abc"). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(1)) + + // Check if new character + mock.ExpectQuery("SELECT is_new_character FROM characters WHERE id = \\$1"). + WithArgs(123). + WillReturnRows(sqlmock.NewRows([]string{"is_new_character"}).AddRow(false)) + + // Soft delete + mock.ExpectExec("UPDATE characters SET deleted = true WHERE id = \\$1"). + WithArgs(123). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // Read all response data in a goroutine (SendPacket writes header + encrypted data) + done := make(chan []byte, 1) + go func() { + var all []byte + buf := make([]byte, 4096) + for { + n, readErr := clientConn.Read(buf) + if n > 0 { + all = append(all, buf[:n]...) + } + if readErr != nil { + break + } + } + done <- all + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + // Close server side so the reader goroutine finishes + serverConn.Close() + + select { + case <-done: + // Response received successfully + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for response") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNWithAutoCreate tests DSGN:100 with auto-create account enabled +func TestHandlePacketDSGNWithAutoCreate(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: true, + DevModeOptions: config.DevModeOptions{ + AutoCreateAccount: true, + }, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("newuser")) + bf.WriteNullTerminatedBytes([]byte("newpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user not found + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnError(sql.ErrNoRows) + + // Auto-create: insert user + mock.ExpectExec("INSERT INTO users \\(username, password, return_expires\\) VALUES \\(\\$1, \\$2, \\$3\\)"). + WithArgs("newuser", sqlmock.AnyArg(), sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Auto-create: get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // Auto-create: insert character + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // Now get new user ID for makeSignInResp + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("newuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // makeSignInResp calls getReturnExpiry + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(time.Now())) + + // getReturnExpiry: get return_expires + mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30))) + + // getReturnExpiry: update last_login + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // getCharactersForUser + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hrp", "gr", "weapon_type", "last_login"})) + + // registerToken + mock.ExpectExec("INSERT INTO sign_sessions \\(user_id, token\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // getLastCID + mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(0)) + + // getUserRights + mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(2)) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNWithValidPassword tests DSGN:100 with correct password +func TestHandlePacketDSGNWithValidPassword(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Generate a bcrypt hash for "testpass" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("testpass"), bcrypt.MinCost) + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("existinguser")) + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user found with correct password + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("existinguser"). + WillReturnRows(sqlmock.NewRows([]string{"id", "password"}).AddRow(1, string(hashedPassword))) + + // makeSignInResp calls getReturnExpiry + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(time.Now())) + + // getReturnExpiry: get return_expires + mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30))) + + // getReturnExpiry: update last_login + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + // getCharactersForUser + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hrp", "gr", "weapon_type", "last_login"})) + + // registerToken + mock.ExpectExec("INSERT INTO sign_sessions \\(user_id, token\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // getLastCID + mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(0)) + + // getUserRights + mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(2)) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNWrongPassword tests DSGN:100 with wrong password +func TestHandlePacketDSGNWrongPassword(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Generate a bcrypt hash for "correctpass" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("correctpass"), bcrypt.MinCost) + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser")) + bf.WriteNullTerminatedBytes([]byte("wrongpass")) // Wrong password + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user found but password will not match + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id", "password"}).AddRow(1, string(hashedPassword))) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNWithDBError tests DSGN:100 with a database error +func TestHandlePacketDSGNWithDBError(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser")) + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: generic error + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnError(sql.ErrConnDone) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNNewCharaRequest tests DSGN:100 with the '+' suffix for new character +func TestHandlePacketDSGNNewCharaRequest(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: false, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + // Generate a bcrypt hash for "testpass" + hashedPassword, _ := bcrypt.GenerateFromPassword([]byte("testpass"), bcrypt.MinCost) + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser+")) // '+' suffix means new character request + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user found + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id", "password"}).AddRow(1, string(hashedPassword))) + + // newUserChara: get user ID + mock.ExpectQuery("SELECT id FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnRows(sqlmock.NewRows([]string{"id"}).AddRow(1)) + + // newUserChara: check existing new chars + mock.ExpectQuery("SELECT COUNT\\(\\*\\) FROM characters WHERE user_id = \\$1 AND is_new_character = true"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(0)) + + // newUserChara: insert character + mock.ExpectExec("INSERT INTO characters"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + // makeSignInResp calls + mock.ExpectQuery("SELECT COALESCE\\(last_login, now\\(\\)\\) FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_login"}).AddRow(time.Now())) + + mock.ExpectQuery("SELECT return_expires FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"return_expires"}).AddRow(time.Now().Add(time.Hour * 24 * 30))) + + mock.ExpectExec("UPDATE users SET last_login=\\$1 WHERE id=\\$2"). + WithArgs(sqlmock.AnyArg(), 1). + WillReturnResult(sqlmock.NewResult(0, 1)) + + mock.ExpectQuery("SELECT id, is_female, is_new_character, name, unk_desc_string, hrp, gr, weapon_type, last_login FROM characters WHERE user_id = \\$1 AND deleted = false ORDER BY id ASC"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"id", "is_female", "is_new_character", "name", "unk_desc_string", "hrp", "gr", "weapon_type", "last_login"})) + + mock.ExpectExec("INSERT INTO sign_sessions \\(user_id, token\\) VALUES \\(\\$1, \\$2\\)"). + WithArgs(1, sqlmock.AnyArg()). + WillReturnResult(sqlmock.NewResult(1, 1)) + + mock.ExpectQuery("SELECT last_character FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"last_character"}).AddRow(0)) + + mock.ExpectQuery("SELECT rights FROM users WHERE id=\\$1"). + WithArgs(1). + WillReturnRows(sqlmock.NewRows([]string{"rights"}).AddRow(2)) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} + +// TestHandlePacketDSGNWithDevModeOutboundLogging tests dev mode outbound logging +func TestHandlePacketDSGNWithDevModeOutboundLogging(t *testing.T) { + logger := zap.NewNop() + erupeConfig := &config.Config{ + DevMode: true, + DevModeOptions: config.DevModeOptions{ + LogOutboundMessages: true, + }, + } + + db, mock, err := sqlmock.New() + if err != nil { + t.Fatalf("failed to create sqlmock: %v", err) + } + sqlxDB := sqlx.NewDb(db, "sqlmock") + + server := &Server{ + logger: logger, + erupeConfig: erupeConfig, + db: sqlxDB, + } + + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + + session := &Session{ + logger: logger, + server: server, + rawConn: serverConn, + cryptConn: network.NewCryptConn(serverConn), + } + + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:100")) + bf.WriteNullTerminatedBytes([]byte("testuser")) + bf.WriteNullTerminatedBytes([]byte("testpass")) + bf.WriteNullTerminatedBytes([]byte("unk")) + + // Mock DB: user not found, dev mode but no auto create + mock.ExpectQuery("SELECT id, password FROM users WHERE username = \\$1"). + WithArgs("testuser"). + WillReturnError(sql.ErrNoRows) + + done := make(chan struct{}) + go func() { + defer close(done) + buf := make([]byte, 4096) + for { + _, err := clientConn.Read(buf) + if err != nil { + return + } + } + }() + + err = session.handlePacket(bf.Data()) + if err != nil { + t.Errorf("handlePacket() returned error: %v", err) + } + + time.Sleep(100 * time.Millisecond) + clientConn.Close() + <-done + + if err := mock.ExpectationsWereMet(); err != nil { + t.Errorf("unfulfilled expectations: %v", err) + } +} diff --git a/tools/usercheck/db_test.go b/tools/usercheck/db_test.go new file mode 100644 index 000000000..31d05baa3 --- /dev/null +++ b/tools/usercheck/db_test.go @@ -0,0 +1,1036 @@ +package main + +import ( + "database/sql" + "flag" + "os" + "path/filepath" + "testing" + "time" +) + +// --------------------------------------------------------------------------- +// escapeConnStringValue +// --------------------------------------------------------------------------- + +func TestEscapeConnStringValue_Empty(t *testing.T) { + got := escapeConnStringValue("") + if got != "" { + t.Errorf("expected empty string, got %q", got) + } +} + +func TestEscapeConnStringValue_NoSpecialChars(t *testing.T) { + got := escapeConnStringValue("hello world") + if got != "hello world" { + t.Errorf("expected %q, got %q", "hello world", got) + } +} + +func TestEscapeConnStringValue_SingleQuote(t *testing.T) { + got := escapeConnStringValue("it's") + want := "it''s" + if got != want { + t.Errorf("expected %q, got %q", want, got) + } +} + +func TestEscapeConnStringValue_Backslash(t *testing.T) { + got := escapeConnStringValue(`path\to\file`) + want := `path\\to\\file` + if got != want { + t.Errorf("expected %q, got %q", want, got) + } +} + +func TestEscapeConnStringValue_BothQuoteAndBackslash(t *testing.T) { + got := escapeConnStringValue(`it's\path`) + want := `it''s\\path` + if got != want { + t.Errorf("expected %q, got %q", want, got) + } +} + +func TestEscapeConnStringValue_MultipleConsecutiveQuotes(t *testing.T) { + got := escapeConnStringValue("a'''b") + want := "a''''''b" + if got != want { + t.Errorf("expected %q, got %q", want, got) + } +} + +func TestEscapeConnStringValue_OnlySpecialChars(t *testing.T) { + got := escapeConnStringValue(`'\`) + want := `''\\` + if got != want { + t.Errorf("expected %q, got %q", want, got) + } +} + +func TestEscapeConnStringValue_Table(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + {"unicode", "p@$$w\u00f6rd", "p@$$w\u00f6rd"}, + {"spaces only", " ", " "}, + {"leading quote", "'start", "''start"}, + {"trailing quote", "end'", "end''"}, + {"trailing backslash", `end\`, `end\\`}, + {"multiple backslashes", `a\\b`, `a\\\\b`}, + {"mixed complex", `x'y\z'w\`, `x''y\\z''w\\`}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := escapeConnStringValue(tt.input) + if got != tt.want { + t.Errorf("escapeConnStringValue(%q) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +// --------------------------------------------------------------------------- +// loadConfigFile +// --------------------------------------------------------------------------- + +func TestLoadConfigFile_Valid(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + content := `{ + "Database": { + "Host": "myhost", + "Port": 1234, + "User": "myuser", + "Password": "mypass", + "Database": "mydb" + } + }` + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := loadConfigFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Database.Host != "myhost" { + t.Errorf("Host = %q, want %q", cfg.Database.Host, "myhost") + } + if cfg.Database.Port != 1234 { + t.Errorf("Port = %d, want %d", cfg.Database.Port, 1234) + } + if cfg.Database.User != "myuser" { + t.Errorf("User = %q, want %q", cfg.Database.User, "myuser") + } + if cfg.Database.Password != "mypass" { + t.Errorf("Password = %q, want %q", cfg.Database.Password, "mypass") + } + if cfg.Database.Database != "mydb" { + t.Errorf("Database = %q, want %q", cfg.Database.Database, "mydb") + } +} + +func TestLoadConfigFile_NonExistent(t *testing.T) { + _, err := loadConfigFile("/tmp/nonexistent_config_test_12345.json") + if err == nil { + t.Fatal("expected error for non-existent file, got nil") + } +} + +func TestLoadConfigFile_InvalidJSON(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte("not valid json {{{"), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadConfigFile(path) + if err == nil { + t.Fatal("expected error for invalid JSON, got nil") + } +} + +func TestLoadConfigFile_EmptyFile(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte(""), 0644); err != nil { + t.Fatal(err) + } + + _, err := loadConfigFile(path) + if err == nil { + t.Fatal("expected error for empty file, got nil") + } +} + +func TestLoadConfigFile_NoDatabaseField(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + if err := os.WriteFile(path, []byte(`{"SomeOther": "field"}`), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := loadConfigFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + // Database fields should be zero-values + if cfg.Database.Host != "" { + t.Errorf("expected empty Host, got %q", cfg.Database.Host) + } + if cfg.Database.Port != 0 { + t.Errorf("expected Port 0, got %d", cfg.Database.Port) + } + if cfg.Database.Password != "" { + t.Errorf("expected empty Password, got %q", cfg.Database.Password) + } +} + +func TestLoadConfigFile_PartialDatabase(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "config.json") + content := `{"Database": {"Host": "partial", "Port": 9999}}` + if err := os.WriteFile(path, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg, err := loadConfigFile(path) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Database.Host != "partial" { + t.Errorf("Host = %q, want %q", cfg.Database.Host, "partial") + } + if cfg.Database.Port != 9999 { + t.Errorf("Port = %d, want %d", cfg.Database.Port, 9999) + } + if cfg.Database.User != "" { + t.Errorf("User = %q, want empty", cfg.Database.User) + } +} + +// --------------------------------------------------------------------------- +// findConfigFile +// --------------------------------------------------------------------------- + +func TestFindConfigFile_NotFound(t *testing.T) { + // Run in a temp directory where no config.json exists. + // We save and restore the working directory so other tests are not affected. + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + + dir := t.TempDir() + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + result := findConfigFile() + if result != "" { + t.Errorf("expected empty string when no config.json exists, got %q", result) + } +} + +func TestFindConfigFile_InCurrentDir(t *testing.T) { + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{}`), 0644); err != nil { + t.Fatal(err) + } + + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + result := findConfigFile() + if result == "" { + t.Error("expected findConfigFile to find config.json in current directory") + } +} + +func TestFindConfigFile_TwoLevelsUp(t *testing.T) { + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + + // Simulate tools/usercheck/ structure: config.json is ../../config.json + dir := t.TempDir() + subDir := filepath.Join(dir, "tools", "usercheck") + if err := os.MkdirAll(subDir, 0755); err != nil { + t.Fatal(err) + } + configPath := filepath.Join(dir, "config.json") + if err := os.WriteFile(configPath, []byte(`{}`), 0644); err != nil { + t.Fatal(err) + } + + if err := os.Chdir(subDir); err != nil { + t.Fatal(err) + } + + result := findConfigFile() + if result == "" { + t.Error("expected findConfigFile to find config.json two levels up") + } +} + +// --------------------------------------------------------------------------- +// addDBFlags +// --------------------------------------------------------------------------- + +func TestAddDBFlags_RegistersAllFlags(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + cfg := &DBConfig{} + addDBFlags(fs, cfg) + + expectedFlags := []string{"config", "host", "port", "user", "password", "dbname"} + for _, name := range expectedFlags { + f := fs.Lookup(name) + if f == nil { + t.Errorf("expected flag %q to be registered, but it was not found", name) + } + } +} + +func TestAddDBFlags_ParseFlags(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + cfg := &DBConfig{} + addDBFlags(fs, cfg) + + args := []string{ + "-host", "dbhost", + "-port", "6543", + "-user", "dbuser", + "-password", "dbpass", + "-dbname", "testdb", + "-config", "/some/path.json", + } + if err := fs.Parse(args); err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + if cfg.Host != "dbhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "dbhost") + } + if cfg.Port != 6543 { + t.Errorf("Port = %d, want %d", cfg.Port, 6543) + } + if cfg.User != "dbuser" { + t.Errorf("User = %q, want %q", cfg.User, "dbuser") + } + if cfg.Password != "dbpass" { + t.Errorf("Password = %q, want %q", cfg.Password, "dbpass") + } + if cfg.DBName != "testdb" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "testdb") + } + if cfg.ConfigPath != "/some/path.json" { + t.Errorf("ConfigPath = %q, want %q", cfg.ConfigPath, "/some/path.json") + } +} + +func TestAddDBFlags_DefaultValues(t *testing.T) { + fs := flag.NewFlagSet("test", flag.ContinueOnError) + cfg := &DBConfig{} + addDBFlags(fs, cfg) + + // Parse with no arguments + if err := fs.Parse(nil); err != nil { + t.Fatalf("unexpected parse error: %v", err) + } + + // All fields should be zero values (defaults from flag package) + if cfg.Host != "" { + t.Errorf("Host = %q, want empty", cfg.Host) + } + if cfg.Port != 0 { + t.Errorf("Port = %d, want 0", cfg.Port) + } + if cfg.User != "" { + t.Errorf("User = %q, want empty", cfg.User) + } + if cfg.Password != "" { + t.Errorf("Password = %q, want empty", cfg.Password) + } +} + +// --------------------------------------------------------------------------- +// resolveDBConfig +// --------------------------------------------------------------------------- + +func TestResolveDBConfig_AllPreset(t *testing.T) { + // Clear environment variables that could interfere + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + cfg := &DBConfig{ + Host: "myhost", + Port: 5555, + User: "myuser", + Password: "mypass", + DBName: "mydb", + } + + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Host != "myhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "myhost") + } + if cfg.Port != 5555 { + t.Errorf("Port = %d, want %d", cfg.Port, 5555) + } + if cfg.User != "myuser" { + t.Errorf("User = %q, want %q", cfg.User, "myuser") + } + if cfg.Password != "mypass" { + t.Errorf("Password = %q, want %q", cfg.Password, "mypass") + } + if cfg.DBName != "mydb" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "mydb") + } +} + +func TestResolveDBConfig_MissingPassword(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + // Change to a temp dir so no config.json is found + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{} + err = resolveDBConfig(cfg) + if err == nil { + t.Fatal("expected error for missing password, got nil") + } + if got := err.Error(); got != "database password is required (set in config.json, use -password flag, or ERUPE_DB_PASSWORD env var)" { + t.Errorf("unexpected error message: %q", got) + } +} + +func TestResolveDBConfig_PasswordFromEnv(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_NAME") + + os.Setenv("ERUPE_DB_PASSWORD", "envpass") + + // Change to a temp dir so no config.json is found + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Password != "envpass" { + t.Errorf("Password = %q, want %q", cfg.Password, "envpass") + } +} + +func TestResolveDBConfig_HostFromEnv(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_NAME") + + os.Setenv("ERUPE_DB_HOST", "envhost") + os.Setenv("ERUPE_DB_PASSWORD", "envpass") + + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.Host != "envhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "envhost") + } +} + +func TestResolveDBConfig_UserFromEnv(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_NAME") + + os.Setenv("ERUPE_DB_USER", "envuser") + os.Setenv("ERUPE_DB_PASSWORD", "envpass") + + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.User != "envuser" { + t.Errorf("User = %q, want %q", cfg.User, "envuser") + } +} + +func TestResolveDBConfig_DBNameFromEnv(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + + os.Setenv("ERUPE_DB_PASSWORD", "envpass") + os.Setenv("ERUPE_DB_NAME", "envdb") + + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + if cfg.DBName != "envdb" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "envdb") + } +} + +func TestResolveDBConfig_DefaultsApplied(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + // Change to a temp dir so no config.json is found + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + if err := os.Chdir(t.TempDir()); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{Password: "provided"} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } + if cfg.Port != 5432 { + t.Errorf("Port = %d, want %d", cfg.Port, 5432) + } + if cfg.User != "postgres" { + t.Errorf("User = %q, want %q", cfg.User, "postgres") + } + if cfg.DBName != "erupe" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "erupe") + } +} + +func TestResolveDBConfig_ConfigFileOverrides(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + content := `{ + "Database": { + "Host": "filehost", + "Port": 7777, + "User": "fileuser", + "Password": "filepass", + "Database": "filedb" + } + }` + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{ConfigPath: configPath} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Host != "filehost" { + t.Errorf("Host = %q, want %q", cfg.Host, "filehost") + } + if cfg.Port != 7777 { + t.Errorf("Port = %d, want %d", cfg.Port, 7777) + } + if cfg.User != "fileuser" { + t.Errorf("User = %q, want %q", cfg.User, "fileuser") + } + if cfg.Password != "filepass" { + t.Errorf("Password = %q, want %q", cfg.Password, "filepass") + } + if cfg.DBName != "filedb" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "filedb") + } +} + +func TestResolveDBConfig_FlagsOverrideConfigFile(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + content := `{ + "Database": { + "Host": "filehost", + "Port": 7777, + "User": "fileuser", + "Password": "filepass", + "Database": "filedb" + } + }` + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + // CLI flags take priority since they're already set in cfg + cfg := &DBConfig{ + ConfigPath: configPath, + Host: "clihost", + Port: 8888, + User: "cliuser", + Password: "clipass", + DBName: "clidb", + } + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if cfg.Host != "clihost" { + t.Errorf("Host = %q, want %q", cfg.Host, "clihost") + } + if cfg.Port != 8888 { + t.Errorf("Port = %d, want %d", cfg.Port, 8888) + } + if cfg.User != "cliuser" { + t.Errorf("User = %q, want %q", cfg.User, "cliuser") + } + if cfg.Password != "clipass" { + t.Errorf("Password = %q, want %q", cfg.Password, "clipass") + } + if cfg.DBName != "clidb" { + t.Errorf("DBName = %q, want %q", cfg.DBName, "clidb") + } +} + +func TestResolveDBConfig_ExplicitConfigPathInvalid(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + cfg := &DBConfig{ + ConfigPath: "/nonexistent/path/config.json", + Password: "pass", + } + err := resolveDBConfig(cfg) + if err == nil { + t.Fatal("expected error when explicitly specifying non-existent config path") + } +} + +func TestResolveDBConfig_AutoDetectedConfigPathInvalid(t *testing.T) { + // When config.json is found by findConfigFile but is invalid JSON, + // resolveDBConfig should silently ignore it (because user didn't explicitly specify it) + origDir, err := os.Getwd() + if err != nil { + t.Fatal(err) + } + defer func() { _ = os.Chdir(origDir) }() + + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_USER") + os.Unsetenv("ERUPE_DB_PASSWORD") + os.Unsetenv("ERUPE_DB_NAME") + + dir := t.TempDir() + // Write a broken config.json + if err := os.WriteFile(filepath.Join(dir, "config.json"), []byte("BROKEN"), 0644); err != nil { + t.Fatal(err) + } + if err := os.Chdir(dir); err != nil { + t.Fatal(err) + } + + cfg := &DBConfig{Password: "pass"} + // Should not error -- broken auto-detected config is silently ignored + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Defaults should be applied + if cfg.Host != "localhost" { + t.Errorf("Host = %q, want %q", cfg.Host, "localhost") + } + if cfg.Port != 5432 { + t.Errorf("Port = %d, want %d", cfg.Port, 5432) + } +} + +func TestResolveDBConfig_EnvOverridesConfig(t *testing.T) { + origHost := os.Getenv("ERUPE_DB_HOST") + origUser := os.Getenv("ERUPE_DB_USER") + origPass := os.Getenv("ERUPE_DB_PASSWORD") + origName := os.Getenv("ERUPE_DB_NAME") + defer func() { + os.Setenv("ERUPE_DB_HOST", origHost) + os.Setenv("ERUPE_DB_USER", origUser) + os.Setenv("ERUPE_DB_PASSWORD", origPass) + os.Setenv("ERUPE_DB_NAME", origName) + }() + + // Config file provides some values, env provides others + dir := t.TempDir() + configPath := filepath.Join(dir, "config.json") + content := `{"Database": {"Host": "filehost", "Port": 7777}}` + if err := os.WriteFile(configPath, []byte(content), 0644); err != nil { + t.Fatal(err) + } + + // Env provides password and user, which config file doesn't set + os.Setenv("ERUPE_DB_PASSWORD", "envpass") + os.Setenv("ERUPE_DB_USER", "envuser") + os.Unsetenv("ERUPE_DB_HOST") + os.Unsetenv("ERUPE_DB_NAME") + + cfg := &DBConfig{ConfigPath: configPath} + if err := resolveDBConfig(cfg); err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // Host comes from config file + if cfg.Host != "filehost" { + t.Errorf("Host = %q, want %q", cfg.Host, "filehost") + } + // User comes from env (config file didn't set it) + if cfg.User != "envuser" { + t.Errorf("User = %q, want %q", cfg.User, "envuser") + } + // Password comes from env + if cfg.Password != "envpass" { + t.Errorf("Password = %q, want %q", cfg.Password, "envpass") + } +} + +// --------------------------------------------------------------------------- +// Struct types construction +// --------------------------------------------------------------------------- + +func TestConnectedUser_Construction(t *testing.T) { + now := time.Now() + u := ConnectedUser{ + CharID: 42, + CharName: "Hunter", + ServerID: 1, + ServerName: "World1", + UserID: 10, + Username: "player1", + LastLogin: sql.NullTime{Time: now, Valid: true}, + HR: 999, + GR: 50, + } + + if u.CharID != 42 { + t.Errorf("CharID = %d, want 42", u.CharID) + } + if u.CharName != "Hunter" { + t.Errorf("CharName = %q, want %q", u.CharName, "Hunter") + } + if u.ServerID != 1 { + t.Errorf("ServerID = %d, want 1", u.ServerID) + } + if u.ServerName != "World1" { + t.Errorf("ServerName = %q, want %q", u.ServerName, "World1") + } + if u.UserID != 10 { + t.Errorf("UserID = %d, want 10", u.UserID) + } + if u.Username != "player1" { + t.Errorf("Username = %q, want %q", u.Username, "player1") + } + if !u.LastLogin.Valid { + t.Error("LastLogin.Valid = false, want true") + } + if u.HR != 999 { + t.Errorf("HR = %d, want 999", u.HR) + } + if u.GR != 50 { + t.Errorf("GR = %d, want 50", u.GR) + } +} + +func TestConnectedUser_NullLastLogin(t *testing.T) { + u := ConnectedUser{ + CharID: 1, + CharName: "Test", + } + if u.LastLogin.Valid { + t.Error("LastLogin.Valid = true, want false for zero value") + } +} + +func TestServerStatus_Construction(t *testing.T) { + s := ServerStatus{ + ServerID: 5, + WorldName: "Frontier", + WorldDesc: "A great server", + Land: 2, + CurrentPlayers: 100, + Season: 1, + } + + if s.ServerID != 5 { + t.Errorf("ServerID = %d, want 5", s.ServerID) + } + if s.WorldName != "Frontier" { + t.Errorf("WorldName = %q, want %q", s.WorldName, "Frontier") + } + if s.WorldDesc != "A great server" { + t.Errorf("WorldDesc = %q, want %q", s.WorldDesc, "A great server") + } + if s.Land != 2 { + t.Errorf("Land = %d, want 2", s.Land) + } + if s.CurrentPlayers != 100 { + t.Errorf("CurrentPlayers = %d, want 100", s.CurrentPlayers) + } + if s.Season != 1 { + t.Errorf("Season = %d, want 1", s.Season) + } +} + +func TestLoginHistory_Construction(t *testing.T) { + now := time.Now() + h := LoginHistory{ + CharID: 7, + CharName: "Veteran", + LastLogin: sql.NullTime{Time: now, Valid: true}, + HR: 500, + GR: 25, + Username: "vet_player", + } + + if h.CharID != 7 { + t.Errorf("CharID = %d, want 7", h.CharID) + } + if h.CharName != "Veteran" { + t.Errorf("CharName = %q, want %q", h.CharName, "Veteran") + } + if !h.LastLogin.Valid { + t.Error("LastLogin.Valid = false, want true") + } + if h.HR != 500 { + t.Errorf("HR = %d, want 500", h.HR) + } + if h.GR != 25 { + t.Errorf("GR = %d, want 25", h.GR) + } + if h.Username != "vet_player" { + t.Errorf("Username = %q, want %q", h.Username, "vet_player") + } +} + +func TestLoginHistory_NullLastLogin(t *testing.T) { + h := LoginHistory{ + CharID: 1, + CharName: "NewPlayer", + } + if h.LastLogin.Valid { + t.Error("LastLogin.Valid = true, want false for zero value") + } +} + +// --------------------------------------------------------------------------- +// DBConfig and ErupeConfig struct construction +// --------------------------------------------------------------------------- + +func TestDBConfig_ZeroValue(t *testing.T) { + cfg := DBConfig{} + if cfg.Host != "" { + t.Errorf("Host = %q, want empty", cfg.Host) + } + if cfg.Port != 0 { + t.Errorf("Port = %d, want 0", cfg.Port) + } + if cfg.User != "" { + t.Errorf("User = %q, want empty", cfg.User) + } + if cfg.Password != "" { + t.Errorf("Password = %q, want empty", cfg.Password) + } + if cfg.DBName != "" { + t.Errorf("DBName = %q, want empty", cfg.DBName) + } + if cfg.ConfigPath != "" { + t.Errorf("ConfigPath = %q, want empty", cfg.ConfigPath) + } +} + +func TestErupeConfig_ZeroValue(t *testing.T) { + cfg := ErupeConfig{} + if cfg.Database.Host != "" { + t.Errorf("Database.Host = %q, want empty", cfg.Database.Host) + } + if cfg.Database.Port != 0 { + t.Errorf("Database.Port = %d, want 0", cfg.Database.Port) + } +}