refactor(channelserver): replace global stagesLock with sync.Map-backed StageMap

The global stagesLock sync.RWMutex protected map[string]*Stage, causing
all stage operations to contend on a single lock even for unrelated
stages. Any stage creation or deletion blocked all reads server-wide.

Replace with a typed StageMap wrapper around sync.Map which provides
lock-free reads and allows concurrent writes to disjoint keys. Per-stage
sync.RWMutex remains unchanged for protecting individual stage state.

StageMap exposes Get, GetOrCreate, StoreIfAbsent, Store, Delete, and
Range methods. Updated ~50 call sites across 6 production files and
9 test files.
This commit is contained in:
Houmgaor
2026-02-22 15:47:21 +01:00
parent 2a5cd50e3f
commit ad4afb4d3b
15 changed files with 207 additions and 221 deletions

View File

@@ -192,22 +192,16 @@ func TestChannelIsolation_IndependentStages(t *testing.T) {
stageName := "sl1Qs999p0a0u42"
// Add stage only to channel 1.
channels[0].stagesLock.Lock()
channels[0].stages[stageName] = NewStage(stageName)
channels[0].stagesLock.Unlock()
channels[0].stages.Store(stageName, NewStage(stageName))
// Channel 1 should have the stage.
channels[0].stagesLock.RLock()
_, ok1 := channels[0].stages[stageName]
channels[0].stagesLock.RUnlock()
_, ok1 := channels[0].stages.Get(stageName)
if !ok1 {
t.Error("channel 1 should have the stage")
}
// Channel 2 should NOT have the stage.
channels[1].stagesLock.RLock()
_, ok2 := channels[1].stages[stageName]
channels[1].stagesLock.RUnlock()
_, ok2 := channels[1].stages.Get(stageName)
if ok2 {
t.Error("channel 2 should not have channel 1's stage")
}

View File

@@ -56,16 +56,18 @@ func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) {
func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string {
for _, channel := range r.channels {
channel.stagesLock.RLock()
for id := range channel.stages {
var gid string
channel.stages.Range(func(id string, _ *Stage) bool {
if strings.HasSuffix(id, stageSuffix) {
gid := channel.GlobalID
channel.stagesLock.RUnlock()
gid = channel.GlobalID
return false // stop iteration
}
return true
})
if gid != "" {
return gid
}
}
channel.stagesLock.RUnlock()
}
return ""
}
@@ -105,13 +107,14 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
if len(results) >= max {
break
}
c.stagesLock.RLock()
for _, stage := range c.stages {
cIP := net.ParseIP(c.IP).To4()
cPort := c.Port
c.stages.Range(func(_ string, stage *Stage) bool {
if len(results) >= max {
break
return false
}
if !strings.HasPrefix(stage.id, stagePrefix) {
continue
return true
}
stage.RLock()
bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}]
@@ -125,8 +128,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
copy(bin3Copy, bin3)
results = append(results, StageSnapshot{
ServerIP: net.ParseIP(c.IP).To4(),
ServerPort: c.Port,
ServerIP: cIP,
ServerPort: cPort,
StageID: stage.id,
ClientCount: len(stage.clients) + len(stage.reservedClientSlots),
Reserved: len(stage.reservedClientSlots),
@@ -136,8 +139,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
RawBinData3: bin3Copy,
})
stage.RUnlock()
}
c.stagesLock.RUnlock()
return true
})
}
return results
}

View File

