Files
Erupe/server/channelserver/handlers_clients_test.go
Houmgaor ad4afb4d3b refactor(channelserver): replace global stagesLock with sync.Map-backed StageMap
The global stagesLock sync.RWMutex protected map[string]*Stage, causing
all stage operations to contend on a single lock even for unrelated
stages. Any stage creation or deletion blocked all reads server-wide.

Replace with a typed StageMap wrapper around sync.Map which provides
lock-free reads and allows concurrent writes to disjoint keys. Per-stage
sync.RWMutex remains unchanged for protecting individual stage state.

StageMap exposes Get, GetOrCreate, StoreIfAbsent, Store, Delete, and
Range methods. Updated ~50 call sites across 6 production files and
9 test files.
2026-02-22 15:47:21 +01:00

586 lines
14 KiB
Go

package channelserver
import (
"fmt"
"testing"
cfg "erupe-ce/config"
"erupe-ce/common/byteframe"
"erupe-ce/network/mhfpacket"
"go.uber.org/zap"
)
// TestHandleMsgSysEnumerateClient tests client enumeration in stages
func TestHandleMsgSysEnumerateClient(t *testing.T) {
tests := []struct {
name string
stageID string
getType uint8
setupStage func(*Server, string)
wantClientCount int
wantFailure bool
}{
{
name: "enumerate_all_clients",
stageID: "test_stage_1",
getType: 0, // All clients
setupStage: func(server *Server, stageID string) {
stage := NewStage(stageID)
mock1 := &MockCryptConn{sentPackets: make([][]byte, 0)}
mock2 := &MockCryptConn{sentPackets: make([][]byte, 0)}
s1 := createTestSession(mock1)
s2 := createTestSession(mock2)
s1.charID = 100
s2.charID = 200
stage.clients[s1] = 100
stage.clients[s2] = 200
server.stages.Store(stageID, stage)
},
wantClientCount: 2,
wantFailure: false,
},
{
name: "enumerate_not_ready_clients",
stageID: "test_stage_2",
getType: 1, // Not ready
setupStage: func(server *Server, stageID string) {
stage := NewStage(stageID)
stage.reservedClientSlots[100] = false // Not ready
stage.reservedClientSlots[200] = true // Ready
stage.reservedClientSlots[300] = false // Not ready
server.stages.Store(stageID, stage)
},
wantClientCount: 2, // Only not-ready clients
wantFailure: false,
},
{
name: "enumerate_ready_clients",
stageID: "test_stage_3",
getType: 2, // Ready
setupStage: func(server *Server, stageID string) {
stage := NewStage(stageID)
stage.reservedClientSlots[100] = false // Not ready
stage.reservedClientSlots[200] = true // Ready
stage.reservedClientSlots[300] = true // Ready
server.stages.Store(stageID, stage)
},
wantClientCount: 2, // Only ready clients
wantFailure: false,
},
{
name: "enumerate_empty_stage",
stageID: "test_stage_empty",
getType: 0,
setupStage: func(server *Server, stageID string) {
stage := NewStage(stageID)
server.stages.Store(stageID, stage)
},
wantClientCount: 0,
wantFailure: false,
},
{
name: "enumerate_nonexistent_stage",
stageID: "nonexistent_stage",
getType: 0,
setupStage: func(server *Server, stageID string) {
// Don't create the stage
},
wantClientCount: 0,
wantFailure: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test session (which creates a server with erupeConfig)
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
// Setup stage
tt.setupStage(s.server, tt.stageID)
pkt := &mhfpacket.MsgSysEnumerateClient{
AckHandle: 1234,
StageID: tt.stageID,
Get: tt.getType,
}
handleMsgSysEnumerateClient(s, pkt)
// Check if ACK was sent
if len(s.sendPackets) == 0 {
t.Fatal("no ACK packet was sent")
}
// Read the ACK packet
ackPkt := <-s.sendPackets
if tt.wantFailure {
// For failures, we can't easily check the exact format
// Just verify something was sent
return
}
// Parse the response to count clients
// The ackPkt.data contains the full packet structure:
// [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...]
// Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes
if len(ackPkt.data) < 10 {
t.Fatal("ACK packet too small")
}
// The response data starts after the 10-byte header
// Response format is: [count:uint16][charID1:uint32][charID2:uint32]...
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
count := bf.ReadUint16()
if int(count) != tt.wantClientCount {
t.Errorf("client count = %d, want %d", count, tt.wantClientCount)
}
})
}
}
// TestHandleMsgMhfListMember tests listing blacklisted members
func TestHandleMsgMhfListMember_Integration(t *testing.T) {
db := SetupTestDB(t)
defer TeardownTestDB(t, db)
tests := []struct {
name string
wantBlockCount int
}{
{
name: "no_blocked_users",
wantBlockCount: 0,
},
{
name: "single_blocked_user",
wantBlockCount: 1,
},
{
name: "multiple_blocked_users",
wantBlockCount: 3,
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test user and character (use short names to avoid 15 char limit)
userID := CreateTestUser(t, db, "user_"+tt.name)
charName := fmt.Sprintf("Char%d", i)
charID := CreateTestCharacter(t, db, userID, charName)
// Create blocked characters and build CSV from their actual IDs
blockedCSV := ""
for j := 0; j < tt.wantBlockCount; j++ {
blockedUserID := CreateTestUser(t, db, fmt.Sprintf("blk_%s_%d", tt.name, j))
blockedCharID := CreateTestCharacter(t, db, blockedUserID, fmt.Sprintf("Blk%d_%d", i, j))
if blockedCSV != "" {
blockedCSV += ","
}
blockedCSV += fmt.Sprintf("%d", blockedCharID)
}
// Set blocked list
_, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", blockedCSV, charID)
if err != nil {
t.Fatalf("Failed to update blocked list: %v", err)
}
// Create test session
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.charID = charID
SetTestDB(s.server, db)
pkt := &mhfpacket.MsgMhfListMember{
AckHandle: 5678,
}
handleMsgMhfListMember(s, pkt)
// Verify ACK was sent
if len(s.sendPackets) == 0 {
t.Fatal("no ACK packet was sent")
}
// Parse response
// The ackPkt.data contains the full packet structure:
// [opcode:2 bytes][ack_handle:4 bytes][is_buffer:1 byte][error_code:1 byte][payload_size:2 bytes][data...]
// Total header size: 2 + 4 + 1 + 1 + 2 = 10 bytes
ackPkt := <-s.sendPackets
if len(ackPkt.data) < 10 {
t.Fatal("ACK packet too small")
}
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
count := bf.ReadUint32()
if int(count) != tt.wantBlockCount {
t.Errorf("blocked count = %d, want %d", count, tt.wantBlockCount)
}
})
}
}
// TestHandleMsgMhfOprMember tests blacklist/friendlist operations
func TestHandleMsgMhfOprMember_Integration(t *testing.T) {
db := SetupTestDB(t)
defer TeardownTestDB(t, db)
tests := []struct {
name string
isBlacklist bool
operation bool // true = remove, false = add
initialList string
targetCharIDs []uint32
wantList string
}{
{
name: "add_to_blacklist",
isBlacklist: true,
operation: false,
initialList: "",
targetCharIDs: []uint32{2},
wantList: "2",
},
{
name: "remove_from_blacklist",
isBlacklist: true,
operation: true,
initialList: "2,3,4",
targetCharIDs: []uint32{3},
wantList: "2,4",
},
{
name: "add_to_friendlist",
isBlacklist: false,
operation: false,
initialList: "10",
targetCharIDs: []uint32{20},
wantList: "10,20",
},
{
name: "remove_from_friendlist",
isBlacklist: false,
operation: true,
initialList: "10,20,30",
targetCharIDs: []uint32{20},
wantList: "10,30",
},
{
name: "add_multiple_to_blacklist",
isBlacklist: true,
operation: false,
initialList: "1",
targetCharIDs: []uint32{2, 3},
wantList: "1,2,3",
},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test user and character (use short names to avoid 15 char limit)
userID := CreateTestUser(t, db, "user_"+tt.name)
charName := fmt.Sprintf("OpChar%d", i)
charID := CreateTestCharacter(t, db, userID, charName)
// Set initial list
column := "blocked"
if !tt.isBlacklist {
column = "friends"
}
_, err := db.Exec("UPDATE characters SET "+column+" = $1 WHERE id = $2", tt.initialList, charID)
if err != nil {
t.Fatalf("Failed to set initial list: %v", err)
}
// Create test session
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.charID = charID
SetTestDB(s.server, db)
pkt := &mhfpacket.MsgMhfOprMember{
AckHandle: 9999,
Blacklist: tt.isBlacklist,
Operation: tt.operation,
CharIDs: tt.targetCharIDs,
}
handleMsgMhfOprMember(s, pkt)
// Verify ACK was sent
if len(s.sendPackets) == 0 {
t.Fatal("no ACK packet was sent")
}
<-s.sendPackets
// Verify the list was updated
var gotList string
err = db.QueryRow("SELECT "+column+" FROM characters WHERE id = $1", charID).Scan(&gotList)
if err != nil {
t.Fatalf("Failed to query updated list: %v", err)
}
if gotList != tt.wantList {
t.Errorf("list = %q, want %q", gotList, tt.wantList)
}
})
}
}
// TestHandleMsgMhfShutClient tests the shut client handler
func TestHandleMsgMhfShutClient(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
pkt := &mhfpacket.MsgMhfShutClient{}
// Should not panic (handler is empty)
handleMsgMhfShutClient(s, pkt)
}
// TestHandleMsgSysHideClient tests the hide client handler
func TestHandleMsgSysHideClient(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
tests := []struct {
name string
hide bool
}{
{
name: "hide_client",
hide: true,
},
{
name: "show_client",
hide: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pkt := &mhfpacket.MsgSysHideClient{
Hide: tt.hide,
}
// Should not panic (handler is empty)
handleMsgSysHideClient(s, pkt)
})
}
}
// TestEnumerateClient_ConcurrentAccess tests concurrent stage access
func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
logger, _ := zap.NewDevelopment()
server := &Server{
logger: logger,
erupeConfig: &cfg.Config{
DebugOptions: cfg.DebugOptions{
LogOutboundMessages: false,
},
},
}
stageID := "concurrent_test_stage"
stage := NewStage(stageID)
// Add some clients to the stage
for i := uint32(1); i <= 10; i++ {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
sess := createTestSession(mock)
sess.charID = i * 100
stage.clients[sess] = i * 100
}
server.stages.Store(stageID, stage)
// Run concurrent enumerations
done := make(chan bool, 5)
for i := 0; i < 5; i++ {
go func() {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server = server
pkt := &mhfpacket.MsgSysEnumerateClient{
AckHandle: 3333,
StageID: stageID,
Get: 0, // All clients
}
handleMsgSysEnumerateClient(s, pkt)
done <- true
}()
}
// Wait for all goroutines to complete
for i := 0; i < 5; i++ {
<-done
}
}
// TestListMember_EmptyDatabase tests listing members when database is empty
func TestListMember_EmptyDatabase_Integration(t *testing.T) {
db := SetupTestDB(t)
defer TeardownTestDB(t, db)
// Create test user and character
userID := CreateTestUser(t, db, "emptytest")
charID := CreateTestCharacter(t, db, userID, "EmptyChar")
// Create test session
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.charID = charID
SetTestDB(s.server, db)
pkt := &mhfpacket.MsgMhfListMember{
AckHandle: 4444,
}
handleMsgMhfListMember(s, pkt)
// Verify ACK was sent
if len(s.sendPackets) == 0 {
t.Fatal("no ACK packet was sent")
}
ackPkt := <-s.sendPackets
if len(ackPkt.data) < 10 {
t.Fatal("ACK packet too small")
}
bf := byteframe.NewByteFrameFromBytes(ackPkt.data[10:]) // Skip full ACK header
count := bf.ReadUint32()
if count != 0 {
t.Errorf("empty blocked list should have count 0, got %d", count)
}
}
// TestOprMember_EdgeCases tests edge cases for member operations
func TestOprMember_EdgeCases_Integration(t *testing.T) {
db := SetupTestDB(t)
defer TeardownTestDB(t, db)
tests := []struct {
name string
initialList string
operation bool
targetCharIDs []uint32
wantList string
}{
{
name: "add_duplicate_to_list",
initialList: "1,2,3",
operation: false, // add
targetCharIDs: []uint32{2},
wantList: "1,2,3", // CSV helper deduplicates
},
{
name: "remove_nonexistent_from_list",
initialList: "1,2,3",
operation: true, // remove
targetCharIDs: []uint32{99},
wantList: "1,2,3",
},
{
name: "operate_on_empty_list",
initialList: "",
operation: false,
targetCharIDs: []uint32{1},
wantList: "1",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create test user and character
userID := CreateTestUser(t, db, "edge_"+tt.name)
charID := CreateTestCharacter(t, db, userID, "EdgeChar")
// Set initial blocked list
_, err := db.Exec("UPDATE characters SET blocked = $1 WHERE id = $2", tt.initialList, charID)
if err != nil {
t.Fatalf("Failed to set initial list: %v", err)
}
// Create test session
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.charID = charID
SetTestDB(s.server, db)
pkt := &mhfpacket.MsgMhfOprMember{
AckHandle: 7777,
Blacklist: true,
Operation: tt.operation,
CharIDs: tt.targetCharIDs,
}
handleMsgMhfOprMember(s, pkt)
// Verify ACK was sent
if len(s.sendPackets) == 0 {
t.Fatal("no ACK packet was sent")
}
<-s.sendPackets
// Verify the list
var gotList string
err = db.QueryRow("SELECT blocked FROM characters WHERE id = $1", charID).Scan(&gotList)
if err != nil {
t.Fatalf("Failed to query list: %v", err)
}
if gotList != tt.wantList {
t.Errorf("list = %q, want %q", gotList, tt.wantList)
}
})
}
}
// BenchmarkEnumerateClients benchmarks client enumeration
func BenchmarkEnumerateClients(b *testing.B) {
logger, _ := zap.NewDevelopment()
server := &Server{
logger: logger,
}
stageID := "bench_stage"
stage := NewStage(stageID)
// Add 100 clients to the stage
for i := uint32(1); i <= 100; i++ {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
sess := createTestSession(mock)
sess.charID = i
stage.clients[sess] = i
}
server.stages.Store(stageID, stage)
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server = server
pkt := &mhfpacket.MsgSysEnumerateClient{
AckHandle: 8888,
StageID: stageID,
Get: 0,
}
b.ResetTimer()
for i := 0; i < b.N; i++ {
// Clear the packet channel
select {
case <-s.sendPackets:
default:
}
handleMsgSysEnumerateClient(s, pkt)
<-s.sendPackets
}
}