mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-21 23:22:34 +01:00
tests(integration): more complete tests with integration of a test database,
This commit is contained in:
592
server/channelserver/handlers_character_test.go
Normal file
592
server/channelserver/handlers_character_test.go
Normal file
@@ -0,0 +1,592 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
_config "erupe-ce/config"
|
||||
"erupe-ce/network/mhfpacket"
|
||||
"erupe-ce/server/channelserver/compression/nullcomp"
|
||||
)
|
||||
|
||||
// TestGetPointers tests the pointer map generation for different game versions
|
||||
func TestGetPointers(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
clientMode _config.Mode
|
||||
wantGender int
|
||||
wantHR int
|
||||
}{
|
||||
{
|
||||
name: "ZZ_version",
|
||||
clientMode: _config.ZZ,
|
||||
wantGender: 81,
|
||||
wantHR: 130550,
|
||||
},
|
||||
{
|
||||
name: "Z2_version",
|
||||
clientMode: _config.Z2,
|
||||
wantGender: 81,
|
||||
wantHR: 94550,
|
||||
},
|
||||
{
|
||||
name: "G10_version",
|
||||
clientMode: _config.G10,
|
||||
wantGender: 81,
|
||||
wantHR: 94550,
|
||||
},
|
||||
{
|
||||
name: "F5_version",
|
||||
clientMode: _config.F5,
|
||||
wantGender: 81,
|
||||
wantHR: 62550,
|
||||
},
|
||||
{
|
||||
name: "S6_version",
|
||||
clientMode: _config.S6,
|
||||
wantGender: 81,
|
||||
wantHR: 14550,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Save and restore original config
|
||||
originalMode := _config.ErupeConfig.RealClientMode
|
||||
defer func() { _config.ErupeConfig.RealClientMode = originalMode }()
|
||||
|
||||
_config.ErupeConfig.RealClientMode = tt.clientMode
|
||||
pointers := getPointers()
|
||||
|
||||
if pointers[pGender] != tt.wantGender {
|
||||
t.Errorf("pGender = %d, want %d", pointers[pGender], tt.wantGender)
|
||||
}
|
||||
|
||||
if pointers[pHR] != tt.wantHR {
|
||||
t.Errorf("pHR = %d, want %d", pointers[pHR], tt.wantHR)
|
||||
}
|
||||
|
||||
// Verify all required pointers exist
|
||||
requiredPointers := []SavePointer{pGender, pRP, pHouseTier, pHouseData, pBookshelfData,
|
||||
pGalleryData, pToreData, pGardenData, pPlaytime, pWeaponType, pWeaponID, pHR, lBookshelfData}
|
||||
|
||||
for _, ptr := range requiredPointers {
|
||||
if _, exists := pointers[ptr]; !exists {
|
||||
t.Errorf("pointer %v not found in map", ptr)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_Compress tests savedata compression
|
||||
func TestCharacterSaveData_Compress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid_small_data",
|
||||
data: []byte{0x01, 0x02, 0x03, 0x04},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid_large_data",
|
||||
data: bytes.Repeat([]byte{0xAA}, 10000),
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "empty_data",
|
||||
data: []byte{},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
save := &CharacterSaveData{
|
||||
decompSave: tt.data,
|
||||
}
|
||||
|
||||
err := save.Compress()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Compress() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if !tt.wantErr && len(save.compSave) == 0 {
|
||||
t.Error("compressed save is empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_Decompress tests savedata decompression
|
||||
func TestCharacterSaveData_Decompress(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
setup func() []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid_compressed_data",
|
||||
setup: func() []byte {
|
||||
data := []byte{0x01, 0x02, 0x03, 0x04}
|
||||
compressed, _ := nullcomp.Compress(data)
|
||||
return compressed
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid_large_compressed_data",
|
||||
setup: func() []byte {
|
||||
data := bytes.Repeat([]byte{0xBB}, 5000)
|
||||
compressed, _ := nullcomp.Compress(data)
|
||||
return compressed
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
save := &CharacterSaveData{
|
||||
compSave: tt.setup(),
|
||||
}
|
||||
|
||||
err := save.Decompress()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Decompress() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if !tt.wantErr && len(save.decompSave) == 0 {
|
||||
t.Error("decompressed save is empty")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_RoundTrip tests compression and decompression
|
||||
func TestCharacterSaveData_RoundTrip(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
}{
|
||||
{
|
||||
name: "small_data",
|
||||
data: []byte{0x01, 0x02, 0x03, 0x04, 0x05},
|
||||
},
|
||||
{
|
||||
name: "repeating_pattern",
|
||||
data: bytes.Repeat([]byte{0xCC}, 1000),
|
||||
},
|
||||
{
|
||||
name: "mixed_data",
|
||||
data: []byte{0x00, 0xFF, 0x01, 0xFE, 0x02, 0xFD, 0x03, 0xFC},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
save := &CharacterSaveData{
|
||||
decompSave: tt.data,
|
||||
}
|
||||
|
||||
// Compress
|
||||
if err := save.Compress(); err != nil {
|
||||
t.Fatalf("Compress() failed: %v", err)
|
||||
}
|
||||
|
||||
// Clear decompressed data
|
||||
save.decompSave = nil
|
||||
|
||||
// Decompress
|
||||
if err := save.Decompress(); err != nil {
|
||||
t.Fatalf("Decompress() failed: %v", err)
|
||||
}
|
||||
|
||||
// Verify round trip
|
||||
if !bytes.Equal(save.decompSave, tt.data) {
|
||||
t.Errorf("round trip failed: got %v, want %v", save.decompSave, tt.data)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_updateStructWithSaveData tests parsing save data
|
||||
func TestCharacterSaveData_updateStructWithSaveData(t *testing.T) {
|
||||
originalMode := _config.ErupeConfig.RealClientMode
|
||||
defer func() { _config.ErupeConfig.RealClientMode = originalMode }()
|
||||
_config.ErupeConfig.RealClientMode = _config.Z2
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
isNewCharacter bool
|
||||
setupSaveData func() []byte
|
||||
wantName string
|
||||
wantGender bool
|
||||
}{
|
||||
{
|
||||
name: "male_character",
|
||||
isNewCharacter: false,
|
||||
setupSaveData: func() []byte {
|
||||
data := make([]byte, 150000)
|
||||
copy(data[88:], []byte("TestChar\x00"))
|
||||
data[81] = 0 // Male
|
||||
return data
|
||||
},
|
||||
wantName: "TestChar",
|
||||
wantGender: false,
|
||||
},
|
||||
{
|
||||
name: "female_character",
|
||||
isNewCharacter: false,
|
||||
setupSaveData: func() []byte {
|
||||
data := make([]byte, 150000)
|
||||
copy(data[88:], []byte("FemaleChar\x00"))
|
||||
data[81] = 1 // Female
|
||||
return data
|
||||
},
|
||||
wantName: "FemaleChar",
|
||||
wantGender: true,
|
||||
},
|
||||
{
|
||||
name: "new_character_skips_parsing",
|
||||
isNewCharacter: true,
|
||||
setupSaveData: func() []byte {
|
||||
data := make([]byte, 150000)
|
||||
copy(data[88:], []byte("NewChar\x00"))
|
||||
return data
|
||||
},
|
||||
wantName: "NewChar",
|
||||
wantGender: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
save := &CharacterSaveData{
|
||||
Pointers: getPointers(),
|
||||
decompSave: tt.setupSaveData(),
|
||||
IsNewCharacter: tt.isNewCharacter,
|
||||
}
|
||||
|
||||
save.updateStructWithSaveData()
|
||||
|
||||
if save.Name != tt.wantName {
|
||||
t.Errorf("Name = %q, want %q", save.Name, tt.wantName)
|
||||
}
|
||||
|
||||
if save.Gender != tt.wantGender {
|
||||
t.Errorf("Gender = %v, want %v", save.Gender, tt.wantGender)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_updateSaveDataWithStruct tests writing struct to save data
|
||||
func TestCharacterSaveData_updateSaveDataWithStruct(t *testing.T) {
|
||||
originalMode := _config.ErupeConfig.RealClientMode
|
||||
defer func() { _config.ErupeConfig.RealClientMode = originalMode }()
|
||||
_config.ErupeConfig.RealClientMode = _config.G10
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
rp uint16
|
||||
kqf []byte
|
||||
wantRP uint16
|
||||
}{
|
||||
{
|
||||
name: "update_rp_value",
|
||||
rp: 1234,
|
||||
kqf: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
|
||||
wantRP: 1234,
|
||||
},
|
||||
{
|
||||
name: "zero_rp_value",
|
||||
rp: 0,
|
||||
kqf: []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00},
|
||||
wantRP: 0,
|
||||
},
|
||||
{
|
||||
name: "max_rp_value",
|
||||
rp: 65535,
|
||||
kqf: []byte{0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF},
|
||||
wantRP: 65535,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
save := &CharacterSaveData{
|
||||
Pointers: getPointers(),
|
||||
decompSave: make([]byte, 150000),
|
||||
RP: tt.rp,
|
||||
KQF: tt.kqf,
|
||||
}
|
||||
|
||||
save.updateSaveDataWithStruct()
|
||||
|
||||
// Verify RP was written correctly
|
||||
rpOffset := save.Pointers[pRP]
|
||||
gotRP := binary.LittleEndian.Uint16(save.decompSave[rpOffset : rpOffset+2])
|
||||
if gotRP != tt.wantRP {
|
||||
t.Errorf("RP in save data = %d, want %d", gotRP, tt.wantRP)
|
||||
}
|
||||
|
||||
// Verify KQF was written correctly
|
||||
kqfOffset := save.Pointers[pKQF]
|
||||
gotKQF := save.decompSave[kqfOffset : kqfOffset+8]
|
||||
if !bytes.Equal(gotKQF, tt.kqf) {
|
||||
t.Errorf("KQF in save data = %v, want %v", gotKQF, tt.kqf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfSexChanger tests the sex changer handler
|
||||
func TestHandleMsgMhfSexChanger(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
ackHandle uint32
|
||||
}{
|
||||
{
|
||||
name: "basic_sex_change",
|
||||
ackHandle: 1234,
|
||||
},
|
||||
{
|
||||
name: "different_ack_handle",
|
||||
ackHandle: 9999,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
|
||||
pkt := &mhfpacket.MsgMhfSexChanger{
|
||||
AckHandle: tt.ackHandle,
|
||||
}
|
||||
|
||||
handleMsgMhfSexChanger(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
|
||||
// Drain the channel
|
||||
<-s.sendPackets
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetCharacterSaveData_Integration tests retrieving character save data from database
|
||||
func TestGetCharacterSaveData_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Save original config mode
|
||||
originalMode := _config.ErupeConfig.RealClientMode
|
||||
defer func() { _config.ErupeConfig.RealClientMode = originalMode }()
|
||||
_config.ErupeConfig.RealClientMode = _config.Z2
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
charName string
|
||||
isNewCharacter bool
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "existing_character",
|
||||
charName: "TestChar",
|
||||
isNewCharacter: false,
|
||||
wantError: false,
|
||||
},
|
||||
{
|
||||
name: "new_character",
|
||||
charName: "NewChar",
|
||||
isNewCharacter: true,
|
||||
wantError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser_"+tt.name)
|
||||
charID := CreateTestCharacter(t, db, userID, tt.charName)
|
||||
|
||||
// Update is_new_character flag
|
||||
_, err := db.Exec("UPDATE characters SET is_new_character = $1 WHERE id = $2", tt.isNewCharacter, charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update character: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
// Get character save data
|
||||
saveData, err := GetCharacterSaveData(s, charID)
|
||||
if (err != nil) != tt.wantError {
|
||||
t.Errorf("GetCharacterSaveData() error = %v, wantErr %v", err, tt.wantError)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantError {
|
||||
if saveData == nil {
|
||||
t.Fatal("saveData is nil")
|
||||
}
|
||||
|
||||
if saveData.CharID != charID {
|
||||
t.Errorf("CharID = %d, want %d", saveData.CharID, charID)
|
||||
}
|
||||
|
||||
if saveData.Name != tt.charName {
|
||||
t.Errorf("Name = %q, want %q", saveData.Name, tt.charName)
|
||||
}
|
||||
|
||||
if saveData.IsNewCharacter != tt.isNewCharacter {
|
||||
t.Errorf("IsNewCharacter = %v, want %v", saveData.IsNewCharacter, tt.isNewCharacter)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestCharacterSaveData_Save_Integration tests saving character data to database
|
||||
func TestCharacterSaveData_Save_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Save original config mode
|
||||
originalMode := _config.ErupeConfig.RealClientMode
|
||||
defer func() { _config.ErupeConfig.RealClientMode = originalMode }()
|
||||
_config.ErupeConfig.RealClientMode = _config.Z2
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "savetest")
|
||||
charID := CreateTestCharacter(t, db, userID, "SaveChar")
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
// Load character save data
|
||||
saveData, err := GetCharacterSaveData(s, charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get save data: %v", err)
|
||||
}
|
||||
|
||||
// Modify save data
|
||||
saveData.HR = 999
|
||||
saveData.GR = 100
|
||||
saveData.Gender = true
|
||||
saveData.WeaponType = 5
|
||||
saveData.WeaponID = 1234
|
||||
|
||||
// Save it
|
||||
saveData.Save(s)
|
||||
|
||||
// Reload and verify
|
||||
var hr, gr uint16
|
||||
var gender bool
|
||||
var weaponType uint8
|
||||
var weaponID uint16
|
||||
|
||||
err = db.QueryRow("SELECT hr, gr, is_female, weapon_type, weapon_id FROM characters WHERE id = $1",
|
||||
charID).Scan(&hr, &gr, &gender, &weaponType, &weaponID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query updated character: %v", err)
|
||||
}
|
||||
|
||||
if hr != 999 {
|
||||
t.Errorf("HR = %d, want 999", hr)
|
||||
}
|
||||
if gr != 100 {
|
||||
t.Errorf("GR = %d, want 100", gr)
|
||||
}
|
||||
if !gender {
|
||||
t.Error("Gender should be true (female)")
|
||||
}
|
||||
if weaponType != 5 {
|
||||
t.Errorf("WeaponType = %d, want 5", weaponType)
|
||||
}
|
||||
if weaponID != 1234 {
|
||||
t.Errorf("WeaponID = %d, want 1234", weaponID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGRPtoGR tests the GRP to GR conversion function
|
||||
func TestGRPtoGR(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grp int
|
||||
wantGR uint16
|
||||
}{
|
||||
{
|
||||
name: "zero_grp",
|
||||
grp: 0,
|
||||
wantGR: 1, // Function returns 1 for 0 GRP
|
||||
},
|
||||
{
|
||||
name: "low_grp",
|
||||
grp: 10000,
|
||||
wantGR: 10, // Function returns 10 for 10000 GRP
|
||||
},
|
||||
{
|
||||
name: "mid_grp",
|
||||
grp: 500000,
|
||||
wantGR: 88, // Function returns 88 for 500000 GRP
|
||||
},
|
||||
{
|
||||
name: "high_grp",
|
||||
grp: 2000000,
|
||||
wantGR: 265, // Function returns 265 for 2000000 GRP
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotGR := grpToGR(tt.grp)
|
||||
if gotGR != tt.wantGR {
|
||||
t.Errorf("grpToGR(%d) = %d, want %d", tt.grp, gotGR, tt.wantGR)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkCompress benchmarks savedata compression
|
||||
func BenchmarkCompress(b *testing.B) {
|
||||
data := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 25000) // 100KB
|
||||
save := &CharacterSaveData{
|
||||
decompSave: data,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
save.Compress()
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkDecompress benchmarks savedata decompression
|
||||
func BenchmarkDecompress(b *testing.B) {
|
||||
data := bytes.Repeat([]byte{0xAA, 0xBB, 0xCC, 0xDD}, 25000)
|
||||
compressed, _ := nullcomp.Compress(data)
|
||||
|
||||
save := &CharacterSaveData{
|
||||
compSave: compressed,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
save.Decompress()
|
||||
}
|
||||
}
|
||||
604
server/channelserver/handlers_clients_test.go
Normal file
604
server/channelserver/handlers_clients_test.go
Normal file
@@ -0,0 +1,604 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
_config "erupe-ce/config"
|
||||
"erupe-ce/common/byteframe"
|
||||
"erupe-ce/network/mhfpacket"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// TestHandleMsgSysEnumerateClient tests client enumeration in stages
|
||||
func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stageID string
|
||||
getType uint8
|
||||
setupStage func(*Server, string)
|
||||
wantClientCount int
|
||||
wantFailure bool
|
||||
}{
|
||||
{
|
||||
name: "enumerate_all_clients",
|
||||
stageID: "test_stage_1",
|
||||
getType: 0, // All clients
|
||||
setupStage: func(server *Server, stageID string) {
|
||||
stage := NewStage(stageID)
|
||||
mock1 := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
mock2 := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s1 := createTestSession(mock1)
|
||||
s2 := createTestSession(mock2)
|
||||
s1.charID = 100
|
||||
s2.charID = 200
|
||||
stage.clients[s1] = 100
|
||||
stage.clients[s2] = 200
|
||||
server.stagesLock.Lock()
|
||||
server.stages[stageID] = stage
|
||||
server.stagesLock.Unlock()
|
||||
},
|
||||
wantClientCount: 2,
|
||||
wantFailure: false,
|
||||
},
|
||||
{
|
||||
name: "enumerate_not_ready_clients",
|
||||
stageID: "test_stage_2",
|
||||
getType: 1, // Not ready
|
||||
setupStage: func(server *Server, stageID string) {
|
||||
stage := NewStage(stageID)
|
||||
stage.reservedClientSlots[100] = false // Not ready
|
||||
stage.reservedClientSlots[200] = true // Ready
|
||||
stage.reservedClientSlots[300] = false // Not ready
|
||||
server.stagesLock.Lock()
|
||||
server.stages[stageID] = stage
|
||||
server.stagesLock.Unlock()
|
||||
},
|
||||
wantClientCount: 2, // Only not-ready clients
|
||||
wantFailure: false,
|
||||
},
|
||||
{
|
||||
name: "enumerate_ready_clients",
|
||||
stageID: "test_stage_3",
|
||||
getType: 2, // Ready
|
||||
setupStage: func(server *Server, stageID string) {
|
||||
stage := NewStage(stageID)
|
||||
stage.reservedClientSlots[100] = false // Not ready
|
||||
stage.reservedClientSlots[200] = true // Ready
|
||||
stage.reservedClientSlots[300] = true // Ready
|
||||
server.stagesLock.Lock()
|
||||
server.stages[stageID] = stage
|
||||
server.stagesLock.Unlock()
|
||||
},
|
||||
wantClientCount: 2, // Only ready clients
|
||||
wantFailure: false,
|
||||
},
|
||||
{
|
||||
name: "enumerate_empty_stage",
|
||||
stageID: "test_stage_empty",
|
||||
getType: 0,
|
||||
setupStage: func(server *Server, stageID string) {
|
||||
stage := NewStage(stageID)
|
||||
server.stagesLock.Lock()
|
||||
server.stages[stageID] = stage
|
||||
server.stagesLock.Unlock()
|
||||
},
|
||||
wantClientCount: 0,
|
||||
wantFailure: false,
|
||||
},
|
||||
{
|
||||
name: "enumerate_nonexistent_stage",
|
||||
stageID: "nonexistent_stage",
|
||||
getType: 0,
|
||||
setupStage: func(server *Server, stageID string) {
|
||||
// Don't create the stage
|
||||
},
|
||||
wantClientCount: 0,
|
||||
wantFailure: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test session (which creates a server with erupeConfig)
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
|
||||
// Initialize stages map if needed
|
||||
if s.server.stages == nil {
|
||||
s.server.stages = make(map[string]*Stage)
|
||||
}
|
||||
|
||||
// Setup stage
|
||||
tt.setupStage(s.server, tt.stageID)
|
||||
|
||||
pkt := &mhfpacket.MsgSysEnumerateClient{
|
||||
AckHandle: 1234,
|
||||
StageID: tt.stageID,
|
||||
Get: tt.getType,
|
||||
}
|
||||
|
||||
handleMsgSysEnumerateClient(s, pkt)
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
|
||||
// Read the ACK packet
|
||||
ackPkt := <-s.sendPackets
|
||||
if tt.wantFailure {
|
||||
// For failures, we can't easily check the exact format
|
||||
// Just verify something was sent
|
||||
return
|
||||
}
|
||||
|
||||
// Parse the response to count clients
|
||||
// The ackPkt.data contains the full packet structure:
|
||||
// [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...]
|
||||
// Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes
|
||||
if len(ackPkt.data) < 10 {
|
||||
t.Fatal("ACK packet too small")
|
||||
}
|
||||
|
||||
// The response data starts after the 10-byte header
|
||||
// Response format is: [count:uint16][charID1:uint32][charID2:uint32]...
|
||||
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
|
||||
count := bf.ReadUint16()
|
||||
|
||||
if int(count) != tt.wantClientCount {
|
||||
t.Errorf("client count = %d, want %d", count, tt.wantClientCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfListMember tests listing blacklisted members
|
||||
func TestHandleMsgMhfListMember_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
blockedCSV string
|
||||
wantBlockCount int
|
||||
}{
|
||||
{
|
||||
name: "no_blocked_users",
|
||||
blockedCSV: "",
|
||||
wantBlockCount: 0,
|
||||
},
|
||||
{
|
||||
name: "single_blocked_user",
|
||||
blockedCSV: "2",
|
||||
wantBlockCount: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple_blocked_users",
|
||||
blockedCSV: "2,3,4",
|
||||
wantBlockCount: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test user and character (use short names to avoid 15 char limit)
|
||||
userID := CreateTestUser(t, db, "user_"+tt.name)
|
||||
charName := fmt.Sprintf("Char%d", i)
|
||||
charID := CreateTestCharacter(t, db, userID, charName)
|
||||
|
||||
// Create blocked characters
|
||||
if tt.blockedCSV != "" {
|
||||
// Create the blocked users
|
||||
for i := 2; i <= 4; i++ {
|
||||
blockedUserID := CreateTestUser(t, db, "blocked_user_"+tt.name+"_"+string(rune(i)))
|
||||
CreateTestCharacter(t, db, blockedUserID, "BlockedChar_"+string(rune(i)))
|
||||
}
|
||||
}
|
||||
|
||||
// Set blocked list
|
||||
_, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", tt.blockedCSV, charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to update blocked list: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
pkt := &mhfpacket.MsgMhfListMember{
|
||||
AckHandle: 5678,
|
||||
}
|
||||
|
||||
handleMsgMhfListMember(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
|
||||
// Parse response
|
||||
// The ackPkt.data contains the full packet structure:
|
||||
// [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...]
|
||||
// Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes
|
||||
ackPkt := <-s.sendPackets
|
||||
if len(ackPkt.data) < 10 {
|
||||
t.Fatal("ACK packet too small")
|
||||
}
|
||||
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
|
||||
count := bf.ReadUint32()
|
||||
|
||||
if int(count) != tt.wantBlockCount {
|
||||
t.Errorf("blocked count = %d, want %d", count, tt.wantBlockCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfOprMember tests blacklist/friendlist operations
|
||||
func TestHandleMsgMhfOprMember_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
isBlacklist bool
|
||||
operation bool // true = remove, false = add
|
||||
initialList string
|
||||
targetCharIDs []uint32
|
||||
wantList string
|
||||
}{
|
||||
{
|
||||
name: "add_to_blacklist",
|
||||
isBlacklist: true,
|
||||
operation: false,
|
||||
initialList: "",
|
||||
targetCharIDs: []uint32{2},
|
||||
wantList: "2",
|
||||
},
|
||||
{
|
||||
name: "remove_from_blacklist",
|
||||
isBlacklist: true,
|
||||
operation: true,
|
||||
initialList: "2,3,4",
|
||||
targetCharIDs: []uint32{3},
|
||||
wantList: "2,4",
|
||||
},
|
||||
{
|
||||
name: "add_to_friendlist",
|
||||
isBlacklist: false,
|
||||
operation: false,
|
||||
initialList: "10",
|
||||
targetCharIDs: []uint32{20},
|
||||
wantList: "10,20",
|
||||
},
|
||||
{
|
||||
name: "remove_from_friendlist",
|
||||
isBlacklist: false,
|
||||
operation: true,
|
||||
initialList: "10,20,30",
|
||||
targetCharIDs: []uint32{20},
|
||||
wantList: "10,30",
|
||||
},
|
||||
{
|
||||
name: "add_multiple_to_blacklist",
|
||||
isBlacklist: true,
|
||||
operation: false,
|
||||
initialList: "1",
|
||||
targetCharIDs: []uint32{2, 3},
|
||||
wantList: "1,2,3",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test user and character (use short names to avoid 15 char limit)
|
||||
userID := CreateTestUser(t, db, "user_"+tt.name)
|
||||
charName := fmt.Sprintf("OpChar%d", i)
|
||||
charID := CreateTestCharacter(t, db, userID, charName)
|
||||
|
||||
// Set initial list
|
||||
column := "blocked"
|
||||
if !tt.isBlacklist {
|
||||
column = "friends"
|
||||
}
|
||||
_, err := db.Exec("UPDATE characters SET "+column+" = $1 WHERE id = $2", tt.initialList, charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set initial list: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
pkt := &mhfpacket.MsgMhfOprMember{
|
||||
AckHandle: 9999,
|
||||
Blacklist: tt.isBlacklist,
|
||||
Operation: tt.operation,
|
||||
CharIDs: tt.targetCharIDs,
|
||||
}
|
||||
|
||||
handleMsgMhfOprMember(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
<-s.sendPackets
|
||||
|
||||
// Verify the list was updated
|
||||
var gotList string
|
||||
err = db.QueryRow("SELECT "+column+" FROM characters WHERE id = $1", charID).Scan(&gotList)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query updated list: %v", err)
|
||||
}
|
||||
|
||||
if gotList != tt.wantList {
|
||||
t.Errorf("list = %q, want %q", gotList, tt.wantList)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfShutClient tests the shut client handler
|
||||
func TestHandleMsgMhfShutClient(t *testing.T) {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
|
||||
pkt := &mhfpacket.MsgMhfShutClient{}
|
||||
|
||||
// Should not panic (handler is empty)
|
||||
handleMsgMhfShutClient(s, pkt)
|
||||
}
|
||||
|
||||
// TestHandleMsgSysHideClient tests the hide client handler
|
||||
func TestHandleMsgSysHideClient(t *testing.T) {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
hide bool
|
||||
}{
|
||||
{
|
||||
name: "hide_client",
|
||||
hide: true,
|
||||
},
|
||||
{
|
||||
name: "show_client",
|
||||
hide: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pkt := &mhfpacket.MsgSysHideClient{
|
||||
Hide: tt.hide,
|
||||
}
|
||||
|
||||
// Should not panic (handler is empty)
|
||||
handleMsgSysHideClient(s, pkt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnumerateClient_ConcurrentAccess tests concurrent stage access
|
||||
func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
|
||||
logger, _ := zap.NewDevelopment()
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
stages: make(map[string]*Stage),
|
||||
erupeConfig: &_config.Config{
|
||||
DebugOptions: _config.DebugOptions{
|
||||
LogOutboundMessages: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
stageID := "concurrent_test_stage"
|
||||
stage := NewStage(stageID)
|
||||
|
||||
// Add some clients to the stage
|
||||
for i := uint32(1); i <= 10; i++ {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
sess := createTestSession(mock)
|
||||
sess.charID = i * 100
|
||||
stage.clients[sess] = i * 100
|
||||
}
|
||||
|
||||
server.stagesLock.Lock()
|
||||
server.stages[stageID] = stage
|
||||
server.stagesLock.Unlock()
|
||||
|
||||
// Run concurrent enumerations
|
||||
done := make(chan bool, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func() {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.server = server
|
||||
|
||||
pkt := &mhfpacket.MsgSysEnumerateClient{
|
||||
AckHandle: 3333,
|
||||
StageID: stageID,
|
||||
Get: 0, // All clients
|
||||
}
|
||||
|
||||
handleMsgSysEnumerateClient(s, pkt)
|
||||
done <- true
|
||||
}()
|
||||
}
|
||||
|
||||
// Wait for all goroutines to complete
|
||||
for i := 0; i < 5; i++ {
|
||||
<-done
|
||||
}
|
||||
}
|
||||
|
||||
// TestListMember_EmptyDatabase tests listing members when database is empty
|
||||
func TestListMember_EmptyDatabase_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "emptytest")
|
||||
charID := CreateTestCharacter(t, db, userID, "EmptyChar")
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
pkt := &mhfpacket.MsgMhfListMember{
|
||||
AckHandle: 4444,
|
||||
}
|
||||
|
||||
handleMsgMhfListMember(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
|
||||
ackPkt := <-s.sendPackets
|
||||
if len(ackPkt.data) < 10 {
|
||||
t.Fatal("ACK packet too small")
|
||||
}
|
||||
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
|
||||
count := bf.ReadUint32()
|
||||
|
||||
if count != 0 {
|
||||
t.Errorf("empty blocked list should have count 0, got %d", count)
|
||||
}
|
||||
}
|
||||
|
||||
// TestOprMember_EdgeCases tests edge cases for member operations
|
||||
func TestOprMember_EdgeCases_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
initialList string
|
||||
operation bool
|
||||
targetCharIDs []uint32
|
||||
wantList string
|
||||
}{
|
||||
{
|
||||
name: "add_duplicate_to_list",
|
||||
initialList: "1,2,3",
|
||||
operation: false, // add
|
||||
targetCharIDs: []uint32{2},
|
||||
wantList: "1,2,3,2", // CSV helper adds duplicates
|
||||
},
|
||||
{
|
||||
name: "remove_nonexistent_from_list",
|
||||
initialList: "1,2,3",
|
||||
operation: true, // remove
|
||||
targetCharIDs: []uint32{99},
|
||||
wantList: "1,2,3",
|
||||
},
|
||||
{
|
||||
name: "operate_on_empty_list",
|
||||
initialList: "",
|
||||
operation: false,
|
||||
targetCharIDs: []uint32{1},
|
||||
wantList: "1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "edge_"+tt.name)
|
||||
charID := CreateTestCharacter(t, db, userID, "EdgeChar")
|
||||
|
||||
// Set initial blocked list
|
||||
_, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", tt.initialList, charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to set initial list: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
pkt := &mhfpacket.MsgMhfOprMember{
|
||||
AckHandle: 7777,
|
||||
Blacklist: true,
|
||||
Operation: tt.operation,
|
||||
CharIDs: tt.targetCharIDs,
|
||||
}
|
||||
|
||||
handleMsgMhfOprMember(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
<-s.sendPackets
|
||||
|
||||
// Verify the list
|
||||
var gotList string
|
||||
err = db.QueryRow("SELECT blocked FROM characters WHERE id = $1", charID).Scan(&gotList)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query list: %v", err)
|
||||
}
|
||||
|
||||
if gotList != tt.wantList {
|
||||
t.Errorf("list = %q, want %q", gotList, tt.wantList)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// BenchmarkEnumerateClients benchmarks client enumeration
|
||||
func BenchmarkEnumerateClients(b *testing.B) {
|
||||
logger, _ := zap.NewDevelopment()
|
||||
server := &Server{
|
||||
logger: logger,
|
||||
stages: make(map[string]*Stage),
|
||||
}
|
||||
|
||||
stageID := "bench_stage"
|
||||
stage := NewStage(stageID)
|
||||
|
||||
// Add 100 clients to the stage
|
||||
for i := uint32(1); i <= 100; i++ {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
sess := createTestSession(mock)
|
||||
sess.charID = i
|
||||
stage.clients[sess] = i
|
||||
}
|
||||
|
||||
server.stages[stageID] = stage
|
||||
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.server = server
|
||||
|
||||
pkt := &mhfpacket.MsgSysEnumerateClient{
|
||||
AckHandle: 8888,
|
||||
StageID: stageID,
|
||||
Get: 0,
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Clear the packet channel
|
||||
select {
|
||||
case <-s.sendPackets:
|
||||
default:
|
||||
}
|
||||
|
||||
handleMsgSysEnumerateClient(s, pkt)
|
||||
<-s.sendPackets
|
||||
}
|
||||
}
|
||||
@@ -3,9 +3,12 @@ package channelserver
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
"erupe-ce/network"
|
||||
"erupe-ce/network/clientctx"
|
||||
"erupe-ce/network/mhfpacket"
|
||||
"erupe-ce/server/channelserver/compression/nullcomp"
|
||||
"testing"
|
||||
)
|
||||
@@ -334,3 +337,318 @@ func BenchmarkPacketQueueing(b *testing.B) {
|
||||
// The current architecture doesn't easily support interface-based testing
|
||||
b.Skip("benchmark requires interface-based CryptConn mock")
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Integration Tests (require test database)
|
||||
// Run with: docker-compose -f docker/docker-compose.test.yml up -d
|
||||
// ============================================================================
|
||||
|
||||
// TestHandleMsgMhfSavedata_Integration tests the actual save data handler with database
|
||||
func TestHandleMsgMhfSavedata_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
charID := CreateTestCharacter(t, db, userID, "TestChar")
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.Name = "TestChar"
|
||||
s.server.db = db
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
saveType uint8
|
||||
payloadFunc func() []byte
|
||||
wantSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "blob_save",
|
||||
saveType: 0,
|
||||
payloadFunc: func() []byte {
|
||||
// Create minimal valid savedata (large enough for all game mode pointers)
|
||||
data := make([]byte, 150000)
|
||||
copy(data[88:], []byte("TestChar\x00")) // Name at offset 88
|
||||
compressed, _ := nullcomp.Compress(data)
|
||||
return compressed
|
||||
},
|
||||
wantSuccess: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
payload := tt.payloadFunc()
|
||||
pkt := &mhfpacket.MsgMhfSavedata{
|
||||
SaveType: tt.saveType,
|
||||
AckHandle: 1234,
|
||||
AllocMemSize: uint32(len(payload)),
|
||||
DataSize: uint32(len(payload)),
|
||||
RawDataPayload: payload,
|
||||
}
|
||||
|
||||
handleMsgMhfSavedata(s, pkt)
|
||||
|
||||
// Check if ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Error("no ACK packet was sent")
|
||||
} else {
|
||||
// Drain the channel
|
||||
<-s.sendPackets
|
||||
}
|
||||
|
||||
// Verify database was updated (for success case)
|
||||
if tt.wantSuccess {
|
||||
var savedData []byte
|
||||
err := db.QueryRow("SELECT savedata FROM characters WHERE id = $1", charID).Scan(&savedData)
|
||||
if err != nil {
|
||||
t.Errorf("failed to query saved data: %v", err)
|
||||
}
|
||||
if len(savedData) == 0 {
|
||||
t.Error("savedata was not written to database")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfLoaddata_Integration tests loading character data
|
||||
func TestHandleMsgMhfLoaddata_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
|
||||
// Create savedata
|
||||
saveData := make([]byte, 200)
|
||||
copy(saveData[88:], []byte("LoadTest\x00"))
|
||||
compressed, _ := nullcomp.Compress(saveData)
|
||||
|
||||
var charID uint32
|
||||
err := db.QueryRow(`
|
||||
INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary)
|
||||
VALUES ($1, false, false, 'LoadTest', '', 0, 0, 0, 0, $2, '', '')
|
||||
RETURNING id
|
||||
`, userID, compressed).Scan(&charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test character: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
s.server.userBinaryParts = make(map[userBinaryPartID][]byte)
|
||||
s.server.userBinaryPartsLock.Lock()
|
||||
defer s.server.userBinaryPartsLock.Unlock()
|
||||
|
||||
pkt := &mhfpacket.MsgMhfLoaddata{
|
||||
AckHandle: 5678,
|
||||
}
|
||||
|
||||
handleMsgMhfLoaddata(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Error("no ACK packet was sent")
|
||||
}
|
||||
|
||||
// Verify name was extracted
|
||||
if s.Name != "LoadTest" {
|
||||
t.Errorf("character name not loaded, got %q, want %q", s.Name, "LoadTest")
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfSaveScenarioData_Integration tests scenario data saving
|
||||
func TestHandleMsgMhfSaveScenarioData_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
charID := CreateTestCharacter(t, db, userID, "ScenarioTest")
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
scenarioData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
|
||||
|
||||
pkt := &mhfpacket.MsgMhfSaveScenarioData{
|
||||
AckHandle: 9999,
|
||||
DataSize: uint32(len(scenarioData)),
|
||||
RawDataPayload: scenarioData,
|
||||
}
|
||||
|
||||
handleMsgMhfSaveScenarioData(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Error("no ACK packet was sent")
|
||||
} else {
|
||||
<-s.sendPackets
|
||||
}
|
||||
|
||||
// Verify scenario data was saved
|
||||
var saved []byte
|
||||
err := db.QueryRow("SELECT scenariodata FROM characters WHERE id = $1", charID).Scan(&saved)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to query scenario data: %v", err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(saved, scenarioData) {
|
||||
t.Errorf("scenario data mismatch: got %v, want %v", saved, scenarioData)
|
||||
}
|
||||
}
|
||||
|
||||
// TestHandleMsgMhfLoadScenarioData_Integration tests scenario data loading
|
||||
func TestHandleMsgMhfLoadScenarioData_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
|
||||
scenarioData := []byte{0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF, 0x11, 0x22, 0x33, 0x44}
|
||||
|
||||
var charID uint32
|
||||
err := db.QueryRow(`
|
||||
INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary, scenariodata)
|
||||
VALUES ($1, false, false, 'ScenarioLoad', '', 0, 0, 0, 0, $2, '', '', $3)
|
||||
RETURNING id
|
||||
`, userID, []byte{0x00, 0x00, 0x00, 0x00}, scenarioData).Scan(&charID)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test character: %v", err)
|
||||
}
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.server.db = db
|
||||
|
||||
pkt := &mhfpacket.MsgMhfLoadScenarioData{
|
||||
AckHandle: 1111,
|
||||
}
|
||||
|
||||
handleMsgMhfLoadScenarioData(s, pkt)
|
||||
|
||||
// Verify ACK was sent
|
||||
if len(s.sendPackets) == 0 {
|
||||
t.Fatal("no ACK packet was sent")
|
||||
}
|
||||
|
||||
// The ACK should contain the scenario data
|
||||
ackPkt := <-s.sendPackets
|
||||
if len(ackPkt.data) < len(scenarioData) {
|
||||
t.Errorf("ACK packet too small: got %d bytes, expected at least %d", len(ackPkt.data), len(scenarioData))
|
||||
}
|
||||
}
|
||||
|
||||
// TestSaveDataCorruptionDetection_Integration tests that corrupted saves are rejected
|
||||
func TestSaveDataCorruptionDetection_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and character
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
charID := CreateTestCharacter(t, db, userID, "OriginalName")
|
||||
|
||||
// Create test session
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charID
|
||||
s.Name = "OriginalName"
|
||||
s.server.db = db
|
||||
s.server.erupeConfig.DeleteOnSaveCorruption = false
|
||||
|
||||
// Create save data with a DIFFERENT name (corruption)
|
||||
corruptedData := make([]byte, 200)
|
||||
copy(corruptedData[88:], []byte("HackedName\x00"))
|
||||
compressed, _ := nullcomp.Compress(corruptedData)
|
||||
|
||||
pkt := &mhfpacket.MsgMhfSavedata{
|
||||
SaveType: 0,
|
||||
AckHandle: 4444,
|
||||
AllocMemSize: uint32(len(compressed)),
|
||||
DataSize: uint32(len(compressed)),
|
||||
RawDataPayload: compressed,
|
||||
}
|
||||
|
||||
handleMsgMhfSavedata(s, pkt)
|
||||
|
||||
// The save should be rejected, connection should be closed
|
||||
// In a real scenario, s.rawConn.Close() is called
|
||||
// We can't easily test that, but we can verify the data wasn't saved
|
||||
|
||||
// Check that database wasn't updated with corrupted data
|
||||
var savedName string
|
||||
db.QueryRow("SELECT name FROM characters WHERE id = $1", charID).Scan(&savedName)
|
||||
if savedName == "HackedName" {
|
||||
t.Error("corrupted save data was incorrectly written to database")
|
||||
}
|
||||
}
|
||||
|
||||
// TestConcurrentSaveData_Integration tests concurrent save operations
|
||||
func TestConcurrentSaveData_Integration(t *testing.T) {
|
||||
db := SetupTestDB(t)
|
||||
defer TeardownTestDB(t, db)
|
||||
|
||||
// Create test user and multiple characters
|
||||
userID := CreateTestUser(t, db, "testuser")
|
||||
charIDs := make([]uint32, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
charIDs[i] = CreateTestCharacter(t, db, userID, fmt.Sprintf("Char%d", i))
|
||||
}
|
||||
|
||||
// Run concurrent saves
|
||||
done := make(chan bool, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
go func(index int) {
|
||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||
s := createTestSession(mock)
|
||||
s.charID = charIDs[index]
|
||||
s.Name = fmt.Sprintf("Char%d", index)
|
||||
s.server.db = db
|
||||
|
||||
saveData := make([]byte, 200)
|
||||
copy(saveData[88:], []byte(fmt.Sprintf("Char%d\x00", index)))
|
||||
compressed, _ := nullcomp.Compress(saveData)
|
||||
|
||||
pkt := &mhfpacket.MsgMhfSavedata{
|
||||
SaveType: 0,
|
||||
AckHandle: uint32(index),
|
||||
AllocMemSize: uint32(len(compressed)),
|
||||
DataSize: uint32(len(compressed)),
|
||||
RawDataPayload: compressed,
|
||||
}
|
||||
|
||||
handleMsgMhfSavedata(s, pkt)
|
||||
done <- true
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all saves to complete
|
||||
for i := 0; i < 5; i++ {
|
||||
<-done
|
||||
}
|
||||
|
||||
// Verify all characters were saved
|
||||
for i := 0; i < 5; i++ {
|
||||
var saveData []byte
|
||||
err := db.QueryRow("SELECT savedata FROM characters WHERE id = $1", charIDs[i]).Scan(&saveData)
|
||||
if err != nil {
|
||||
t.Errorf("character %d: failed to load savedata: %v", i, err)
|
||||
}
|
||||
if len(saveData) == 0 {
|
||||
t.Errorf("character %d: savedata is empty", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
260
server/channelserver/testhelpers_db.go
Normal file
260
server/channelserver/testhelpers_db.go
Normal file
@@ -0,0 +1,260 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"erupe-ce/server/channelserver/compression/nullcomp"
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
// TestDBConfig holds the configuration for the test database
|
||||
type TestDBConfig struct {
|
||||
Host string
|
||||
Port string
|
||||
User string
|
||||
Password string
|
||||
DBName string
|
||||
}
|
||||
|
||||
// DefaultTestDBConfig returns the default test database configuration
|
||||
// that matches docker-compose.test.yml
|
||||
func DefaultTestDBConfig() *TestDBConfig {
|
||||
return &TestDBConfig{
|
||||
Host: getEnv("TEST_DB_HOST", "localhost"),
|
||||
Port: getEnv("TEST_DB_PORT", "5433"),
|
||||
User: getEnv("TEST_DB_USER", "test"),
|
||||
Password: getEnv("TEST_DB_PASSWORD", "test"),
|
||||
DBName: getEnv("TEST_DB_NAME", "erupe_test"),
|
||||
}
|
||||
}
|
||||
|
||||
func getEnv(key, defaultValue string) string {
|
||||
if value := os.Getenv(key); value != "" {
|
||||
return value
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
// SetupTestDB creates a connection to the test database and applies the schema
|
||||
func SetupTestDB(t *testing.T) *sqlx.DB {
|
||||
t.Helper()
|
||||
|
||||
config := DefaultTestDBConfig()
|
||||
connStr := fmt.Sprintf(
|
||||
"host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
|
||||
config.Host, config.Port, config.User, config.Password, config.DBName,
|
||||
)
|
||||
|
||||
db, err := sqlx.Open("postgres", connStr)
|
||||
if err != nil {
|
||||
t.Skipf("Failed to connect to test database: %v. Run: docker compose -f docker/docker-compose.test.yml up -d", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
db.Close()
|
||||
t.Skipf("Test database not available: %v. Run: docker compose -f docker/docker-compose.test.yml up -d", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Clean the database before tests
|
||||
CleanTestDB(t, db)
|
||||
|
||||
// Apply schema
|
||||
ApplyTestSchema(t, db)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
// CleanTestDB drops all tables to ensure a clean state
|
||||
func CleanTestDB(t *testing.T, db *sqlx.DB) {
|
||||
t.Helper()
|
||||
|
||||
// Drop all tables in the public schema
|
||||
_, err := db.Exec(`
|
||||
DO $$ DECLARE
|
||||
r RECORD;
|
||||
BEGIN
|
||||
FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = 'public') LOOP
|
||||
EXECUTE 'DROP TABLE IF EXISTS ' || quote_ident(r.tablename) || ' CASCADE';
|
||||
END LOOP;
|
||||
END $$;
|
||||
`)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to clean database: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// ApplyTestSchema applies the database schema from init.sql using pg_restore
|
||||
func ApplyTestSchema(t *testing.T, db *sqlx.DB) {
|
||||
t.Helper()
|
||||
|
||||
// Find the project root (where schemas/ directory is located)
|
||||
projectRoot := findProjectRoot(t)
|
||||
schemaPath := filepath.Join(projectRoot, "schemas", "init.sql")
|
||||
|
||||
// Get the connection config
|
||||
config := DefaultTestDBConfig()
|
||||
|
||||
// Use pg_restore to load the schema dump
|
||||
// The init.sql file is a pg_dump custom format, so we need pg_restore
|
||||
cmd := exec.Command("pg_restore",
|
||||
"-h", config.Host,
|
||||
"-p", config.Port,
|
||||
"-U", config.User,
|
||||
"-d", config.DBName,
|
||||
"--no-owner",
|
||||
"--no-acl",
|
||||
"-c", // clean (drop) before recreating
|
||||
schemaPath,
|
||||
)
|
||||
cmd.Env = append(os.Environ(), fmt.Sprintf("PGPASSWORD=%s", config.Password))
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
// pg_restore may error on first run (no tables to drop), that's usually ok
|
||||
t.Logf("pg_restore output: %s", string(output))
|
||||
// Check if it's a fatal error
|
||||
if !strings.Contains(string(output), "does not exist") {
|
||||
t.Logf("pg_restore error (may be non-fatal): %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Apply patch schemas in order
|
||||
applyPatchSchemas(t, db, projectRoot)
|
||||
}
|
||||
|
||||
// applyPatchSchemas applies all patch schema files in numeric order
|
||||
func applyPatchSchemas(t *testing.T, db *sqlx.DB, projectRoot string) {
|
||||
t.Helper()
|
||||
|
||||
patchDir := filepath.Join(projectRoot, "schemas", "patch-schema")
|
||||
entries, err := os.ReadDir(patchDir)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Could not read patch-schema directory: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Sort patch files numerically
|
||||
var patchFiles []string
|
||||
for _, entry := range entries {
|
||||
if !entry.IsDir() && strings.HasSuffix(entry.Name(), ".sql") {
|
||||
patchFiles = append(patchFiles, entry.Name())
|
||||
}
|
||||
}
|
||||
sort.Strings(patchFiles)
|
||||
|
||||
// Apply each patch in its own transaction
|
||||
for _, filename := range patchFiles {
|
||||
patchPath := filepath.Join(patchDir, filename)
|
||||
patchSQL, err := os.ReadFile(patchPath)
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to read patch file %s: %v", filename, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Start a new transaction for each patch
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
t.Logf("Warning: Failed to start transaction for patch %s: %v", filename, err)
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = tx.Exec(string(patchSQL))
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
t.Logf("Warning: Failed to apply patch %s: %v", filename, err)
|
||||
// Continue with other patches even if one fails
|
||||
} else {
|
||||
tx.Commit()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// findProjectRoot finds the project root directory by looking for the schemas directory
|
||||
func findProjectRoot(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
// Start from current directory and walk up
|
||||
dir, err := os.Getwd()
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to get working directory: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
schemasPath := filepath.Join(dir, "schemas")
|
||||
if stat, err := os.Stat(schemasPath); err == nil && stat.IsDir() {
|
||||
return dir
|
||||
}
|
||||
|
||||
parent := filepath.Dir(dir)
|
||||
if parent == dir {
|
||||
t.Fatal("Could not find project root (schemas directory not found)")
|
||||
}
|
||||
dir = parent
|
||||
}
|
||||
}
|
||||
|
||||
// TeardownTestDB closes the database connection
|
||||
func TeardownTestDB(t *testing.T, db *sqlx.DB) {
|
||||
t.Helper()
|
||||
if db != nil {
|
||||
db.Close()
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTestUser creates a test user and returns the user ID
|
||||
func CreateTestUser(t *testing.T, db *sqlx.DB, username string) uint32 {
|
||||
t.Helper()
|
||||
|
||||
var userID uint32
|
||||
err := db.QueryRow(`
|
||||
INSERT INTO users (username, password, rights)
|
||||
VALUES ($1, 'test_password_hash', 0)
|
||||
RETURNING id
|
||||
`, username).Scan(&userID)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test user: %v", err)
|
||||
}
|
||||
|
||||
return userID
|
||||
}
|
||||
|
||||
// CreateTestCharacter creates a test character and returns the character ID
|
||||
func CreateTestCharacter(t *testing.T, db *sqlx.DB, userID uint32, name string) uint32 {
|
||||
t.Helper()
|
||||
|
||||
// Create minimal valid savedata (needs to be large enough for the game to parse)
|
||||
// The name is at offset 88, and various game mode pointers extend up to ~147KB for ZZ mode
|
||||
// We need at least 150KB to accommodate all possible pointer offsets
|
||||
saveData := make([]byte, 150000) // Large enough for all game modes
|
||||
copy(saveData[88:], append([]byte(name), 0x00)) // Name at offset 88 with null terminator
|
||||
|
||||
// Import the nullcomp package for compression
|
||||
compressed, err := nullcomp.Compress(saveData)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to compress savedata: %v", err)
|
||||
}
|
||||
|
||||
var charID uint32
|
||||
err = db.QueryRow(`
|
||||
INSERT INTO characters (user_id, is_female, is_new_character, name, unk_desc_string, gr, hr, weapon_type, last_login, savedata, decomyset, savemercenary)
|
||||
VALUES ($1, false, false, $2, '', 0, 0, 0, 0, $3, '', '')
|
||||
RETURNING id
|
||||
`, userID, name, compressed).Scan(&charID)
|
||||
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create test character: %v", err)
|
||||
}
|
||||
|
||||
return charID
|
||||
}
|
||||
Reference in New Issue
Block a user