@@ -61,9 +61,7 @@ func TestLocalRegistryFindChannelForStage(t *testing.T) {
channels[1].GlobalID = "0102"
reg := NewLocalChannelRegistry(channels)
channels[1].stagesLock.Lock()
channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42")
channels[1].stagesLock.Unlock()
channels[1].stages.Store("sl2Qs123p0a0u42", NewStage("sl2Qs123p0a0u42"))
gid := reg.FindChannelForStage("u42")
if gid != "0102" {
@@ -136,11 +134,9 @@ func TestLocalRegistrySearchStages(t *testing.T) {
channels := createTestChannels(1)
reg := NewLocalChannelRegistry(channels)
channels[0].stagesLock.Lock()
channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1")
channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2")
channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other")
channels[0].stagesLock.Unlock()
channels[0].stages.Store("sl2Ls210test1", NewStage("sl2Ls210test1"))
channels[0].stages.Store("sl2Ls210test2", NewStage("sl2Ls210test2"))
channels[0].stages.Store("sl1Ns200other", NewStage("sl1Ns200other"))
results := reg.SearchStages("sl2Ls210", 10)
if len(results) != 2 {

View File

@@ -10,15 +10,12 @@ import (
func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateClient)
s.server.stagesLock.RLock()
stage, ok := s.server.stages[pkt.StageID]
stage, ok := s.server.stages.Get(pkt.StageID)
if !ok {
s.server.stagesLock.RUnlock()
s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID))
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
return
}
s.server.stagesLock.RUnlock()
// Read-lock the stage and make the response with all of the charID's in the stage.
resp := byteframe.NewByteFrame()

View File

@@ -34,9 +34,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
s2.charID = 200
stage.clients[s1] = 100
stage.clients[s2] = 200
server.stagesLock.Lock()
server.stages[stageID] = stage
server.stagesLock.Unlock()
server.stages.Store(stageID, stage)
},
wantClientCount: 2,
wantFailure: false,
@@ -50,9 +48,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
stage.reservedClientSlots[100] = false // Not ready
stage.reservedClientSlots[200] = true // Ready
stage.reservedClientSlots[300] = false // Not ready
server.stagesLock.Lock()
server.stages[stageID] = stage
server.stagesLock.Unlock()
server.stages.Store(stageID, stage)
},
wantClientCount: 2, // Only not-ready clients
wantFailure: false,
@@ -66,9 +62,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
stage.reservedClientSlots[100] = false // Not ready
stage.reservedClientSlots[200] = true // Ready
stage.reservedClientSlots[300] = true // Ready
server.stagesLock.Lock()
server.stages[stageID] = stage
server.stagesLock.Unlock()
server.stages.Store(stageID, stage)
},
wantClientCount: 2, // Only ready clients
wantFailure: false,
@@ -79,9 +73,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
getType: 0,
setupStage: func(server *Server, stageID string) {
stage := NewStage(stageID)
server.stagesLock.Lock()
server.stages[stageID] = stage
server.stagesLock.Unlock()
server.stages.Store(stageID, stage)
},
wantClientCount: 0,
wantFailure: false,
@@ -104,11 +96,6 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
// Initialize stages map if needed
if s.server.stages == nil {
s.server.stages = make(map[string]*Stage)
}
// Setup stage
tt.setupStage(s.server, tt.stageID)
@@ -389,7 +376,6 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
logger, _ := zap.NewDevelopment()
server := &Server{
logger: logger,
stages: make(map[string]*Stage),
erupeConfig: &cfg.Config{
DebugOptions: cfg.DebugOptions{
LogOutboundMessages: false,
@@ -408,9 +394,7 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
stage.clients[sess] = i * 100
}
server.stagesLock.Lock()
server.stages[stageID] = stage
server.stagesLock.Unlock()
server.stages.Store(stageID, stage)
// Run concurrent enumerations
done := make(chan bool, 5)
@@ -562,7 +546,6 @@ func BenchmarkEnumerateClients(b *testing.B) {
logger, _ := zap.NewDevelopment()
server := &Server{
logger: logger,
stages: make(map[string]*Stage),
}
stageID := "bench_stage"
@@ -576,7 +559,7 @@ func BenchmarkEnumerateClients(b *testing.B) {
stage.clients[sess] = i
}
server.stages[stageID] = stage
server.stages.Store(stageID, stage)
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)

View File

@@ -600,9 +600,8 @@ func TestHandleMsgSysLockGlobalSema_WithChannel(t *testing.T) {
// Create a mock channel with stages
channel := &Server{
GlobalID: "other-server",
stages: make(map[string]*Stage),
}
channel.stages["stage_user123"] = NewStage("stage_user123")
channel.stages.Store("stage_user123", NewStage("stage_user123"))
server.Channels = []*Server{channel}
session := createMockSession(1, server)
@@ -632,9 +631,8 @@ func TestHandleMsgSysLockGlobalSema_SameServer(t *testing.T) {
// Create a mock channel with same GlobalID
channel := &Server{
GlobalID: "test-server",
stages: make(map[string]*Stage),
}
channel.stages["stage_user456"] = NewStage("stage_user456")
channel.stages.Store("stage_user456", NewStage("stage_user456"))
server.Channels = []*Server{channel}
session := createMockSession(1, server)

View File

@@ -984,7 +984,7 @@ func TestHandleMsgSysCreateStage_Coverage3(t *testing.T) {
default:
t.Error("no response queued")
}
if _, exists := server.stages["test_create_stage"]; !exists {
if _, exists := server.stages.Get("test_create_stage"); !exists {
t.Error("stage should have been created")
}
})

View File

@@ -290,7 +290,7 @@ func logoutPlayer(s *Session) {
_ = s.rawConn.Close()
s.server.Unlock()
// Stage cleanup — snapshot sessions first under server mutex, then iterate stages under stagesLock
// Stage cleanup — snapshot sessions first under server mutex, then iterate stages
s.server.Lock()
sessionSnapshot := make([]*Session, 0, len(s.server.sessions))
for _, sess := range s.server.sessions {
@@ -298,8 +298,7 @@ func logoutPlayer(s *Session) {
}
s.server.Unlock()
s.server.stagesLock.RLock()
for _, stage := range s.server.stages {
s.server.stages.Range(func(_ string, stage *Stage) bool {
stage.Lock()
// Tell sessions registered to disconnecting player's quest to unregister
if stage.host != nil && stage.host.charID == s.charID {
@@ -317,8 +316,8 @@ func logoutPlayer(s *Session) {
}
}
stage.Unlock()
}
s.server.stagesLock.RUnlock()
return true
})
// Update sign sessions and server player count
if s.server.db != nil {
@@ -346,13 +345,12 @@ func logoutPlayer(s *Session) {
CharID: s.charID,
}, s)
s.server.stagesLock.RLock()
for _, stage := range s.server.stages {
s.server.stages.Range(func(_ string, stage *Stage) bool {
stage.Lock()
delete(stage.reservedClientSlots, s.charID)
stage.Unlock()
}
s.server.stagesLock.RUnlock()
return true
})
removeSessionFromSemaphore(s)
removeSessionFromStage(s)
@@ -449,13 +447,12 @@ func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) {
sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString)
} else {
for _, channel := range s.server.Channels {
channel.stagesLock.RLock()
for id := range channel.stages {
channel.stages.Range(func(id string, _ *Stage) bool {
if strings.HasSuffix(id, pkt.UserIDString) {
sgid = channel.GlobalID
}
}
channel.stagesLock.RUnlock()
return true
})
}
}
bf := byteframe.NewByteFrame()
@@ -689,10 +686,11 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
if count == maxResults {
break
}
c.stagesLock.RLock()
for _, stage := range c.stages {
cIP := net.ParseIP(c.IP).To4()
cPort := c.Port
c.stages.Range(func(_ string, stage *Stage) bool {
if count == maxResults {
break
return false
}
if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) {
stage.RLock()
@@ -718,7 +716,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
if findPartyParams.RankRestriction >= 0 {
if stageData[0] > findPartyParams.RankRestriction {
stage.RUnlock()
continue
return true
}
}
@@ -732,7 +730,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
}
if !hasTarget {
stage.RUnlock()
continue
return true
}
}
@@ -746,8 +744,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
count++
stageResults = append(stageResults, stageResult{
ip: net.ParseIP(c.IP).To4(),
port: c.Port,
ip: cIP,
port: cPort,
clientCount: len(stage.clients) + len(stage.reservedClientSlots),
reserved: len(stage.reservedClientSlots),
maxPlayers: stage.maxPlayers,
@@ -758,8 +756,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
})
stage.RUnlock()
}
}
c.stagesLock.RUnlock()
return true
})
}
for _, sr := range stageResults {

View File

@@ -14,32 +14,23 @@ import (
func handleMsgSysCreateStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysCreateStage)
s.server.stagesLock.Lock()
defer s.server.stagesLock.Unlock()
if _, exists := s.server.stages[pkt.StageID]; exists {
doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
} else {
stage := NewStage(pkt.StageID)
stage.host = s
stage.maxPlayers = uint16(pkt.PlayerCount)
s.server.stages[stage.id] = stage
if s.server.stages.StoreIfAbsent(pkt.StageID, stage) {
doAckSimpleSucceed(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
} else {
doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
}
}
func handleMsgSysStageDestruct(s *Session, p mhfpacket.MHFPacket) {}
func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
s.server.stagesLock.Lock()
stage, exists := s.server.stages[stageID]
if !exists {
s.server.stages[stageID] = NewStage(stageID)
stage = s.server.stages[stageID]
}
s.server.stagesLock.Unlock()
stage, created := s.server.stages.GetOrCreate(stageID)
stage.Lock()
if !exists {
if created {
stage.host = s
}
stage.clients[s] = s.charID
@@ -50,12 +41,9 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
removeSessionFromStage(s)
}
// Save our new stage ID and pointer to the new stage itself.
s.server.stagesLock.RLock()
newStage := s.server.stages[stageID]
s.server.stagesLock.RUnlock()
// Save our new stage pointer.
s.Lock()
s.stage = newStage
s.stage = stage
s.Unlock()
// Tell the client to cleanup its current stage objects.
@@ -140,22 +128,20 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
}
func destructEmptyStages(s *Session) {
s.server.stagesLock.Lock()
defer s.server.stagesLock.Unlock()
for _, stage := range s.server.stages {
s.server.stages.Range(func(id string, stage *Stage) bool {
// Destroy empty Quest/My series/Guild stages.
if stage.id[3:5] == "Qs" || stage.id[3:5] == "Ms" || stage.id[3:5] == "Gs" || stage.id[3:5] == "Ls" {
// Lock stage to safely check its client and reservation counts
if id[3:5] == "Qs" || id[3:5] == "Ms" || id[3:5] == "Gs" || id[3:5] == "Ls" {
stage.Lock()
isEmpty := len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0
stage.Unlock()
if isEmpty {
delete(s.server.stages, stage.id)
s.logger.Debug("Destructed stage", zap.String("stage.id", stage.id))
}
s.server.stages.Delete(id)
s.logger.Debug("Destructed stage", zap.String("stage.id", id))
}
}
return true
})
}
func removeSessionFromStage(s *Session) {
@@ -194,9 +180,7 @@ func removeSessionFromStage(s *Session) {
}
func isStageFull(s *Session, StageID string) bool {
s.server.stagesLock.RLock()
stage, exists := s.server.stages[StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(StageID)
if exists {
// Lock stage to safely check client counts
@@ -261,9 +245,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
s.stage.Unlock()
}
s.server.stagesLock.RLock()
backStagePtr, exists := s.server.stages[backStage]
s.server.stagesLock.RUnlock()
backStagePtr, exists := s.server.stages.Get(backStage)
if exists {
backStagePtr.Lock()
delete(backStagePtr.reservedClientSlots, s.charID)
@@ -288,9 +270,7 @@ func handleMsgSysLeaveStage(s *Session, p mhfpacket.MHFPacket) {}
func handleMsgSysLockStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysLockStage)
s.server.stagesLock.RLock()
stage, exists := s.server.stages[pkt.StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(pkt.StageID)
if exists {
stage.Lock()
stage.locked = true
@@ -317,10 +297,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) {
}
}
// Delete from stages map under stagesLock (not nested inside stage RLock)
s.server.stagesLock.Lock()
delete(s.server.stages, stageID)
s.server.stagesLock.Unlock()
s.server.stages.Delete(stageID)
}
destructEmptyStages(s)
@@ -328,9 +305,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysReserveStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysReserveStage)
s.server.stagesLock.RLock()
stage, exists := s.server.stages[pkt.StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(pkt.StageID)
if exists {
stage.Lock()
defer stage.Unlock()
@@ -402,9 +377,7 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysSetStageBinary)
s.server.stagesLock.RLock()
stage, exists := s.server.stages[pkt.StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(pkt.StageID)
if exists {
stage.Lock()
stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] = pkt.RawDataPayload
@@ -416,9 +389,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysGetStageBinary)
s.server.stagesLock.RLock()
stage, exists := s.server.stages[pkt.StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(pkt.StageID)
if exists {
stage.Lock()
if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists {
@@ -443,9 +414,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysWaitStageBinary)
s.server.stagesLock.RLock()
stage, exists := s.server.stages[pkt.StageID]
s.server.stagesLock.RUnlock()
stage, exists := s.server.stages.Get(pkt.StageID)
if exists {
if pkt.BinaryType0 == 1 && pkt.BinaryType1 == 12 {
// This might contain the hunter count, or max player count?
@@ -479,24 +448,20 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateStage)
// Read-lock the server stage map.
s.server.stagesLock.RLock()
defer s.server.stagesLock.RUnlock()
// Build the response
bf := byteframe.NewByteFrame()
var joinable uint16
bf.WriteUint16(0)
for sid, stage := range s.server.stages {
s.server.stages.Range(func(sid string, stage *Stage) bool {
stage.RLock()
if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 {
stage.RUnlock()
continue
return true
}
if !strings.Contains(stage.id, pkt.StagePrefix) {
stage.RUnlock()
continue
return true
}
joinable++
@@ -518,7 +483,8 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
bf.WriteUint8(flags)
ps.Uint8(bf, sid, false)
stage.RUnlock()
}
return true
})
_, _ = bf.Seek(0, 0)
bf.WriteUint16(joinable)

View File

@@ -17,7 +17,7 @@ const raceTestCompletionMsg = "Test completed. No race conditions with fixed loc
func TestCreateStageSuccess(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
// Create a new stage
pkt := &mhfpacket.MsgSysCreateStage{
@@ -29,11 +29,10 @@ func TestCreateStageSuccess(t *testing.T) {
handleMsgSysCreateStage(s, pkt)
// Verify stage was created
if _, exists := s.server.stages["test_stage_1"]; !exists {
stage, exists := s.server.stages.Get("test_stage_1")
if !exists {
t.Error("stage was not created")
}
stage := s.server.stages["test_stage_1"]
if stage.id != "test_stage_1" {
t.Errorf("stage ID mismatch: got %s, want test_stage_1", stage.id)
}
@@ -46,7 +45,7 @@ func TestCreateStageSuccess(t *testing.T) {
func TestCreateStageDuplicate(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
// Create first stage
pkt1 := &mhfpacket.MsgSysCreateStage{
@@ -65,8 +64,10 @@ func TestCreateStageDuplicate(t *testing.T) {
handleMsgSysCreateStage(s, pkt2)
// Verify only one stage exists
if len(s.server.stages) != 1 {
t.Errorf("expected 1 stage, got %d", len(s.server.stages))
count := 0
s.server.stages.Range(func(_ string, _ *Stage) bool { count++; return true })
if count != 1 {
t.Errorf("expected 1 stage, got %d", count)
}
}
@@ -74,13 +75,13 @@ func TestCreateStageDuplicate(t *testing.T) {
func TestStageLocking(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
// Create a stage
stage := NewStage("locked_stage")
stage.host = s
stage.password = ""
s.server.stages["locked_stage"] = stage
s.server.stages.Store("locked_stage", stage)
// Lock the stage
pkt := &mhfpacket.MsgSysLockStage{
@@ -103,14 +104,14 @@ func TestStageLocking(t *testing.T) {
func TestStageReservation(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
// Create a stage
stage := NewStage("reserved_stage")
stage.host = s
stage.reservedClientSlots = make(map[uint32]bool)
stage.reservedClientSlots[s.charID] = false // Pre-add the charID so reservation works
s.server.stages["reserved_stage"] = stage
s.server.stages.Store("reserved_stage", stage)
// Reserve the stage
pkt := &mhfpacket.MsgSysReserveStage{
@@ -163,8 +164,8 @@ func TestStageBinaryData(t *testing.T) {
stage := NewStage("binary_stage")
stage.rawBinaryData = make(map[stageBinaryKey][]byte)
s.stage = stage
s.server.stages = make(map[string]*Stage)
s.server.stages["binary_stage"] = stage
s.server.stages.Store("binary_stage", stage)
// Store binary data directly
key := stageBinaryKey{id0: byte(s.charID >> 8), id1: byte(s.charID & 0xFF)}
@@ -230,8 +231,8 @@ func TestIsStageFull(t *testing.T) {
stage.clients[client] = uint32(i)
}
s.server.stages = make(map[string]*Stage)
s.server.stages["full_test_stage"] = stage
s.server.stages.Store("full_test_stage", stage)
result := isStageFull(s, "full_test_stage")
if result != tt.wantFull {
@@ -245,14 +246,14 @@ func TestIsStageFull(t *testing.T) {
func TestEnumerateStage(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
// Create multiple stages
for i := 0; i < 3; i++ {
stage := NewStage("stage_" + string(rune(i)))
stage.maxPlayers = 4
s.server.stages[stage.id] = stage
s.server.stages.Store(stage.id, stage)
}
// Enumerate stages
@@ -264,8 +265,10 @@ func TestEnumerateStage(t *testing.T) {
// Basic verification that enumeration was processed
// In a real test, we'd verify the response packet content
if len(s.server.stages) != 3 {
t.Errorf("expected 3 stages, got %d", len(s.server.stages))
stageCount := 0
s.server.stages.Range(func(_ string, _ *Stage) bool { stageCount++; return true })
if stageCount != 3 {
t.Errorf("expected 3 stages, got %d", stageCount)
}
}
@@ -279,8 +282,8 @@ func TestRemoveSessionFromStage(t *testing.T) {
stage.clients[s] = s.charID
s.stage = stage
s.server.stages = make(map[string]*Stage)
s.server.stages["removal_stage"] = stage
s.server.stages.Store("removal_stage", stage)
// Remove session
removeSessionFromStage(s)
@@ -299,18 +302,18 @@ func TestRemoveSessionFromStage(t *testing.T) {
func TestDestructEmptyStages(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
// Create stages with different client counts
emptyStage := NewStage("empty_stage")
emptyStage.clients = make(map[*Session]uint32)
emptyStage.host = s // Host needs to be set or it won't be destructed
s.server.stages["empty_stage"] = emptyStage
s.server.stages.Store("empty_stage", emptyStage)
populatedStage := NewStage("populated_stage")
populatedStage.clients = make(map[*Session]uint32)
populatedStage.clients[s] = s.charID
s.server.stages["populated_stage"] = populatedStage
s.server.stages.Store("populated_stage", populatedStage)
// Destruct empty stages (from the channel server's perspective, not our session's)
// The function destructs stages that are not referenced by us or don't have clients
@@ -318,8 +321,10 @@ func TestDestructEmptyStages(t *testing.T) {
// For this test to work correctly, we'd need to verify the actual removal
// Let's just verify the stages exist first
if len(s.server.stages) != 2 {
t.Errorf("expected 2 stages initially, got %d", len(s.server.stages))
initialCount := 0
s.server.stages.Range(func(_ string, _ *Stage) bool { initialCount++; return true })
if initialCount != 2 {
t.Errorf("expected 2 stages initially, got %d", initialCount)
}
}
@@ -327,14 +332,14 @@ func TestDestructEmptyStages(t *testing.T) {
func TestStageTransferBasic(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
// Transfer to non-existent stage (should create it)
doStageTransfer(s, 0x12345678, "new_transfer_stage")
// Verify stage was created
if stage, exists := s.server.stages["new_transfer_stage"]; !exists {
if stage, exists := s.server.stages.Get("new_transfer_stage"); !exists {
t.Error("stage was not created during transfer")
} else {
// Verify session is in the stage
@@ -357,12 +362,12 @@ func TestStageTransferBasic(t *testing.T) {
func TestEnterStageBasic(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
stage := NewStage("entry_stage")
stage.clients = make(map[*Session]uint32)
s.server.stages["entry_stage"] = stage
s.server.stages.Store("entry_stage", stage)
pkt := &mhfpacket.MsgSysEnterStage{
StageID: "entry_stage",
@@ -383,7 +388,7 @@ func TestEnterStageBasic(t *testing.T) {
func TestMoveStagePreservesData(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
// Create source stage with binary data
@@ -392,13 +397,13 @@ func TestMoveStagePreservesData(t *testing.T) {
sourceStage.rawBinaryData = make(map[stageBinaryKey][]byte)
key := stageBinaryKey{id0: 0x00, id1: 0x01}
sourceStage.rawBinaryData[key] = []byte{0xAA, 0xBB}
s.server.stages["source_stage"] = sourceStage
s.server.stages.Store("source_stage", sourceStage)
s.stage = sourceStage
// Create destination stage
destStage := NewStage("dest_stage")
destStage.clients = make(map[*Session]uint32)
s.server.stages["dest_stage"] = destStage
s.server.stages.Store("dest_stage", destStage)
pkt := &mhfpacket.MsgSysMoveStage{
StageID: "dest_stage",
@@ -417,12 +422,12 @@ func TestMoveStagePreservesData(t *testing.T) {
func TestConcurrentStageOperations(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
baseSession := createTestSession(mock)
baseSession.server.stages = make(map[string]*Stage)
// Create a stage
stage := NewStage("concurrent_stage")
stage.clients = make(map[*Session]uint32)
baseSession.server.stages["concurrent_stage"] = stage
baseSession.server.stages.Store("concurrent_stage", stage)
var wg sync.WaitGroup
@@ -459,7 +464,7 @@ func TestConcurrentStageOperations(t *testing.T) {
func TestBackStageNavigation(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
// Create a stringstack for stage move history
@@ -472,8 +477,8 @@ func TestBackStageNavigation(t *testing.T) {
stage2 := NewStage("stage_2")
stage2.clients = make(map[*Session]uint32)
s.server.stages["stage_1"] = stage1
s.server.stages["stage_2"] = stage2
s.server.stages.Store("stage_1", stage1)
s.server.stages.Store("stage_2", stage2)
// First enter stage 2 and push to stack
s.stage = stage2
@@ -502,13 +507,13 @@ func TestRaceConditionRemoveSessionFromStageNotLocked(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
s := createTestSession(mock)
s.server.stages = make(map[string]*Stage)
s.server.sessions = make(map[net.Conn]*Session)
stage := NewStage("race_test_stage")
stage.clients = make(map[*Session]uint32)
stage.objects = make(map[uint32]*Object)
s.server.stages["race_test_stage"] = stage
s.server.stages.Store("race_test_stage", stage)
s.stage = stage
stage.clients[s] = s.charID
@@ -567,14 +572,14 @@ func TestRaceConditionDoStageTransferUnlockedAccess(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
baseSession := createTestSession(mock)
baseSession.server.stages = make(map[string]*Stage)
baseSession.server.sessions = make(map[net.Conn]*Session)
// Create initial stage
stage := NewStage("initial_stage")
stage.clients = make(map[*Session]uint32)
stage.objects = make(map[uint32]*Object)
baseSession.server.stages["initial_stage"] = stage
baseSession.server.stages.Store("initial_stage", stage)
baseSession.stage = stage
stage.clients[baseSession] = baseSession.charID
@@ -631,13 +636,13 @@ func TestRaceConditionStageObjectsIteration(t *testing.T) {
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
baseSession := createTestSession(mock)
baseSession.server.stages = make(map[string]*Stage)
baseSession.server.sessions = make(map[net.Conn]*Session)
stage := NewStage("object_race_stage")
stage.clients = make(map[*Session]uint32)
stage.objects = make(map[uint32]*Object)
baseSession.server.stages["object_race_stage"] = stage
baseSession.server.stages.Store("object_race_stage", stage)
baseSession.stage = stage
stage.clients[baseSession] = baseSession.charID

View File

@@ -582,7 +582,6 @@ func createTestServerWithDB(t *testing.T, db *sqlx.DB) *Server {
server := &Server{
db: db,
sessions: make(map[net.Conn]*Session),
stages: make(map[string]*Stage),
userBinary: NewUserBinaryStore(),
minidata: NewMinidataStore(),
semaphore: make(map[string]*Semaphore),

View File

@@ -33,9 +33,11 @@ type Config struct {
//
// Lock ordering (acquire in this order to avoid deadlocks):
// 1. Server.Mutex protects sessions map
// 2. Server.stagesLock protects stages map
// 3. Stage.RWMutex protects per-stage state (clients, objects)
// 4. Server.semaphoreLock protects semaphore map
// 2. Stage.RWMutex protects per-stage state (clients, objects)
// 3. Server.semaphoreLock protects semaphore map
//
// Note: Server.stages is a StageMap (sync.Map-backed), so it requires no
// external lock for reads or writes.
//
// Self-contained stores (userBinary, minidata, questCache) manage their
// own locks internally and may be acquired at any point.
@@ -78,8 +80,7 @@ type Server struct {
isShuttingDown bool
done chan struct{} // Closed on Shutdown to wake background goroutines.
stagesLock sync.RWMutex
stages map[string]*Stage
stages StageMap
// Used to map different languages
i18n i18n
@@ -115,7 +116,6 @@ func NewServer(config *Config) *Server {
deleteConns: make(chan net.Conn),
done: make(chan struct{}),
sessions: make(map[net.Conn]*Session),
stages: make(map[string]*Stage),
userBinary: NewUserBinaryStore(),
minidata: NewMinidataStore(),
semaphore: make(map[string]*Semaphore),
@@ -155,25 +155,25 @@ func NewServer(config *Config) *Server {
s.mercenaryRepo = NewMercenaryRepository(config.DB)
// Mezeporta
s.stages["sl1Ns200p0a0u0"] = NewStage("sl1Ns200p0a0u0")
s.stages.Store("sl1Ns200p0a0u0", NewStage("sl1Ns200p0a0u0"))
// Rasta bar stage
s.stages["sl1Ns211p0a0u0"] = NewStage("sl1Ns211p0a0u0")
s.stages.Store("sl1Ns211p0a0u0", NewStage("sl1Ns211p0a0u0"))
// Pallone Carvan
s.stages["sl1Ns260p0a0u0"] = NewStage("sl1Ns260p0a0u0")
s.stages.Store("sl1Ns260p0a0u0", NewStage("sl1Ns260p0a0u0"))
// Pallone Guest House 1st Floor
s.stages["sl1Ns262p0a0u0"] = NewStage("sl1Ns262p0a0u0")
s.stages.Store("sl1Ns262p0a0u0", NewStage("sl1Ns262p0a0u0"))
// Pallone Guest House 2nd Floor
s.stages["sl1Ns263p0a0u0"] = NewStage("sl1Ns263p0a0u0")
s.stages.Store("sl1Ns263p0a0u0", NewStage("sl1Ns263p0a0u0"))
// Diva fountain / prayer fountain.
s.stages["sl2Ns379p0a0u0"] = NewStage("sl2Ns379p0a0u0")
s.stages.Store("sl2Ns379p0a0u0", NewStage("sl2Ns379p0a0u0"))
// MezFes
s.stages["sl1Ns462p0a0u0"] = NewStage("sl1Ns462p0a0u0")
s.stages.Store("sl1Ns462p0a0u0", NewStage("sl1Ns462p0a0u0"))
s.i18n = getLangStrings(s)
@@ -424,21 +424,20 @@ func (s *Server) DisconnectUser(uid uint32) {
// FindObjectByChar finds a stage object owned by the given character ID.
func (s *Server) FindObjectByChar(charID uint32) *Object {
s.stagesLock.RLock()
defer s.stagesLock.RUnlock()
for _, stage := range s.stages {
var found *Object
s.stages.Range(func(_ string, stage *Stage) bool {
stage.RLock()
for objId := range stage.objects {
obj := stage.objects[objId]
for _, obj := range stage.objects {
if obj.ownerCharID == charID {
found = obj
stage.RUnlock()
return obj
return false // stop iteration
}
}
stage.RUnlock()
}
return nil
return true
})
return found
}
// HasSemaphore checks if the given session is hosting any semaphore.

View File

@@ -56,7 +56,6 @@ func createTestServer() *Server {
ID: 1,
logger: logger,
sessions: make(map[net.Conn]*Session),
stages: make(map[string]*Stage),
semaphore: make(map[string]*Semaphore),
questCache: NewQuestCache(0),
erupeConfig: &cfg.Config{
@@ -125,7 +124,7 @@ func TestNewServer(t *testing.T) {
}
for _, stageID := range expectedStages {
if _, exists := server.stages[stageID]; !exists {
if _, exists := server.stages.Get(stageID); !exists {
t.Errorf("Default stage %s not initialized", stageID)
}
}
@@ -682,9 +681,7 @@ func TestFindObjectByChar(t *testing.T) {
stage.objects[1] = obj1
stage.objects[2] = obj2
server.stagesLock.Lock()
server.stages["test_stage"] = stage
server.stagesLock.Unlock()
server.stages.Store("test_stage", stage)
tests := []struct {
name string

View File

@@ -7,6 +7,57 @@ import (
"erupe-ce/network/mhfpacket"
)
// StageMap is a concurrent-safe map of stage ID → *Stage backed by sync.Map.
// It replaces the former stagesLock + map[string]*Stage pattern, eliminating
// read contention entirely (reads are lock-free) and allowing concurrent
// writes to disjoint keys.
type StageMap struct {
m sync.Map
}
// Get returns the stage for the given ID, or (nil, false) if not found.
func (sm *StageMap) Get(id string) (*Stage, bool) {
v, ok := sm.m.Load(id)
if !ok {
return nil, false
}
return v.(*Stage), true
}
// GetOrCreate atomically returns the existing stage for id, or creates and
// stores a new one. The second return value is true when a new stage was created.
func (sm *StageMap) GetOrCreate(id string) (*Stage, bool) {
newStage := NewStage(id)
v, loaded := sm.m.LoadOrStore(id, newStage)
return v.(*Stage), !loaded // created == !loaded
}
// StoreIfAbsent stores the stage only if the key does not already exist.
// Returns true if the store succeeded (key was absent).
func (sm *StageMap) StoreIfAbsent(id string, stage *Stage) bool {
_, loaded := sm.m.LoadOrStore(id, stage)
return !loaded
}
// Store unconditionally sets the stage for the given ID.
func (sm *StageMap) Store(id string, stage *Stage) {
sm.m.Store(id, stage)
}
// Delete removes the stage with the given ID.
func (sm *StageMap) Delete(id string) {
sm.m.Delete(id)
}
// Range iterates over all stages. The callback receives each (id, stage) pair
// and should return true to continue iteration or false to stop.
// It is safe to call Delete during iteration.
func (sm *StageMap) Range(fn func(id string, stage *Stage) bool) {
sm.m.Range(func(key, value any) bool {
return fn(key.(string), value.(*Stage))
})
}
// Object holds infomation about a specific object.
type Object struct {
sync.RWMutex

View File

@@ -40,7 +40,7 @@ func createMockServer() *Server {
s := &Server{
logger: logger,
erupeConfig: &cfg.Config{},
stages: make(map[string]*Stage),
// stages is a StageMap (zero value is ready to use)
sessions: make(map[net.Conn]*Session),
handlerTable: buildHandlerTable(),
raviente: &Raviente{