diff --git a/server/channelserver/channel_isolation_test.go b/server/channelserver/channel_isolation_test.go index 9020d35b6..b565982fd 100644 --- a/server/channelserver/channel_isolation_test.go +++ b/server/channelserver/channel_isolation_test.go @@ -192,22 +192,16 @@ func TestChannelIsolation_IndependentStages(t *testing.T) { stageName := "sl1Qs999p0a0u42" // Add stage only to channel 1. - channels[0].stagesLock.Lock() - channels[0].stages[stageName] = NewStage(stageName) - channels[0].stagesLock.Unlock() + channels[0].stages.Store(stageName, NewStage(stageName)) // Channel 1 should have the stage. - channels[0].stagesLock.RLock() - _, ok1 := channels[0].stages[stageName] - channels[0].stagesLock.RUnlock() + _, ok1 := channels[0].stages.Get(stageName) if !ok1 { t.Error("channel 1 should have the stage") } // Channel 2 should NOT have the stage. - channels[1].stagesLock.RLock() - _, ok2 := channels[1].stages[stageName] - channels[1].stagesLock.RUnlock() + _, ok2 := channels[1].stages.Get(stageName) if ok2 { t.Error("channel 2 should not have channel 1's stage") } diff --git a/server/channelserver/channel_registry_local.go b/server/channelserver/channel_registry_local.go index f0239ac6e..15985fb88 100644 --- a/server/channelserver/channel_registry_local.go +++ b/server/channelserver/channel_registry_local.go @@ -56,15 +56,17 @@ func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) { func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string { for _, channel := range r.channels { - channel.stagesLock.RLock() - for id := range channel.stages { + var gid string + channel.stages.Range(func(id string, _ *Stage) bool { if strings.HasSuffix(id, stageSuffix) { - gid := channel.GlobalID - channel.stagesLock.RUnlock() - return gid + gid = channel.GlobalID + return false // stop iteration } + return true + }) + if gid != "" { + return gid } - channel.stagesLock.RUnlock() } return "" } @@ -105,13 +107,14 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage if len(results) >= max { break } - c.stagesLock.RLock() - for _, stage := range c.stages { + cIP := net.ParseIP(c.IP).To4() + cPort := c.Port + c.stages.Range(func(_ string, stage *Stage) bool { if len(results) >= max { - break + return false } if !strings.HasPrefix(stage.id, stagePrefix) { - continue + return true } stage.RLock() bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] @@ -125,8 +128,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage copy(bin3Copy, bin3) results = append(results, StageSnapshot{ - ServerIP: net.ParseIP(c.IP).To4(), - ServerPort: c.Port, + ServerIP: cIP, + ServerPort: cPort, StageID: stage.id, ClientCount: len(stage.clients) + len(stage.reservedClientSlots), Reserved: len(stage.reservedClientSlots), @@ -136,8 +139,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage RawBinData3: bin3Copy, }) stage.RUnlock() - } - c.stagesLock.RUnlock() + return true + }) } return results } diff --git a/server/channelserver/channel_registry_test.go b/server/channelserver/channel_registry_test.go index 2fe6b296c..823320ead 100644 --- a/server/channelserver/channel_registry_test.go +++ b/server/channelserver/channel_registry_test.go @@ -61,9 +61,7 @@ func TestLocalRegistryFindChannelForStage(t *testing.T) { channels[1].GlobalID = "0102" reg := NewLocalChannelRegistry(channels) - channels[1].stagesLock.Lock() - channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42") - channels[1].stagesLock.Unlock() + channels[1].stages.Store("sl2Qs123p0a0u42", NewStage("sl2Qs123p0a0u42")) gid := reg.FindChannelForStage("u42") if gid != "0102" { @@ -136,11 +134,9 @@ func TestLocalRegistrySearchStages(t *testing.T) { channels := createTestChannels(1) reg := NewLocalChannelRegistry(channels) - channels[0].stagesLock.Lock() - channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1") - channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2") - channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other") - channels[0].stagesLock.Unlock() + channels[0].stages.Store("sl2Ls210test1", NewStage("sl2Ls210test1")) + channels[0].stages.Store("sl2Ls210test2", NewStage("sl2Ls210test2")) + channels[0].stages.Store("sl1Ns200other", NewStage("sl1Ns200other")) results := reg.SearchStages("sl2Ls210", 10) if len(results) != 2 { diff --git a/server/channelserver/handlers_clients.go b/server/channelserver/handlers_clients.go index 8f7c1ca17..23f881e79 100644 --- a/server/channelserver/handlers_clients.go +++ b/server/channelserver/handlers_clients.go @@ -10,15 +10,12 @@ import ( func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysEnumerateClient) - s.server.stagesLock.RLock() - stage, ok := s.server.stages[pkt.StageID] + stage, ok := s.server.stages.Get(pkt.StageID) if !ok { - s.server.stagesLock.RUnlock() s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) return } - s.server.stagesLock.RUnlock() // Read-lock the stage and make the response with all of the charID's in the stage. resp := byteframe.NewByteFrame() diff --git a/server/channelserver/handlers_clients_test.go b/server/channelserver/handlers_clients_test.go index 11a82a112..65de21c36 100644 --- a/server/channelserver/handlers_clients_test.go +++ b/server/channelserver/handlers_clients_test.go @@ -34,9 +34,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) { s2.charID = 200 stage.clients[s1] = 100 stage.clients[s2] = 200 - server.stagesLock.Lock() - server.stages[stageID] = stage - server.stagesLock.Unlock() + server.stages.Store(stageID, stage) }, wantClientCount: 2, wantFailure: false, @@ -50,9 +48,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) { stage.reservedClientSlots[100] = false // Not ready stage.reservedClientSlots[200] = true // Ready stage.reservedClientSlots[300] = false // Not ready - server.stagesLock.Lock() - server.stages[stageID] = stage - server.stagesLock.Unlock() + server.stages.Store(stageID, stage) }, wantClientCount: 2, // Only not-ready clients wantFailure: false, @@ -66,9 +62,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) { stage.reservedClientSlots[100] = false // Not ready stage.reservedClientSlots[200] = true // Ready stage.reservedClientSlots[300] = true // Ready - server.stagesLock.Lock() - server.stages[stageID] = stage - server.stagesLock.Unlock() + server.stages.Store(stageID, stage) }, wantClientCount: 2, // Only ready clients wantFailure: false, @@ -79,9 +73,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) { getType: 0, setupStage: func(server *Server, stageID string) { stage := NewStage(stageID) - server.stagesLock.Lock() - server.stages[stageID] = stage - server.stagesLock.Unlock() + server.stages.Store(stageID, stage) }, wantClientCount: 0, wantFailure: false, @@ -104,11 +96,6 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - // Initialize stages map if needed - if s.server.stages == nil { - s.server.stages = make(map[string]*Stage) - } - // Setup stage tt.setupStage(s.server, tt.stageID) @@ -389,7 +376,6 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) { logger, _ := zap.NewDevelopment() server := &Server{ logger: logger, - stages: make(map[string]*Stage), erupeConfig: &cfg.Config{ DebugOptions: cfg.DebugOptions{ LogOutboundMessages: false, @@ -408,9 +394,7 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) { stage.clients[sess] = i * 100 } - server.stagesLock.Lock() - server.stages[stageID] = stage - server.stagesLock.Unlock() + server.stages.Store(stageID, stage) // Run concurrent enumerations done := make(chan bool, 5) @@ -562,7 +546,6 @@ func BenchmarkEnumerateClients(b *testing.B) { logger, _ := zap.NewDevelopment() server := &Server{ logger: logger, - stages: make(map[string]*Stage), } stageID := "bench_stage" @@ -576,7 +559,7 @@ func BenchmarkEnumerateClients(b *testing.B) { stage.clients[sess] = i } - server.stages[stageID] = stage + server.stages.Store(stageID, stage) mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) diff --git a/server/channelserver/handlers_core_test.go b/server/channelserver/handlers_core_test.go index f07a0d016..13862da8d 100644 --- a/server/channelserver/handlers_core_test.go +++ b/server/channelserver/handlers_core_test.go @@ -600,9 +600,8 @@ func TestHandleMsgSysLockGlobalSema_WithChannel(t *testing.T) { // Create a mock channel with stages channel := &Server{ GlobalID: "other-server", - stages: make(map[string]*Stage), } - channel.stages["stage_user123"] = NewStage("stage_user123") + channel.stages.Store("stage_user123", NewStage("stage_user123")) server.Channels = []*Server{channel} session := createMockSession(1, server) @@ -632,9 +631,8 @@ func TestHandleMsgSysLockGlobalSema_SameServer(t *testing.T) { // Create a mock channel with same GlobalID channel := &Server{ GlobalID: "test-server", - stages: make(map[string]*Stage), } - channel.stages["stage_user456"] = NewStage("stage_user456") + channel.stages.Store("stage_user456", NewStage("stage_user456")) server.Channels = []*Server{channel} session := createMockSession(1, server) diff --git a/server/channelserver/handlers_coverage3_test.go b/server/channelserver/handlers_coverage3_test.go index 495234723..fd2ec7c8d 100644 --- a/server/channelserver/handlers_coverage3_test.go +++ b/server/channelserver/handlers_coverage3_test.go @@ -984,7 +984,7 @@ func TestHandleMsgSysCreateStage_Coverage3(t *testing.T) { default: t.Error("no response queued") } - if _, exists := server.stages["test_create_stage"]; !exists { + if _, exists := server.stages.Get("test_create_stage"); !exists { t.Error("stage should have been created") } }) diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index b31d532da..8717a9150 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -290,7 +290,7 @@ func logoutPlayer(s *Session) { _ = s.rawConn.Close() s.server.Unlock() - // Stage cleanup — snapshot sessions first under server mutex, then iterate stages under stagesLock + // Stage cleanup — snapshot sessions first under server mutex, then iterate stages s.server.Lock() sessionSnapshot := make([]*Session, 0, len(s.server.sessions)) for _, sess := range s.server.sessions { @@ -298,8 +298,7 @@ func logoutPlayer(s *Session) { } s.server.Unlock() - s.server.stagesLock.RLock() - for _, stage := range s.server.stages { + s.server.stages.Range(func(_ string, stage *Stage) bool { stage.Lock() // Tell sessions registered to disconnecting player's quest to unregister if stage.host != nil && stage.host.charID == s.charID { @@ -317,8 +316,8 @@ func logoutPlayer(s *Session) { } } stage.Unlock() - } - s.server.stagesLock.RUnlock() + return true + }) // Update sign sessions and server player count if s.server.db != nil { @@ -346,13 +345,12 @@ func logoutPlayer(s *Session) { CharID: s.charID, }, s) - s.server.stagesLock.RLock() - for _, stage := range s.server.stages { + s.server.stages.Range(func(_ string, stage *Stage) bool { stage.Lock() delete(stage.reservedClientSlots, s.charID) stage.Unlock() - } - s.server.stagesLock.RUnlock() + return true + }) removeSessionFromSemaphore(s) removeSessionFromStage(s) @@ -449,13 +447,12 @@ func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) { sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString) } else { for _, channel := range s.server.Channels { - channel.stagesLock.RLock() - for id := range channel.stages { + channel.stages.Range(func(id string, _ *Stage) bool { if strings.HasSuffix(id, pkt.UserIDString) { sgid = channel.GlobalID } - } - channel.stagesLock.RUnlock() + return true + }) } } bf := byteframe.NewByteFrame() @@ -689,10 +686,11 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { if count == maxResults { break } - c.stagesLock.RLock() - for _, stage := range c.stages { + cIP := net.ParseIP(c.IP).To4() + cPort := c.Port + c.stages.Range(func(_ string, stage *Stage) bool { if count == maxResults { - break + return false } if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) { stage.RLock() @@ -718,7 +716,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { if findPartyParams.RankRestriction >= 0 { if stageData[0] > findPartyParams.RankRestriction { stage.RUnlock() - continue + return true } } @@ -732,7 +730,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } if !hasTarget { stage.RUnlock() - continue + return true } } @@ -746,8 +744,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { count++ stageResults = append(stageResults, stageResult{ - ip: net.ParseIP(c.IP).To4(), - port: c.Port, + ip: cIP, + port: cPort, clientCount: len(stage.clients) + len(stage.reservedClientSlots), reserved: len(stage.reservedClientSlots), maxPlayers: stage.maxPlayers, @@ -758,8 +756,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { }) stage.RUnlock() } - } - c.stagesLock.RUnlock() + return true + }) } for _, sr := range stageResults { diff --git a/server/channelserver/handlers_stage.go b/server/channelserver/handlers_stage.go index 65b7ce790..4ce68e4db 100644 --- a/server/channelserver/handlers_stage.go +++ b/server/channelserver/handlers_stage.go @@ -14,32 +14,23 @@ import ( func handleMsgSysCreateStage(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysCreateStage) - s.server.stagesLock.Lock() - defer s.server.stagesLock.Unlock() - if _, exists := s.server.stages[pkt.StageID]; exists { - doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00}) - } else { - stage := NewStage(pkt.StageID) - stage.host = s - stage.maxPlayers = uint16(pkt.PlayerCount) - s.server.stages[stage.id] = stage + stage := NewStage(pkt.StageID) + stage.host = s + stage.maxPlayers = uint16(pkt.PlayerCount) + if s.server.stages.StoreIfAbsent(pkt.StageID, stage) { doAckSimpleSucceed(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00}) + } else { + doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00}) } } func handleMsgSysStageDestruct(s *Session, p mhfpacket.MHFPacket) {} func doStageTransfer(s *Session, ackHandle uint32, stageID string) { - s.server.stagesLock.Lock() - stage, exists := s.server.stages[stageID] - if !exists { - s.server.stages[stageID] = NewStage(stageID) - stage = s.server.stages[stageID] - } - s.server.stagesLock.Unlock() + stage, created := s.server.stages.GetOrCreate(stageID) stage.Lock() - if !exists { + if created { stage.host = s } stage.clients[s] = s.charID @@ -50,12 +41,9 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) { removeSessionFromStage(s) } - // Save our new stage ID and pointer to the new stage itself. - s.server.stagesLock.RLock() - newStage := s.server.stages[stageID] - s.server.stagesLock.RUnlock() + // Save our new stage pointer. s.Lock() - s.stage = newStage + s.stage = stage s.Unlock() // Tell the client to cleanup its current stage objects. @@ -140,22 +128,20 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) { } func destructEmptyStages(s *Session) { - s.server.stagesLock.Lock() - defer s.server.stagesLock.Unlock() - for _, stage := range s.server.stages { + s.server.stages.Range(func(id string, stage *Stage) bool { // Destroy empty Quest/My series/Guild stages. - if stage.id[3:5] == "Qs" || stage.id[3:5] == "Ms" || stage.id[3:5] == "Gs" || stage.id[3:5] == "Ls" { - // Lock stage to safely check its client and reservation counts + if id[3:5] == "Qs" || id[3:5] == "Ms" || id[3:5] == "Gs" || id[3:5] == "Ls" { stage.Lock() isEmpty := len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 stage.Unlock() if isEmpty { - delete(s.server.stages, stage.id) - s.logger.Debug("Destructed stage", zap.String("stage.id", stage.id)) + s.server.stages.Delete(id) + s.logger.Debug("Destructed stage", zap.String("stage.id", id)) } } - } + return true + }) } func removeSessionFromStage(s *Session) { @@ -194,9 +180,7 @@ func removeSessionFromStage(s *Session) { } func isStageFull(s *Session, StageID string) bool { - s.server.stagesLock.RLock() - stage, exists := s.server.stages[StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(StageID) if exists { // Lock stage to safely check client counts @@ -261,9 +245,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) { s.stage.Unlock() } - s.server.stagesLock.RLock() - backStagePtr, exists := s.server.stages[backStage] - s.server.stagesLock.RUnlock() + backStagePtr, exists := s.server.stages.Get(backStage) if exists { backStagePtr.Lock() delete(backStagePtr.reservedClientSlots, s.charID) @@ -288,9 +270,7 @@ func handleMsgSysLeaveStage(s *Session, p mhfpacket.MHFPacket) {} func handleMsgSysLockStage(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysLockStage) - s.server.stagesLock.RLock() - stage, exists := s.server.stages[pkt.StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(pkt.StageID) if exists { stage.Lock() stage.locked = true @@ -317,10 +297,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) { } } - // Delete from stages map under stagesLock (not nested inside stage RLock) - s.server.stagesLock.Lock() - delete(s.server.stages, stageID) - s.server.stagesLock.Unlock() + s.server.stages.Delete(stageID) } destructEmptyStages(s) @@ -328,9 +305,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysReserveStage(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysReserveStage) - s.server.stagesLock.RLock() - stage, exists := s.server.stages[pkt.StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(pkt.StageID) if exists { stage.Lock() defer stage.Unlock() @@ -402,9 +377,7 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysSetStageBinary) - s.server.stagesLock.RLock() - stage, exists := s.server.stages[pkt.StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(pkt.StageID) if exists { stage.Lock() stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] = pkt.RawDataPayload @@ -416,9 +389,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysGetStageBinary) - s.server.stagesLock.RLock() - stage, exists := s.server.stages[pkt.StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(pkt.StageID) if exists { stage.Lock() if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists { @@ -443,9 +414,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysWaitStageBinary) - s.server.stagesLock.RLock() - stage, exists := s.server.stages[pkt.StageID] - s.server.stagesLock.RUnlock() + stage, exists := s.server.stages.Get(pkt.StageID) if exists { if pkt.BinaryType0 == 1 && pkt.BinaryType1 == 12 { // This might contain the hunter count, or max player count? @@ -479,24 +448,20 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysEnumerateStage) - // Read-lock the server stage map. - s.server.stagesLock.RLock() - defer s.server.stagesLock.RUnlock() - // Build the response bf := byteframe.NewByteFrame() var joinable uint16 bf.WriteUint16(0) - for sid, stage := range s.server.stages { + s.server.stages.Range(func(sid string, stage *Stage) bool { stage.RLock() if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 { stage.RUnlock() - continue + return true } if !strings.Contains(stage.id, pkt.StagePrefix) { stage.RUnlock() - continue + return true } joinable++ @@ -518,7 +483,8 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) { bf.WriteUint8(flags) ps.Uint8(bf, sid, false) stage.RUnlock() - } + return true + }) _, _ = bf.Seek(0, 0) bf.WriteUint16(joinable) diff --git a/server/channelserver/handlers_stage_test.go b/server/channelserver/handlers_stage_test.go index 79758222b..8bbe0bc5d 100644 --- a/server/channelserver/handlers_stage_test.go +++ b/server/channelserver/handlers_stage_test.go @@ -17,7 +17,7 @@ const raceTestCompletionMsg = "Test completed. No race conditions with fixed loc func TestCreateStageSuccess(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + // Create a new stage pkt := &mhfpacket.MsgSysCreateStage{ @@ -29,11 +29,10 @@ func TestCreateStageSuccess(t *testing.T) { handleMsgSysCreateStage(s, pkt) // Verify stage was created - if _, exists := s.server.stages["test_stage_1"]; !exists { + stage, exists := s.server.stages.Get("test_stage_1") + if !exists { t.Error("stage was not created") } - - stage := s.server.stages["test_stage_1"] if stage.id != "test_stage_1" { t.Errorf("stage ID mismatch: got %s, want test_stage_1", stage.id) } @@ -46,7 +45,7 @@ func TestCreateStageSuccess(t *testing.T) { func TestCreateStageDuplicate(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + // Create first stage pkt1 := &mhfpacket.MsgSysCreateStage{ @@ -65,8 +64,10 @@ func TestCreateStageDuplicate(t *testing.T) { handleMsgSysCreateStage(s, pkt2) // Verify only one stage exists - if len(s.server.stages) != 1 { - t.Errorf("expected 1 stage, got %d", len(s.server.stages)) + count := 0 + s.server.stages.Range(func(_ string, _ *Stage) bool { count++; return true }) + if count != 1 { + t.Errorf("expected 1 stage, got %d", count) } } @@ -74,13 +75,13 @@ func TestCreateStageDuplicate(t *testing.T) { func TestStageLocking(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + // Create a stage stage := NewStage("locked_stage") stage.host = s stage.password = "" - s.server.stages["locked_stage"] = stage + s.server.stages.Store("locked_stage", stage) // Lock the stage pkt := &mhfpacket.MsgSysLockStage{ @@ -103,14 +104,14 @@ func TestStageLocking(t *testing.T) { func TestStageReservation(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + // Create a stage stage := NewStage("reserved_stage") stage.host = s stage.reservedClientSlots = make(map[uint32]bool) stage.reservedClientSlots[s.charID] = false // Pre-add the charID so reservation works - s.server.stages["reserved_stage"] = stage + s.server.stages.Store("reserved_stage", stage) // Reserve the stage pkt := &mhfpacket.MsgSysReserveStage{ @@ -163,8 +164,8 @@ func TestStageBinaryData(t *testing.T) { stage := NewStage("binary_stage") stage.rawBinaryData = make(map[stageBinaryKey][]byte) s.stage = stage - s.server.stages = make(map[string]*Stage) - s.server.stages["binary_stage"] = stage + + s.server.stages.Store("binary_stage", stage) // Store binary data directly key := stageBinaryKey{id0: byte(s.charID >> 8), id1: byte(s.charID & 0xFF)} @@ -230,8 +231,8 @@ func TestIsStageFull(t *testing.T) { stage.clients[client] = uint32(i) } - s.server.stages = make(map[string]*Stage) - s.server.stages["full_test_stage"] = stage + + s.server.stages.Store("full_test_stage", stage) result := isStageFull(s, "full_test_stage") if result != tt.wantFull { @@ -245,14 +246,14 @@ func TestIsStageFull(t *testing.T) { func TestEnumerateStage(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) // Create multiple stages for i := 0; i < 3; i++ { stage := NewStage("stage_" + string(rune(i))) stage.maxPlayers = 4 - s.server.stages[stage.id] = stage + s.server.stages.Store(stage.id, stage) } // Enumerate stages @@ -264,8 +265,10 @@ func TestEnumerateStage(t *testing.T) { // Basic verification that enumeration was processed // In a real test, we'd verify the response packet content - if len(s.server.stages) != 3 { - t.Errorf("expected 3 stages, got %d", len(s.server.stages)) + stageCount := 0 + s.server.stages.Range(func(_ string, _ *Stage) bool { stageCount++; return true }) + if stageCount != 3 { + t.Errorf("expected 3 stages, got %d", stageCount) } } @@ -279,8 +282,8 @@ func TestRemoveSessionFromStage(t *testing.T) { stage.clients[s] = s.charID s.stage = stage - s.server.stages = make(map[string]*Stage) - s.server.stages["removal_stage"] = stage + + s.server.stages.Store("removal_stage", stage) // Remove session removeSessionFromStage(s) @@ -299,18 +302,18 @@ func TestRemoveSessionFromStage(t *testing.T) { func TestDestructEmptyStages(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + // Create stages with different client counts emptyStage := NewStage("empty_stage") emptyStage.clients = make(map[*Session]uint32) emptyStage.host = s // Host needs to be set or it won't be destructed - s.server.stages["empty_stage"] = emptyStage + s.server.stages.Store("empty_stage", emptyStage) populatedStage := NewStage("populated_stage") populatedStage.clients = make(map[*Session]uint32) populatedStage.clients[s] = s.charID - s.server.stages["populated_stage"] = populatedStage + s.server.stages.Store("populated_stage", populatedStage) // Destruct empty stages (from the channel server's perspective, not our session's) // The function destructs stages that are not referenced by us or don't have clients @@ -318,8 +321,10 @@ func TestDestructEmptyStages(t *testing.T) { // For this test to work correctly, we'd need to verify the actual removal // Let's just verify the stages exist first - if len(s.server.stages) != 2 { - t.Errorf("expected 2 stages initially, got %d", len(s.server.stages)) + initialCount := 0 + s.server.stages.Range(func(_ string, _ *Stage) bool { initialCount++; return true }) + if initialCount != 2 { + t.Errorf("expected 2 stages initially, got %d", initialCount) } } @@ -327,14 +332,14 @@ func TestDestructEmptyStages(t *testing.T) { func TestStageTransferBasic(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) // Transfer to non-existent stage (should create it) doStageTransfer(s, 0x12345678, "new_transfer_stage") // Verify stage was created - if stage, exists := s.server.stages["new_transfer_stage"]; !exists { + if stage, exists := s.server.stages.Get("new_transfer_stage"); !exists { t.Error("stage was not created during transfer") } else { // Verify session is in the stage @@ -357,12 +362,12 @@ func TestStageTransferBasic(t *testing.T) { func TestEnterStageBasic(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) stage := NewStage("entry_stage") stage.clients = make(map[*Session]uint32) - s.server.stages["entry_stage"] = stage + s.server.stages.Store("entry_stage", stage) pkt := &mhfpacket.MsgSysEnterStage{ StageID: "entry_stage", @@ -383,7 +388,7 @@ func TestEnterStageBasic(t *testing.T) { func TestMoveStagePreservesData(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) // Create source stage with binary data @@ -392,13 +397,13 @@ func TestMoveStagePreservesData(t *testing.T) { sourceStage.rawBinaryData = make(map[stageBinaryKey][]byte) key := stageBinaryKey{id0: 0x00, id1: 0x01} sourceStage.rawBinaryData[key] = []byte{0xAA, 0xBB} - s.server.stages["source_stage"] = sourceStage + s.server.stages.Store("source_stage", sourceStage) s.stage = sourceStage // Create destination stage destStage := NewStage("dest_stage") destStage.clients = make(map[*Session]uint32) - s.server.stages["dest_stage"] = destStage + s.server.stages.Store("dest_stage", destStage) pkt := &mhfpacket.MsgSysMoveStage{ StageID: "dest_stage", @@ -417,12 +422,12 @@ func TestMoveStagePreservesData(t *testing.T) { func TestConcurrentStageOperations(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} baseSession := createTestSession(mock) - baseSession.server.stages = make(map[string]*Stage) + // Create a stage stage := NewStage("concurrent_stage") stage.clients = make(map[*Session]uint32) - baseSession.server.stages["concurrent_stage"] = stage + baseSession.server.stages.Store("concurrent_stage", stage) var wg sync.WaitGroup @@ -459,7 +464,7 @@ func TestConcurrentStageOperations(t *testing.T) { func TestBackStageNavigation(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) // Create a stringstack for stage move history @@ -472,8 +477,8 @@ func TestBackStageNavigation(t *testing.T) { stage2 := NewStage("stage_2") stage2.clients = make(map[*Session]uint32) - s.server.stages["stage_1"] = stage1 - s.server.stages["stage_2"] = stage2 + s.server.stages.Store("stage_1", stage1) + s.server.stages.Store("stage_2", stage2) // First enter stage 2 and push to stack s.stage = stage2 @@ -502,13 +507,13 @@ func TestRaceConditionRemoveSessionFromStageNotLocked(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} s := createTestSession(mock) - s.server.stages = make(map[string]*Stage) + s.server.sessions = make(map[net.Conn]*Session) stage := NewStage("race_test_stage") stage.clients = make(map[*Session]uint32) stage.objects = make(map[uint32]*Object) - s.server.stages["race_test_stage"] = stage + s.server.stages.Store("race_test_stage", stage) s.stage = stage stage.clients[s] = s.charID @@ -567,14 +572,14 @@ func TestRaceConditionDoStageTransferUnlockedAccess(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} baseSession := createTestSession(mock) - baseSession.server.stages = make(map[string]*Stage) + baseSession.server.sessions = make(map[net.Conn]*Session) // Create initial stage stage := NewStage("initial_stage") stage.clients = make(map[*Session]uint32) stage.objects = make(map[uint32]*Object) - baseSession.server.stages["initial_stage"] = stage + baseSession.server.stages.Store("initial_stage", stage) baseSession.stage = stage stage.clients[baseSession] = baseSession.charID @@ -631,13 +636,13 @@ func TestRaceConditionStageObjectsIteration(t *testing.T) { mock := &MockCryptConn{sentPackets: make([][]byte, 0)} baseSession := createTestSession(mock) - baseSession.server.stages = make(map[string]*Stage) + baseSession.server.sessions = make(map[net.Conn]*Session) stage := NewStage("object_race_stage") stage.clients = make(map[*Session]uint32) stage.objects = make(map[uint32]*Object) - baseSession.server.stages["object_race_stage"] = stage + baseSession.server.stages.Store("object_race_stage", stage) baseSession.stage = stage stage.clients[baseSession] = baseSession.charID diff --git a/server/channelserver/session_lifecycle_integration_test.go b/server/channelserver/session_lifecycle_integration_test.go index 6e082d4d9..bfe1d36cb 100644 --- a/server/channelserver/session_lifecycle_integration_test.go +++ b/server/channelserver/session_lifecycle_integration_test.go @@ -582,7 +582,6 @@ func createTestServerWithDB(t *testing.T, db *sqlx.DB) *Server { server := &Server{ db: db, sessions: make(map[net.Conn]*Session), - stages: make(map[string]*Stage), userBinary: NewUserBinaryStore(), minidata: NewMinidataStore(), semaphore: make(map[string]*Semaphore), diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index 60b7b8a52..cdb695f08 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -33,9 +33,11 @@ type Config struct { // // Lock ordering (acquire in this order to avoid deadlocks): // 1. Server.Mutex – protects sessions map -// 2. Server.stagesLock – protects stages map -// 3. Stage.RWMutex – protects per-stage state (clients, objects) -// 4. Server.semaphoreLock – protects semaphore map +// 2. Stage.RWMutex – protects per-stage state (clients, objects) +// 3. Server.semaphoreLock – protects semaphore map +// +// Note: Server.stages is a StageMap (sync.Map-backed), so it requires no +// external lock for reads or writes. // // Self-contained stores (userBinary, minidata, questCache) manage their // own locks internally and may be acquired at any point. @@ -78,8 +80,7 @@ type Server struct { isShuttingDown bool done chan struct{} // Closed on Shutdown to wake background goroutines. - stagesLock sync.RWMutex - stages map[string]*Stage + stages StageMap // Used to map different languages i18n i18n @@ -115,7 +116,6 @@ func NewServer(config *Config) *Server { deleteConns: make(chan net.Conn), done: make(chan struct{}), sessions: make(map[net.Conn]*Session), - stages: make(map[string]*Stage), userBinary: NewUserBinaryStore(), minidata: NewMinidataStore(), semaphore: make(map[string]*Semaphore), @@ -155,25 +155,25 @@ func NewServer(config *Config) *Server { s.mercenaryRepo = NewMercenaryRepository(config.DB) // Mezeporta - s.stages["sl1Ns200p0a0u0"] = NewStage("sl1Ns200p0a0u0") + s.stages.Store("sl1Ns200p0a0u0", NewStage("sl1Ns200p0a0u0")) // Rasta bar stage - s.stages["sl1Ns211p0a0u0"] = NewStage("sl1Ns211p0a0u0") + s.stages.Store("sl1Ns211p0a0u0", NewStage("sl1Ns211p0a0u0")) // Pallone Carvan - s.stages["sl1Ns260p0a0u0"] = NewStage("sl1Ns260p0a0u0") + s.stages.Store("sl1Ns260p0a0u0", NewStage("sl1Ns260p0a0u0")) // Pallone Guest House 1st Floor - s.stages["sl1Ns262p0a0u0"] = NewStage("sl1Ns262p0a0u0") + s.stages.Store("sl1Ns262p0a0u0", NewStage("sl1Ns262p0a0u0")) // Pallone Guest House 2nd Floor - s.stages["sl1Ns263p0a0u0"] = NewStage("sl1Ns263p0a0u0") + s.stages.Store("sl1Ns263p0a0u0", NewStage("sl1Ns263p0a0u0")) // Diva fountain / prayer fountain. - s.stages["sl2Ns379p0a0u0"] = NewStage("sl2Ns379p0a0u0") + s.stages.Store("sl2Ns379p0a0u0", NewStage("sl2Ns379p0a0u0")) // MezFes - s.stages["sl1Ns462p0a0u0"] = NewStage("sl1Ns462p0a0u0") + s.stages.Store("sl1Ns462p0a0u0", NewStage("sl1Ns462p0a0u0")) s.i18n = getLangStrings(s) @@ -424,21 +424,20 @@ func (s *Server) DisconnectUser(uid uint32) { // FindObjectByChar finds a stage object owned by the given character ID. func (s *Server) FindObjectByChar(charID uint32) *Object { - s.stagesLock.RLock() - defer s.stagesLock.RUnlock() - for _, stage := range s.stages { + var found *Object + s.stages.Range(func(_ string, stage *Stage) bool { stage.RLock() - for objId := range stage.objects { - obj := stage.objects[objId] + for _, obj := range stage.objects { if obj.ownerCharID == charID { + found = obj stage.RUnlock() - return obj + return false // stop iteration } } stage.RUnlock() - } - - return nil + return true + }) + return found } // HasSemaphore checks if the given session is hosting any semaphore. diff --git a/server/channelserver/sys_channel_server_test.go b/server/channelserver/sys_channel_server_test.go index 18f539094..056d69a27 100644 --- a/server/channelserver/sys_channel_server_test.go +++ b/server/channelserver/sys_channel_server_test.go @@ -56,7 +56,6 @@ func createTestServer() *Server { ID: 1, logger: logger, sessions: make(map[net.Conn]*Session), - stages: make(map[string]*Stage), semaphore: make(map[string]*Semaphore), questCache: NewQuestCache(0), erupeConfig: &cfg.Config{ @@ -125,7 +124,7 @@ func TestNewServer(t *testing.T) { } for _, stageID := range expectedStages { - if _, exists := server.stages[stageID]; !exists { + if _, exists := server.stages.Get(stageID); !exists { t.Errorf("Default stage %s not initialized", stageID) } } @@ -682,9 +681,7 @@ func TestFindObjectByChar(t *testing.T) { stage.objects[1] = obj1 stage.objects[2] = obj2 - server.stagesLock.Lock() - server.stages["test_stage"] = stage - server.stagesLock.Unlock() + server.stages.Store("test_stage", stage) tests := []struct { name string diff --git a/server/channelserver/sys_stage.go b/server/channelserver/sys_stage.go index 54aea8909..b5ef3be35 100644 --- a/server/channelserver/sys_stage.go +++ b/server/channelserver/sys_stage.go @@ -7,6 +7,57 @@ import ( "erupe-ce/network/mhfpacket" ) +// StageMap is a concurrent-safe map of stage ID → *Stage backed by sync.Map. +// It replaces the former stagesLock + map[string]*Stage pattern, eliminating +// read contention entirely (reads are lock-free) and allowing concurrent +// writes to disjoint keys. +type StageMap struct { + m sync.Map +} + +// Get returns the stage for the given ID, or (nil, false) if not found. +func (sm *StageMap) Get(id string) (*Stage, bool) { + v, ok := sm.m.Load(id) + if !ok { + return nil, false + } + return v.(*Stage), true +} + +// GetOrCreate atomically returns the existing stage for id, or creates and +// stores a new one. The second return value is true when a new stage was created. +func (sm *StageMap) GetOrCreate(id string) (*Stage, bool) { + newStage := NewStage(id) + v, loaded := sm.m.LoadOrStore(id, newStage) + return v.(*Stage), !loaded // created == !loaded +} + +// StoreIfAbsent stores the stage only if the key does not already exist. +// Returns true if the store succeeded (key was absent). +func (sm *StageMap) StoreIfAbsent(id string, stage *Stage) bool { + _, loaded := sm.m.LoadOrStore(id, stage) + return !loaded +} + +// Store unconditionally sets the stage for the given ID. +func (sm *StageMap) Store(id string, stage *Stage) { + sm.m.Store(id, stage) +} + +// Delete removes the stage with the given ID. +func (sm *StageMap) Delete(id string) { + sm.m.Delete(id) +} + +// Range iterates over all stages. The callback receives each (id, stage) pair +// and should return true to continue iteration or false to stop. +// It is safe to call Delete during iteration. +func (sm *StageMap) Range(fn func(id string, stage *Stage) bool) { + sm.m.Range(func(key, value any) bool { + return fn(key.(string), value.(*Stage)) + }) +} + // Object holds infomation about a specific object. type Object struct { sync.RWMutex diff --git a/server/channelserver/test_helpers_test.go b/server/channelserver/test_helpers_test.go index 99bbbf08d..8b5513ef7 100644 --- a/server/channelserver/test_helpers_test.go +++ b/server/channelserver/test_helpers_test.go @@ -40,7 +40,7 @@ func createMockServer() *Server { s := &Server{ logger: logger, erupeConfig: &cfg.Config{}, - stages: make(map[string]*Stage), + // stages is a StageMap (zero value is ready to use) sessions: make(map[net.Conn]*Session), handlerTable: buildHandlerTable(), raviente: &Raviente{