diff --git a/docker/docker-compose.test.yml b/docker/docker-compose.test.yml new file mode 100644 index 000000000..9feb9ec01 --- /dev/null +++ b/docker/docker-compose.test.yml @@ -0,0 +1,25 @@ +version: "3.9" +# Docker Compose configuration for running integration tests +# Usage: docker-compose -f docker/docker-compose.test.yml up -d +services: + test-db: + image: postgres:15-alpine + container_name: erupe-test-db + environment: + POSTGRES_USER: test + POSTGRES_PASSWORD: test + POSTGRES_DB: erupe_test + ports: + - "5433:5432" # Different port to avoid conflicts with main DB + # Use tmpfs for faster tests (in-memory database) + tmpfs: + - /var/lib/postgresql/data + # Mount schema files for initialization + volumes: + - ../schemas/:/schemas/ + healthcheck: + test: ["CMD-SHELL", "pg_isready -U test -d erupe_test"] + interval: 2s + timeout: 2s + retries: 10 + start_period: 5s diff --git a/server/channelserver/handlers_character_test.go b/server/channelserver/handlers_character_test.go new file mode 100644 index 000000000..ed5ac086c --- /dev/null +++ b/server/channelserver/handlers_character_test.go @@ -0,0 +1,592 @@ +package channelserver + +import ( + "bytes" + "encoding/binary" + "testing" + + _config "erupe-ce/config" + "erupe-ce/network/mhfpacket" + "erupe-ce/server/channelserver/compression/nullcomp" +) + +// TestGetPointers tests the pointer map generation for different game versions +func TestGetPointers(t *testing.T) { + tests := []struct { + name string + clientMode _config.Mode + wantGender int + wantHR int + }{ + { + name: "ZZ_version", + clientMode: _config.ZZ, + wantGender: 81, + wantHR: 130550, + }, + { + name: "Z2_version", + clientMode: _config.Z2, + wantGender: 81, + wantHR: 94550, + }, + { + name: "G10_version", + clientMode: _config.G10, + wantGender: 81, + wantHR: 94550, + }, + { + name: "F5_version", + clientMode: _config.F5, + wantGender: 81, + wantHR: 62550, + }, + { + name: "S6_version", + clientMode: _config.S6, + wantGender: 81, + wantHR: 14550, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Save and restore original config + originalMode := _config.ErupeConfig.RealClientMode + defer func() { _config.ErupeConfig.RealClientMode = originalMode }() + + _config.ErupeConfig.RealClientMode = tt.clientMode + pointers := getPointers() + + if pointers[pGender] != tt.wantGender { + t.Errorf("pGender = %d, want %d", pointers[pGender], tt.wantGender) + } + + if pointers[pHR] != tt.wantHR { + t.Errorf("pHR = %d, want %d", pointers[pHR], tt.wantHR) + } + + // Verify all required pointers exist + requiredPointers := []SavePointer{pGender, pRP, pHouseTier, pHouseData, pBookshelfData, + pGalleryData, pToreData, pGardenData, pPlaytime, pWeaponType, pWeaponID, pHR, lBookshelfData} + + for _, ptr := range requiredPointers { + if _, exists := pointers[ptr]; !exists { + t.Errorf("pointer %v not found in map", ptr) + } + } + }) + } +} + +// TestCharacterSaveData_Compress tests savedata compression +func TestCharacterSaveData_Compress(t *testing.T) { + tests := []struct { + name string + data []byte + wantErr bool + }{ + { + name: "valid_small_data", + data: []byte{0x01, 0x02, 0x03, 0x04}, + wantErr: false, + }, + { + name: "valid_large_data", + data: bytes.Repeat([]byte{0xAA}, 10000), + wantErr: false, + }, + { + name: "empty_data", + data: []byte{}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + save := &CharacterSaveData{ + decompSave: tt.data, + } + + err := save.Compress() + if (err != nil) != tt.wantErr { + t.Errorf("Compress() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr && len(save.compSave) == 0 { + t.Error("compressed save is empty") + } + }) + } +} + +// TestCharacterSaveData_Decompress tests savedata decompression +func TestCharacterSaveData_Decompress(t *testing.T) { + tests := []struct { + name string + setup func() []byte + wantErr bool + }{ + { + name: "valid_compressed_data", + setup: func() []byte { + data := []byte{0x01, 0x02, 0x03, 0x04} + compressed, _ := nullcomp.Compress(data) + return compressed + }, + wantErr: false, + }, + { + name: "valid_large_compressed_data", + setup: func() []byte { + data := bytes.Repeat([]byte{0xBB}, 5000) + compressed, _ := nullcomp.Compress(data) + return compressed + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + save := &CharacterSaveData{ + compSave: tt.setup(), + } + + err := save.Decompress() + if (err != nil) != tt.wantErr { + t.Errorf("Decompress() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr && len(save.decompSave) == 0 { + t.Error("decompressed save is empty") + } + }) + } +} + +// TestCharacterSaveData_RoundTrip tests compression and decompression +func TestCharacterSaveData_RoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "small_data", + data: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + }, + { + name: "repeating_pattern", + data: bytes.Repeat([]byte{0xCC}, 1000), + }, + { + name: "mixed_data", + data: []byte{0x00, 0xFF, 0x01, 0xFE, 0x02, 0xFD, 0x03, 0xFC}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + save := &CharacterSaveData{ + decompSave: tt.data, + } + + // Compress + if err := save.Compress(); err != nil { + t.Fatalf("Compress() failed: %v", err) + } + + // Clear decompressed data + save.decompSave = nil + + // Decompress + if err := save.Decompress(); err != nil { + t.Fatalf("Decompress() failed: %v", err) + } + + // Verify round trip + if !bytes.Equal(save.decompSave, tt.data) { + t.Errorf("round trip failed: got %v, want %v", save.decompSave, tt.data) + } + }) + } +} + +// TestCharacterSaveData_updateStructWithSaveData tests parsing save data +func TestCharacterSaveData_updateStructWithSaveData(t *testing.T) { + originalMode := _config.ErupeConfig.RealClientMode + defer func() { _config.ErupeConfig.RealClientMode = originalMode }() + _config.ErupeConfig.RealClientMode = _config.Z2 + + tests := []struct { + name string + isNewCharacter bool + setupSaveData func() []byte + wantName string + wantGender bool + }{ + { + name: "male_character", + isNewCharacter: false, + setupSaveData: func() []byte { + data := make([]byte, 150000) + copy(data[88:], []byte("TestChar\x00")) + data[81] = 0 // Male + return data + }, + wantName: "TestChar", + wantGender: false, + }, + { + name: "female_character", + isNewCharacter: false, + setupSaveData: func() []byte { + data := make([]byte, 150000) + copy(data[88:], []byte("FemaleChar\x00")) + data[81] = 1 // Female + return data + }, + wantName: "FemaleChar", + wantGender: true, + }, + { + name: "new_character_skips_parsing", + isNewCharacter: true, + setupSaveData: func() []byte { + data := make([]byte, 150000) + copy(data[88:], []byte("NewChar\x00")) + return data + }, + wantName: "NewChar", + wantGender: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + save := &CharacterSaveData{ + Pointers: getPointers(), + decompSave: tt.setupSaveData(), + IsNewCharacter: tt.isNewCharacter, + } + + save.updateStructWithSaveData() + + if save.Name != tt.wantName { + t.Errorf("Name = %q, want %q", save.Name, tt.wantName) + } + + if save.Gender != tt.wantGender { + t.Errorf("Gender = %v, want %v", save.Gender, tt.wantGender) + } + }) + } +} + +// TestCharacterSaveData_updateSaveDataWithStruct tests writing struct to save data +func TestCharacterSaveData_updateSaveDataWithStruct(t *testing.T) { + originalMode := _config.ErupeConfig.RealClientMode + defer func() { _config.ErupeConfig.RealClientMode = originalMode }() + _config.ErupeConfig.RealClientMode = _config.G10 + + tests := []struct { + name string + rp uint16 + kqf []byte + wantRP uint16 + }{ + { + name: "update_rp_value", + rp: 1234, + kqf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}, + wantRP: 1234, + }, + { + name: "zero_rp_value", + rp: 0, + kqf: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00}, + wantRP: 0, + }, + { + name: "max_rp_value", + rp: 65535, + kqf: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, + wantRP: 65535, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + save := &CharacterSaveData{ + Pointers: getPointers(), + decompSave: make([]byte, 150000), + RP: tt.rp, + KQF: tt.kqf, + } + + save.updateSaveDataWithStruct() + + // Verify RP was written correctly + rpOffset := save.Pointers[pRP] + gotRP := binary.LittleEndian.Uint16(save.decompSave[rpOffset : rpOffset+2]) + if gotRP != tt.wantRP { + t.Errorf("RP in save data = %d, want %d", gotRP, tt.wantRP) + } + + // Verify KQF was written correctly + kqfOffset := save.Pointers[pKQF] + gotKQF := save.decompSave[kqfOffset : kqfOffset+8] + if !bytes.Equal(gotKQF, tt.kqf) { + t.Errorf("KQF in save data = %v, want %v", gotKQF, tt.kqf) + } + }) + } +} + +// TestHandleMsgMhfSexChanger tests the sex changer handler +func TestHandleMsgMhfSexChanger(t *testing.T) { + tests := []struct { + name string + ackHandle uint32 + }{ + { + name: "basic_sex_change", + ackHandle: 1234, + }, + { + name: "different_ack_handle", + ackHandle: 9999, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + + pkt := &mhfpacket.MsgMhfSexChanger{ + AckHandle: tt.ackHandle, + } + + handleMsgMhfSexChanger(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + + // Drain the channel + <-s.sendPackets + }) + } +} + +// TestGetCharacterSaveData_Integration tests retrieving character save data from database +func TestGetCharacterSaveData_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Save original config mode + originalMode := _config.ErupeConfig.RealClientMode + defer func() { _config.ErupeConfig.RealClientMode = originalMode }() + _config.ErupeConfig.RealClientMode = _config.Z2 + + tests := []struct { + name string + charName string + isNewCharacter bool + wantError bool + }{ + { + name: "existing_character", + charName: "TestChar", + isNewCharacter: false, + wantError: false, + }, + { + name: "new_character", + charName: "NewChar", + isNewCharacter: true, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test user and character + userID := CreateTestUser(t, db, "testuser_"+tt.name) + charID := CreateTestCharacter(t, db, userID, tt.charName) + + // Update is_new_character flag + _, err := db.Exec("UPDATE characters SET is_new_character = $1 WHERE id = $2", tt.isNewCharacter, charID) + if err != nil { + t.Fatalf("Failed to update character: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + // Get character save data + saveData, err := GetCharacterSaveData(s, charID) + if (err != nil) != tt.wantError { + t.Errorf("GetCharacterSaveData() error = %v, wantErr %v", err, tt.wantError) + return + } + + if !tt.wantError { + if saveData == nil { + t.Fatal("saveData is nil") + } + + if saveData.CharID != charID { + t.Errorf("CharID = %d, want %d", saveData.CharID, charID) + } + + if saveData.Name != tt.charName { + t.Errorf("Name = %q, want %q", saveData.Name, tt.charName) + } + + if saveData.IsNewCharacter != tt.isNewCharacter { + t.Errorf("IsNewCharacter = %v, want %v", saveData.IsNewCharacter, tt.isNewCharacter) + } + } + }) + } +} + +// TestCharacterSaveData_Save_Integration tests saving character data to database +func TestCharacterSaveData_Save_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Save original config mode + originalMode := _config.ErupeConfig.RealClientMode + defer func() { _config.ErupeConfig.RealClientMode = originalMode }() + _config.ErupeConfig.RealClientMode = _config.Z2 + + // Create test user and character + userID := CreateTestUser(t, db, "savetest") + charID := CreateTestCharacter(t, db, userID, "SaveChar") + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + // Load character save data + saveData, err := GetCharacterSaveData(s, charID) + if err != nil { + t.Fatalf("Failed to get save data: %v", err) + } + + // Modify save data + saveData.HR = 999 + saveData.GR = 100 + saveData.Gender = true + saveData.WeaponType = 5 + saveData.WeaponID = 1234 + + // Save it + saveData.Save(s) + + // Reload and verify + var hr, gr uint16 + var gender bool + var weaponType uint8 + var weaponID uint16 + + err = db.QueryRow("SELECT hr, gr, is_female, weapon_type, weapon_id FROM characters WHERE id = $1", + charID).Scan(&hr, &gr, &gender, &weaponType, &weaponID) + if err != nil { + t.Fatalf("Failed to query updated character: %v", err) + } + + if hr != 999 { + t.Errorf("HR = %d, want 999", hr) + } + if gr != 100 { + t.Errorf("GR = %d, want 100", gr) + } + if !gender { + t.Error("Gender should be true (female)") + } + if weaponType != 5 { + t.Errorf("WeaponType = %d, want 5", weaponType) + } + if weaponID != 1234 { + t.Errorf("WeaponID = %d, want 1234", weaponID) + } +} + +// TestGRPtoGR tests the GRP to GR conversion function +func TestGRPtoGR(t *testing.T) { + tests := []struct { + name string + grp int + wantGR uint16 + }{ + { + name: "zero_grp", + grp: 0, + wantGR: 1, // Function returns 1 for 0 GRP + }, + { + name: "low_grp", + grp: 10000, + wantGR: 10, // Function returns 10 for 10000 GRP + }, + { + name: "mid_grp", + grp: 500000, + wantGR: 88, // Function returns 88 for 500000 GRP + }, + { + name: "high_grp", + grp: 2000000, + wantGR: 265, // Function returns 265 for 2000000 GRP + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotGR := grpToGR(tt.grp) + if gotGR != tt.wantGR { + t.Errorf("grpToGR(%d) = %d, want %d", tt.grp, gotGR, tt.wantGR) + } + }) + } +} + +// BenchmarkCompress benchmarks savedata compression +func BenchmarkCompress(b *testing.B) { + data := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 25000) // 100KB + save := &CharacterSaveData{ + decompSave: data, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + save.Compress() + } +} + +// BenchmarkDecompress benchmarks savedata decompression +func BenchmarkDecompress(b *testing.B) { + data := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 25000) + compressed, _ := nullcomp.Compress(data) + + save := &CharacterSaveData{ + compSave: compressed, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + save.Decompress() + } +} diff --git a/server/channelserver/handlers_clients_test.go b/server/channelserver/handlers_clients_test.go new file mode 100644 index 000000000..15708cb51 --- /dev/null +++ b/server/channelserver/handlers_clients_test.go @@ -0,0 +1,604 @@ +package channelserver + +import ( + "fmt" + "testing" + + _config "erupe-ce/config" + "erupe-ce/common/byteframe" + "erupe-ce/network/mhfpacket" + "go.uber.org/zap" +) + +// TestHandleMsgSysEnumerateClient tests client enumeration in stages +func TestHandleMsgSysEnumerateClient(t *testing.T) { + tests := []struct { + name string + stageID string + getType uint8 + setupStage func(*Server, string) + wantClientCount int + wantFailure bool + }{ + { + name: "enumerate_all_clients", + stageID: "test_stage_1", + getType: 0, // All clients + setupStage: func(server *Server, stageID string) { + stage := NewStage(stageID) + mock1 := &MockCryptConn{sentPackets: make([][]byte, 0)} + mock2 := &MockCryptConn{sentPackets: make([][]byte, 0)} + s1 := createTestSession(mock1) + s2 := createTestSession(mock2) + s1.charID = 100 + s2.charID = 200 + stage.clients[s1] = 100 + stage.clients[s2] = 200 + server.stagesLock.Lock() + server.stages[stageID] = stage + server.stagesLock.Unlock() + }, + wantClientCount: 2, + wantFailure: false, + }, + { + name: "enumerate_not_ready_clients", + stageID: "test_stage_2", + getType: 1, // Not ready + setupStage: func(server *Server, stageID string) { + stage := NewStage(stageID) + stage.reservedClientSlots[100] = false // Not ready + stage.reservedClientSlots[200] = true // Ready + stage.reservedClientSlots[300] = false // Not ready + server.stagesLock.Lock() + server.stages[stageID] = stage + server.stagesLock.Unlock() + }, + wantClientCount: 2, // Only not-ready clients + wantFailure: false, + }, + { + name: "enumerate_ready_clients", + stageID: "test_stage_3", + getType: 2, // Ready + setupStage: func(server *Server, stageID string) { + stage := NewStage(stageID) + stage.reservedClientSlots[100] = false // Not ready + stage.reservedClientSlots[200] = true // Ready + stage.reservedClientSlots[300] = true // Ready + server.stagesLock.Lock() + server.stages[stageID] = stage + server.stagesLock.Unlock() + }, + wantClientCount: 2, // Only ready clients + wantFailure: false, + }, + { + name: "enumerate_empty_stage", + stageID: "test_stage_empty", + getType: 0, + setupStage: func(server *Server, stageID string) { + stage := NewStage(stageID) + server.stagesLock.Lock() + server.stages[stageID] = stage + server.stagesLock.Unlock() + }, + wantClientCount: 0, + wantFailure: false, + }, + { + name: "enumerate_nonexistent_stage", + stageID: "nonexistent_stage", + getType: 0, + setupStage: func(server *Server, stageID string) { + // Don't create the stage + }, + wantClientCount: 0, + wantFailure: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test session (which creates a server with erupeConfig) + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + + // Initialize stages map if needed + if s.server.stages == nil { + s.server.stages = make(map[string]*Stage) + } + + // Setup stage + tt.setupStage(s.server, tt.stageID) + + pkt := &mhfpacket.MsgSysEnumerateClient{ + AckHandle: 1234, + StageID: tt.stageID, + Get: tt.getType, + } + + handleMsgSysEnumerateClient(s, pkt) + + // Check if ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + + // Read the ACK packet + ackPkt := <-s.sendPackets + if tt.wantFailure { + // For failures, we can't easily check the exact format + // Just verify something was sent + return + } + + // Parse the response to count clients + // The ackPkt.data contains the full packet structure: + // [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...] + // Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes + if len(ackPkt.data) < 10 { + t.Fatal("ACK packet too small") + } + + // The response data starts after the 10-byte header + // Response format is: [count:uint16][charID1:uint32][charID2:uint32]... + bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header + count := bf.ReadUint16() + + if int(count) != tt.wantClientCount { + t.Errorf("client count = %d, want %d", count, tt.wantClientCount) + } + }) + } +} + +// TestHandleMsgMhfListMember tests listing blacklisted members +func TestHandleMsgMhfListMember_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + tests := []struct { + name string + blockedCSV string + wantBlockCount int + }{ + { + name: "no_blocked_users", + blockedCSV: "", + wantBlockCount: 0, + }, + { + name: "single_blocked_user", + blockedCSV: "2", + wantBlockCount: 1, + }, + { + name: "multiple_blocked_users", + blockedCSV: "2,3,4", + wantBlockCount: 3, + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test user and character (use short names to avoid 15 char limit) + userID := CreateTestUser(t, db, "user_"+tt.name) + charName := fmt.Sprintf("Char%d", i) + charID := CreateTestCharacter(t, db, userID, charName) + + // Create blocked characters + if tt.blockedCSV != "" { + // Create the blocked users + for i := 2; i <= 4; i++ { + blockedUserID := CreateTestUser(t, db, "blocked_user_"+tt.name+"_"+string(rune(i))) + CreateTestCharacter(t, db, blockedUserID, "BlockedChar_"+string(rune(i))) + } + } + + // Set blocked list + _, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", tt.blockedCSV, charID) + if err != nil { + t.Fatalf("Failed to update blocked list: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + pkt := &mhfpacket.MsgMhfListMember{ + AckHandle: 5678, + } + + handleMsgMhfListMember(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + + // Parse response + // The ackPkt.data contains the full packet structure: + // [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...] + // Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes + ackPkt := <-s.sendPackets + if len(ackPkt.data) < 10 { + t.Fatal("ACK packet too small") + } + bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header + count := bf.ReadUint32() + + if int(count) != tt.wantBlockCount { + t.Errorf("blocked count = %d, want %d", count, tt.wantBlockCount) + } + }) + } +} + +// TestHandleMsgMhfOprMember tests blacklist/friendlist operations +func TestHandleMsgMhfOprMember_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + tests := []struct { + name string + isBlacklist bool + operation bool // true = remove, false = add + initialList string + targetCharIDs []uint32 + wantList string + }{ + { + name: "add_to_blacklist", + isBlacklist: true, + operation: false, + initialList: "", + targetCharIDs: []uint32{2}, + wantList: "2", + }, + { + name: "remove_from_blacklist", + isBlacklist: true, + operation: true, + initialList: "2,3,4", + targetCharIDs: []uint32{3}, + wantList: "2,4", + }, + { + name: "add_to_friendlist", + isBlacklist: false, + operation: false, + initialList: "10", + targetCharIDs: []uint32{20}, + wantList: "10,20", + }, + { + name: "remove_from_friendlist", + isBlacklist: false, + operation: true, + initialList: "10,20,30", + targetCharIDs: []uint32{20}, + wantList: "10,30", + }, + { + name: "add_multiple_to_blacklist", + isBlacklist: true, + operation: false, + initialList: "1", + targetCharIDs: []uint32{2, 3}, + wantList: "1,2,3", + }, + } + + for i, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test user and character (use short names to avoid 15 char limit) + userID := CreateTestUser(t, db, "user_"+tt.name) + charName := fmt.Sprintf("OpChar%d", i) + charID := CreateTestCharacter(t, db, userID, charName) + + // Set initial list + column := "blocked" + if !tt.isBlacklist { + column = "friends" + } + _, err := db.Exec("UPDATE characters SET "+column+" = $1 WHERE id = $2", tt.initialList, charID) + if err != nil { + t.Fatalf("Failed to set initial list: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + pkt := &mhfpacket.MsgMhfOprMember{ + AckHandle: 9999, + Blacklist: tt.isBlacklist, + Operation: tt.operation, + CharIDs: tt.targetCharIDs, + } + + handleMsgMhfOprMember(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + <-s.sendPackets + + // Verify the list was updated + var gotList string + err = db.QueryRow("SELECT "+column+" FROM characters WHERE id = $1", charID).Scan(&gotList) + if err != nil { + t.Fatalf("Failed to query updated list: %v", err) + } + + if gotList != tt.wantList { + t.Errorf("list = %q, want %q", gotList, tt.wantList) + } + }) + } +} + +// TestHandleMsgMhfShutClient tests the shut client handler +func TestHandleMsgMhfShutClient(t *testing.T) { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + + pkt := &mhfpacket.MsgMhfShutClient{} + + // Should not panic (handler is empty) + handleMsgMhfShutClient(s, pkt) +} + +// TestHandleMsgSysHideClient tests the hide client handler +func TestHandleMsgSysHideClient(t *testing.T) { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + + tests := []struct { + name string + hide bool + }{ + { + name: "hide_client", + hide: true, + }, + { + name: "show_client", + hide: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pkt := &mhfpacket.MsgSysHideClient{ + Hide: tt.hide, + } + + // Should not panic (handler is empty) + handleMsgSysHideClient(s, pkt) + }) + } +} + +// TestEnumerateClient_ConcurrentAccess tests concurrent stage access +func TestEnumerateClient_ConcurrentAccess(t *testing.T) { + logger, _ := zap.NewDevelopment() + server := &Server{ + logger: logger, + stages: make(map[string]*Stage), + erupeConfig: &_config.Config{ + DebugOptions: _config.DebugOptions{ + LogOutboundMessages: false, + }, + }, + } + + stageID := "concurrent_test_stage" + stage := NewStage(stageID) + + // Add some clients to the stage + for i := uint32(1); i <= 10; i++ { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + sess := createTestSession(mock) + sess.charID = i * 100 + stage.clients[sess] = i * 100 + } + + server.stagesLock.Lock() + server.stages[stageID] = stage + server.stagesLock.Unlock() + + // Run concurrent enumerations + done := make(chan bool, 5) + for i := 0; i < 5; i++ { + go func() { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.server = server + + pkt := &mhfpacket.MsgSysEnumerateClient{ + AckHandle: 3333, + StageID: stageID, + Get: 0, // All clients + } + + handleMsgSysEnumerateClient(s, pkt) + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 5; i++ { + <-done + } +} + +// TestListMember_EmptyDatabase tests listing members when database is empty +func TestListMember_EmptyDatabase_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "emptytest") + charID := CreateTestCharacter(t, db, userID, "EmptyChar") + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + pkt := &mhfpacket.MsgMhfListMember{ + AckHandle: 4444, + } + + handleMsgMhfListMember(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + + ackPkt := <-s.sendPackets + if len(ackPkt.data) < 10 { + t.Fatal("ACK packet too small") + } + bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header + count := bf.ReadUint32() + + if count != 0 { + t.Errorf("empty blocked list should have count 0, got %d", count) + } +} + +// TestOprMember_EdgeCases tests edge cases for member operations +func TestOprMember_EdgeCases_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + tests := []struct { + name string + initialList string + operation bool + targetCharIDs []uint32 + wantList string + }{ + { + name: "add_duplicate_to_list", + initialList: "1,2,3", + operation: false, // add + targetCharIDs: []uint32{2}, + wantList: "1,2,3,2", // CSV helper adds duplicates + }, + { + name: "remove_nonexistent_from_list", + initialList: "1,2,3", + operation: true, // remove + targetCharIDs: []uint32{99}, + wantList: "1,2,3", + }, + { + name: "operate_on_empty_list", + initialList: "", + operation: false, + targetCharIDs: []uint32{1}, + wantList: "1", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test user and character + userID := CreateTestUser(t, db, "edge_"+tt.name) + charID := CreateTestCharacter(t, db, userID, "EdgeChar") + + // Set initial blocked list + _, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", tt.initialList, charID) + if err != nil { + t.Fatalf("Failed to set initial list: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + pkt := &mhfpacket.MsgMhfOprMember{ + AckHandle: 7777, + Blacklist: true, + Operation: tt.operation, + CharIDs: tt.targetCharIDs, + } + + handleMsgMhfOprMember(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + <-s.sendPackets + + // Verify the list + var gotList string + err = db.QueryRow("SELECT blocked FROM characters WHERE id = $1", charID).Scan(&gotList) + if err != nil { + t.Fatalf("Failed to query list: %v", err) + } + + if gotList != tt.wantList { + t.Errorf("list = %q, want %q", gotList, tt.wantList) + } + }) + } +} + +// BenchmarkEnumerateClients benchmarks client enumeration +func BenchmarkEnumerateClients(b *testing.B) { + logger, _ := zap.NewDevelopment() + server := &Server{ + logger: logger, + stages: make(map[string]*Stage), + } + + stageID := "bench_stage" + stage := NewStage(stageID) + + // Add 100 clients to the stage + for i := uint32(1); i <= 100; i++ { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + sess := createTestSession(mock) + sess.charID = i + stage.clients[sess] = i + } + + server.stages[stageID] = stage + + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.server = server + + pkt := &mhfpacket.MsgSysEnumerateClient{ + AckHandle: 8888, + StageID: stageID, + Get: 0, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Clear the packet channel + select { + case <-s.sendPackets: + default: + } + + handleMsgSysEnumerateClient(s, pkt) + <-s.sendPackets + } +} diff --git a/server/channelserver/handlers_data_test.go b/server/channelserver/handlers_data_test.go index 4283f9026..aad819ca9 100644 --- a/server/channelserver/handlers_data_test.go +++ b/server/channelserver/handlers_data_test.go @@ -3,9 +3,12 @@ package channelserver import ( "bytes" "encoding/binary" + "fmt" + "erupe-ce/common/byteframe" "erupe-ce/network" "erupe-ce/network/clientctx" + "erupe-ce/network/mhfpacket" "erupe-ce/server/channelserver/compression/nullcomp" "testing" ) @@ -334,3 +337,318 @@ func BenchmarkPacketQueueing(b *testing.B) { // The current architecture doesn't easily support interface-based testing b.Skip("benchmark requires interface-based CryptConn mock") } + +// ============================================================================ +// Integration Tests (require test database) +// Run with: docker-compose -f docker/docker-compose.test.yml up -d +// ============================================================================ + +// TestHandleMsgMhfSavedata_Integration tests the actual save data handler with database +func TestHandleMsgMhfSavedata_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "testuser") + charID := CreateTestCharacter(t, db, userID, "TestChar") + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.Name = "TestChar" + s.server.db = db + + tests := []struct { + name string + saveType uint8 + payloadFunc func() []byte + wantSuccess bool + }{ + { + name: "blob_save", + saveType: 0, + payloadFunc: func() []byte { + // Create minimal valid savedata (large enough for all game mode pointers) + data := make([]byte, 150000) + copy(data[88:], []byte("TestChar\x00")) // Name at offset 88 + compressed, _ := nullcomp.Compress(data) + return compressed + }, + wantSuccess: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + payload := tt.payloadFunc() + pkt := &mhfpacket.MsgMhfSavedata{ + SaveType: tt.saveType, + AckHandle: 1234, + AllocMemSize: uint32(len(payload)), + DataSize: uint32(len(payload)), + RawDataPayload: payload, + } + + handleMsgMhfSavedata(s, pkt) + + // Check if ACK was sent + if len(s.sendPackets) == 0 { + t.Error("no ACK packet was sent") + } else { + // Drain the channel + <-s.sendPackets + } + + // Verify database was updated (for success case) + if tt.wantSuccess { + var savedData []byte + err := db.QueryRow("SELECT savedata FROM characters WHERE id = $1", charID).Scan(&savedData) + if err != nil { + t.Errorf("failed to query saved data: %v", err) + } + if len(savedData) == 0 { + t.Error("savedata was not written to database") + } + } + }) + } +} + +// TestHandleMsgMhfLoaddata_Integration tests loading character data +func TestHandleMsgMhfLoaddata_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "testuser") + + // Create savedata + saveData := make([]byte, 200) + copy(saveData[88:], []byte("LoadTest\x00")) + compressed, _ := nullcomp.Compress(saveData) + + var charID uint32 + err := db.QueryRow(` + INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary) + VALUES ($1, false, false, 'LoadTest', '', 0, 0, 0, 0, $2, '', '') + RETURNING id + `, userID, compressed).Scan(&charID) + if err != nil { + t.Fatalf("Failed to create test character: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + s.server.userBinaryParts = make(map[userBinaryPartID][]byte) + s.server.userBinaryPartsLock.Lock() + defer s.server.userBinaryPartsLock.Unlock() + + pkt := &mhfpacket.MsgMhfLoaddata{ + AckHandle: 5678, + } + + handleMsgMhfLoaddata(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Error("no ACK packet was sent") + } + + // Verify name was extracted + if s.Name != "LoadTest" { + t.Errorf("character name not loaded, got %q, want %q", s.Name, "LoadTest") + } +} + +// TestHandleMsgMhfSaveScenarioData_Integration tests scenario data saving +func TestHandleMsgMhfSaveScenarioData_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "testuser") + charID := CreateTestCharacter(t, db, userID, "ScenarioTest") + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + scenarioData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A} + + pkt := &mhfpacket.MsgMhfSaveScenarioData{ + AckHandle: 9999, + DataSize: uint32(len(scenarioData)), + RawDataPayload: scenarioData, + } + + handleMsgMhfSaveScenarioData(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Error("no ACK packet was sent") + } else { + <-s.sendPackets + } + + // Verify scenario data was saved + var saved []byte + err := db.QueryRow("SELECT scenariodata FROM characters WHERE id = $1", charID).Scan(&saved) + if err != nil { + t.Fatalf("failed to query scenario data: %v", err) + } + + if !bytes.Equal(saved, scenarioData) { + t.Errorf("scenario data mismatch: got %v, want %v", saved, scenarioData) + } +} + +// TestHandleMsgMhfLoadScenarioData_Integration tests scenario data loading +func TestHandleMsgMhfLoadScenarioData_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "testuser") + + scenarioData := []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44} + + var charID uint32 + err := db.QueryRow(` + INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary, scenariodata) + VALUES ($1, false, false, 'ScenarioLoad', '', 0, 0, 0, 0, $2, '', '', $3) + RETURNING id + `, userID, []byte{0x00, 0x00, 0x00, 0x00}, scenarioData).Scan(&charID) + if err != nil { + t.Fatalf("Failed to create test character: %v", err) + } + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.server.db = db + + pkt := &mhfpacket.MsgMhfLoadScenarioData{ + AckHandle: 1111, + } + + handleMsgMhfLoadScenarioData(s, pkt) + + // Verify ACK was sent + if len(s.sendPackets) == 0 { + t.Fatal("no ACK packet was sent") + } + + // The ACK should contain the scenario data + ackPkt := <-s.sendPackets + if len(ackPkt.data) < len(scenarioData) { + t.Errorf("ACK packet too small: got %d bytes, expected at least %d", len(ackPkt.data), len(scenarioData)) + } +} + +// TestSaveDataCorruptionDetection_Integration tests that corrupted saves are rejected +func TestSaveDataCorruptionDetection_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and character + userID := CreateTestUser(t, db, "testuser") + charID := CreateTestCharacter(t, db, userID, "OriginalName") + + // Create test session + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charID + s.Name = "OriginalName" + s.server.db = db + s.server.erupeConfig.DeleteOnSaveCorruption = false + + // Create save data with a DIFFERENT name (corruption) + corruptedData := make([]byte, 200) + copy(corruptedData[88:], []byte("HackedName\x00")) + compressed, _ := nullcomp.Compress(corruptedData) + + pkt := &mhfpacket.MsgMhfSavedata{ + SaveType: 0, + AckHandle: 4444, + AllocMemSize: uint32(len(compressed)), + DataSize: uint32(len(compressed)), + RawDataPayload: compressed, + } + + handleMsgMhfSavedata(s, pkt) + + // The save should be rejected, connection should be closed + // In a real scenario, s.rawConn.Close() is called + // We can't easily test that, but we can verify the data wasn't saved + + // Check that database wasn't updated with corrupted data + var savedName string + db.QueryRow("SELECT name FROM characters WHERE id = $1", charID).Scan(&savedName) + if savedName == "HackedName" { + t.Error("corrupted save data was incorrectly written to database") + } +} + +// TestConcurrentSaveData_Integration tests concurrent save operations +func TestConcurrentSaveData_Integration(t *testing.T) { + db := SetupTestDB(t) + defer TeardownTestDB(t, db) + + // Create test user and multiple characters + userID := CreateTestUser(t, db, "testuser") + charIDs := make([]uint32, 5) + for i := 0; i < 5; i++ { + charIDs[i] = CreateTestCharacter(t, db, userID, fmt.Sprintf("Char%d", i)) + } + + // Run concurrent saves + done := make(chan bool, 5) + for i := 0; i < 5; i++ { + go func(index int) { + mock := &MockCryptConn{sentPackets: make([][]byte, 0)} + s := createTestSession(mock) + s.charID = charIDs[index] + s.Name = fmt.Sprintf("Char%d", index) + s.server.db = db + + saveData := make([]byte, 200) + copy(saveData[88:], []byte(fmt.Sprintf("Char%d\x00", index))) + compressed, _ := nullcomp.Compress(saveData) + + pkt := &mhfpacket.MsgMhfSavedata{ + SaveType: 0, + AckHandle: uint32(index), + AllocMemSize: uint32(len(compressed)), + DataSize: uint32(len(compressed)), + RawDataPayload: compressed, + } + + handleMsgMhfSavedata(s, pkt) + done <- true + }(i) + } + + // Wait for all saves to complete + for i := 0; i < 5; i++ { + <-done + } + + // Verify all characters were saved + for i := 0; i < 5; i++ { + var saveData []byte + err := db.QueryRow("SELECT savedata FROM characters WHERE id = $1", charIDs[i]).Scan(&saveData) + if err != nil { + t.Errorf("character %d: failed to load savedata: %v", i, err) + } + if len(saveData) == 0 { + t.Errorf("character %d: savedata is empty", i) + } + } +} diff --git a/server/channelserver/testhelpers_db.go b/server/channelserver/testhelpers_db.go new file mode 100644 index 000000000..c9ec16639 --- /dev/null +++ b/server/channelserver/testhelpers_db.go @@ -0,0 +1,260 @@ +package channelserver + +import ( + "fmt" + "os" + "os/exec" + "path/filepath" + "sort" + "strings" + "testing" + + "erupe-ce/server/channelserver/compression/nullcomp" + "github.com/jmoiron/sqlx" + _ "github.com/lib/pq" +) + +// TestDBConfig holds the configuration for the test database +type TestDBConfig struct { + Host string + Port string + User string + Password string + DBName string +} + +// DefaultTestDBConfig returns the default test database configuration +// that matches docker-compose.test.yml +func DefaultTestDBConfig() *TestDBConfig { + return &TestDBConfig{ + Host: getEnv("TEST_DB_HOST", "localhost"), + Port: getEnv("TEST_DB_PORT", "5433"), + User: getEnv("TEST_DB_USER", "test"), + Password: getEnv("TEST_DB_PASSWORD", "test"), + DBName: getEnv("TEST_DB_NAME", "erupe_test"), + } +} + +func getEnv(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue +} + +// SetupTestDB creates a connection to the test database and applies the schema +func SetupTestDB(t *testing.T) *sqlx.DB { + t.Helper() + + config := DefaultTestDBConfig() + connStr := fmt.Sprintf( + "host=%s port=%s user=%s password=%s dbname=%s sslmode=disable", + config.Host, config.Port, config.User, config.Password, config.DBName, + ) + + db, err := sqlx.Open("postgres", connStr) + if err != nil { + t.Skipf("Failed to connect to test database: %v. Run: docker compose -f docker/docker-compose.test.yml up -d", err) + return nil + } + + // Test connection + if err := db.Ping(); err != nil { + db.Close() + t.Skipf("Test database not available: %v. Run: docker compose -f docker/docker-compose.test.yml up -d", err) + return nil + } + + // Clean the database before tests + CleanTestDB(t, db) + + // Apply schema + ApplyTestSchema(t, db) + + return db +} + +// CleanTestDB drops all tables to ensure a clean state +func CleanTestDB(t *testing.T, db *sqlx.DB) { + t.Helper() + + // Drop all tables in the public schema + _, err := db.Exec(` + DO $$ DECLARE + r RECORD; + BEGIN + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP + EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + END $$; + `) + if err != nil { + t.Logf("Warning: Failed to clean database: %v", err) + } +} + +// ApplyTestSchema applies the database schema from init.sql using pg_restore +func ApplyTestSchema(t *testing.T, db *sqlx.DB) { + t.Helper() + + // Find the project root (where schemas/ directory is located) + projectRoot := findProjectRoot(t) + schemaPath := filepath.Join(projectRoot, "schemas", "init.sql") + + // Get the connection config + config := DefaultTestDBConfig() + + // Use pg_restore to load the schema dump + // The init.sql file is a pg_dump custom format, so we need pg_restore + cmd := exec.Command("pg_restore", + "-h", config.Host, + "-p", config.Port, + "-U", config.User, + "-d", config.DBName, + "--no-owner", + "--no-acl", + "-c", // clean (drop) before recreating + schemaPath, + ) + cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", config.Password)) + + output, err := cmd.CombinedOutput() + if err != nil { + // pg_restore may error on first run (no tables to drop), that's usually ok + t.Logf("pg_restore output: %s", string(output)) + // Check if it's a fatal error + if !strings.Contains(string(output), "does not exist") { + t.Logf("pg_restore error (may be non-fatal): %v", err) + } + } + + // Apply patch schemas in order + applyPatchSchemas(t, db, projectRoot) +} + +// applyPatchSchemas applies all patch schema files in numeric order +func applyPatchSchemas(t *testing.T, db *sqlx.DB, projectRoot string) { + t.Helper() + + patchDir := filepath.Join(projectRoot, "schemas", "patch-schema") + entries, err := os.ReadDir(patchDir) + if err != nil { + t.Logf("Warning: Could not read patch-schema directory: %v", err) + return + } + + // Sort patch files numerically + var patchFiles []string + for _, entry := range entries { + if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") { + patchFiles = append(patchFiles, entry.Name()) + } + } + sort.Strings(patchFiles) + + // Apply each patch in its own transaction + for _, filename := range patchFiles { + patchPath := filepath.Join(patchDir, filename) + patchSQL, err := os.ReadFile(patchPath) + if err != nil { + t.Logf("Warning: Failed to read patch file %s: %v", filename, err) + continue + } + + // Start a new transaction for each patch + tx, err := db.Begin() + if err != nil { + t.Logf("Warning: Failed to start transaction for patch %s: %v", filename, err) + continue + } + + _, err = tx.Exec(string(patchSQL)) + if err != nil { + tx.Rollback() + t.Logf("Warning: Failed to apply patch %s: %v", filename, err) + // Continue with other patches even if one fails + } else { + tx.Commit() + } + } +} + +// findProjectRoot finds the project root directory by looking for the schemas directory +func findProjectRoot(t *testing.T) string { + t.Helper() + + // Start from current directory and walk up + dir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get working directory: %v", err) + } + + for { + schemasPath := filepath.Join(dir, "schemas") + if stat, err := os.Stat(schemasPath); err == nil && stat.IsDir() { + return dir + } + + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("Could not find project root (schemas directory not found)") + } + dir = parent + } +} + +// TeardownTestDB closes the database connection +func TeardownTestDB(t *testing.T, db *sqlx.DB) { + t.Helper() + if db != nil { + db.Close() + } +} + +// CreateTestUser creates a test user and returns the user ID +func CreateTestUser(t *testing.T, db *sqlx.DB, username string) uint32 { + t.Helper() + + var userID uint32 + err := db.QueryRow(` + INSERT INTO users (username, password, rights) + VALUES ($1, 'test_password_hash', 0) + RETURNING id + `, username).Scan(&userID) + + if err != nil { + t.Fatalf("Failed to create test user: %v", err) + } + + return userID +} + +// CreateTestCharacter creates a test character and returns the character ID +func CreateTestCharacter(t *testing.T, db *sqlx.DB, userID uint32, name string) uint32 { + t.Helper() + + // Create minimal valid savedata (needs to be large enough for the game to parse) + // The name is at offset 88, and various game mode pointers extend up to ~147KB for ZZ mode + // We need at least 150KB to accommodate all possible pointer offsets + saveData := make([]byte, 150000) // Large enough for all game modes + copy(saveData[88:], append([]byte(name), 0x00)) // Name at offset 88 with null terminator + + // Import the nullcomp package for compression + compressed, err := nullcomp.Compress(saveData) + if err != nil { + t.Fatalf("Failed to compress savedata: %v", err) + } + + var charID uint32 + err = db.QueryRow(` + INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary) + VALUES ($1, false, false, $2, '', 0, 0, 0, 0, $3, '', '') + RETURNING id + `, userID, name, compressed).Scan(&charID) + + if err != nil { + t.Fatalf("Failed to create test character: %v", err) + } + + return charID +}