mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-22 07:32:32 +01:00
refactor(channelserver): replace global stagesLock with sync.Map-backed StageMap
The global stagesLock sync.RWMutex protected map[string]*Stage, causing all stage operations to contend on a single lock even for unrelated stages. Any stage creation or deletion blocked all reads server-wide. Replace with a typed StageMap wrapper around sync.Map which provides lock-free reads and allows concurrent writes to disjoint keys. Per-stage sync.RWMutex remains unchanged for protecting individual stage state. StageMap exposes Get, GetOrCreate, StoreIfAbsent, Store, Delete, and Range methods. Updated ~50 call sites across 6 production files and 9 test files.
This commit is contained in:
@@ -192,22 +192,16 @@ func TestChannelIsolation_IndependentStages(t *testing.T) {
|
|||||||
stageName := "sl1Qs999p0a0u42"
|
stageName := "sl1Qs999p0a0u42"
|
||||||
|
|
||||||
// Add stage only to channel 1.
|
// Add stage only to channel 1.
|
||||||
channels[0].stagesLock.Lock()
|
channels[0].stages.Store(stageName, NewStage(stageName))
|
||||||
channels[0].stages[stageName] = NewStage(stageName)
|
|
||||||
channels[0].stagesLock.Unlock()
|
|
||||||
|
|
||||||
// Channel 1 should have the stage.
|
// Channel 1 should have the stage.
|
||||||
channels[0].stagesLock.RLock()
|
_, ok1 := channels[0].stages.Get(stageName)
|
||||||
_, ok1 := channels[0].stages[stageName]
|
|
||||||
channels[0].stagesLock.RUnlock()
|
|
||||||
if !ok1 {
|
if !ok1 {
|
||||||
t.Error("channel 1 should have the stage")
|
t.Error("channel 1 should have the stage")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Channel 2 should NOT have the stage.
|
// Channel 2 should NOT have the stage.
|
||||||
channels[1].stagesLock.RLock()
|
_, ok2 := channels[1].stages.Get(stageName)
|
||||||
_, ok2 := channels[1].stages[stageName]
|
|
||||||
channels[1].stagesLock.RUnlock()
|
|
||||||
if ok2 {
|
if ok2 {
|
||||||
t.Error("channel 2 should not have channel 1's stage")
|
t.Error("channel 2 should not have channel 1's stage")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -56,15 +56,17 @@ func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) {
|
|||||||
|
|
||||||
func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string {
|
func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string {
|
||||||
for _, channel := range r.channels {
|
for _, channel := range r.channels {
|
||||||
channel.stagesLock.RLock()
|
var gid string
|
||||||
for id := range channel.stages {
|
channel.stages.Range(func(id string, _ *Stage) bool {
|
||||||
if strings.HasSuffix(id, stageSuffix) {
|
if strings.HasSuffix(id, stageSuffix) {
|
||||||
gid := channel.GlobalID
|
gid = channel.GlobalID
|
||||||
channel.stagesLock.RUnlock()
|
return false // stop iteration
|
||||||
return gid
|
|
||||||
}
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if gid != "" {
|
||||||
|
return gid
|
||||||
}
|
}
|
||||||
channel.stagesLock.RUnlock()
|
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
@@ -105,13 +107,14 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
|
|||||||
if len(results) >= max {
|
if len(results) >= max {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
c.stagesLock.RLock()
|
cIP := net.ParseIP(c.IP).To4()
|
||||||
for _, stage := range c.stages {
|
cPort := c.Port
|
||||||
|
c.stages.Range(func(_ string, stage *Stage) bool {
|
||||||
if len(results) >= max {
|
if len(results) >= max {
|
||||||
break
|
return false
|
||||||
}
|
}
|
||||||
if !strings.HasPrefix(stage.id, stagePrefix) {
|
if !strings.HasPrefix(stage.id, stagePrefix) {
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
stage.RLock()
|
stage.RLock()
|
||||||
bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}]
|
bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}]
|
||||||
@@ -125,8 +128,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
|
|||||||
copy(bin3Copy, bin3)
|
copy(bin3Copy, bin3)
|
||||||
|
|
||||||
results = append(results, StageSnapshot{
|
results = append(results, StageSnapshot{
|
||||||
ServerIP: net.ParseIP(c.IP).To4(),
|
ServerIP: cIP,
|
||||||
ServerPort: c.Port,
|
ServerPort: cPort,
|
||||||
StageID: stage.id,
|
StageID: stage.id,
|
||||||
ClientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
ClientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
||||||
Reserved: len(stage.reservedClientSlots),
|
Reserved: len(stage.reservedClientSlots),
|
||||||
@@ -136,8 +139,8 @@ func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []Stage
|
|||||||
RawBinData3: bin3Copy,
|
RawBinData3: bin3Copy,
|
||||||
})
|
})
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
}
|
return true
|
||||||
c.stagesLock.RUnlock()
|
})
|
||||||
}
|
}
|
||||||
return results
|
return results
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -61,9 +61,7 @@ func TestLocalRegistryFindChannelForStage(t *testing.T) {
|
|||||||
channels[1].GlobalID = "0102"
|
channels[1].GlobalID = "0102"
|
||||||
reg := NewLocalChannelRegistry(channels)
|
reg := NewLocalChannelRegistry(channels)
|
||||||
|
|
||||||
channels[1].stagesLock.Lock()
|
channels[1].stages.Store("sl2Qs123p0a0u42", NewStage("sl2Qs123p0a0u42"))
|
||||||
channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42")
|
|
||||||
channels[1].stagesLock.Unlock()
|
|
||||||
|
|
||||||
gid := reg.FindChannelForStage("u42")
|
gid := reg.FindChannelForStage("u42")
|
||||||
if gid != "0102" {
|
if gid != "0102" {
|
||||||
@@ -136,11 +134,9 @@ func TestLocalRegistrySearchStages(t *testing.T) {
|
|||||||
channels := createTestChannels(1)
|
channels := createTestChannels(1)
|
||||||
reg := NewLocalChannelRegistry(channels)
|
reg := NewLocalChannelRegistry(channels)
|
||||||
|
|
||||||
channels[0].stagesLock.Lock()
|
channels[0].stages.Store("sl2Ls210test1", NewStage("sl2Ls210test1"))
|
||||||
channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1")
|
channels[0].stages.Store("sl2Ls210test2", NewStage("sl2Ls210test2"))
|
||||||
channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2")
|
channels[0].stages.Store("sl1Ns200other", NewStage("sl1Ns200other"))
|
||||||
channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other")
|
|
||||||
channels[0].stagesLock.Unlock()
|
|
||||||
|
|
||||||
results := reg.SearchStages("sl2Ls210", 10)
|
results := reg.SearchStages("sl2Ls210", 10)
|
||||||
if len(results) != 2 {
|
if len(results) != 2 {
|
||||||
|
|||||||
@@ -10,15 +10,12 @@ import (
|
|||||||
func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysEnumerateClient)
|
pkt := p.(*mhfpacket.MsgSysEnumerateClient)
|
||||||
|
|
||||||
s.server.stagesLock.RLock()
|
stage, ok := s.server.stages.Get(pkt.StageID)
|
||||||
stage, ok := s.server.stages[pkt.StageID]
|
|
||||||
if !ok {
|
if !ok {
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID))
|
s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID))
|
||||||
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
|
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
|
|
||||||
// Read-lock the stage and make the response with all of the charID's in the stage.
|
// Read-lock the stage and make the response with all of the charID's in the stage.
|
||||||
resp := byteframe.NewByteFrame()
|
resp := byteframe.NewByteFrame()
|
||||||
|
|||||||
@@ -34,9 +34,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
|||||||
s2.charID = 200
|
s2.charID = 200
|
||||||
stage.clients[s1] = 100
|
stage.clients[s1] = 100
|
||||||
stage.clients[s2] = 200
|
stage.clients[s2] = 200
|
||||||
server.stagesLock.Lock()
|
server.stages.Store(stageID, stage)
|
||||||
server.stages[stageID] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
},
|
},
|
||||||
wantClientCount: 2,
|
wantClientCount: 2,
|
||||||
wantFailure: false,
|
wantFailure: false,
|
||||||
@@ -50,9 +48,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
|||||||
stage.reservedClientSlots[100] = false // Not ready
|
stage.reservedClientSlots[100] = false // Not ready
|
||||||
stage.reservedClientSlots[200] = true // Ready
|
stage.reservedClientSlots[200] = true // Ready
|
||||||
stage.reservedClientSlots[300] = false // Not ready
|
stage.reservedClientSlots[300] = false // Not ready
|
||||||
server.stagesLock.Lock()
|
server.stages.Store(stageID, stage)
|
||||||
server.stages[stageID] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
},
|
},
|
||||||
wantClientCount: 2, // Only not-ready clients
|
wantClientCount: 2, // Only not-ready clients
|
||||||
wantFailure: false,
|
wantFailure: false,
|
||||||
@@ -66,9 +62,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
|||||||
stage.reservedClientSlots[100] = false // Not ready
|
stage.reservedClientSlots[100] = false // Not ready
|
||||||
stage.reservedClientSlots[200] = true // Ready
|
stage.reservedClientSlots[200] = true // Ready
|
||||||
stage.reservedClientSlots[300] = true // Ready
|
stage.reservedClientSlots[300] = true // Ready
|
||||||
server.stagesLock.Lock()
|
server.stages.Store(stageID, stage)
|
||||||
server.stages[stageID] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
},
|
},
|
||||||
wantClientCount: 2, // Only ready clients
|
wantClientCount: 2, // Only ready clients
|
||||||
wantFailure: false,
|
wantFailure: false,
|
||||||
@@ -79,9 +73,7 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
|||||||
getType: 0,
|
getType: 0,
|
||||||
setupStage: func(server *Server, stageID string) {
|
setupStage: func(server *Server, stageID string) {
|
||||||
stage := NewStage(stageID)
|
stage := NewStage(stageID)
|
||||||
server.stagesLock.Lock()
|
server.stages.Store(stageID, stage)
|
||||||
server.stages[stageID] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
},
|
},
|
||||||
wantClientCount: 0,
|
wantClientCount: 0,
|
||||||
wantFailure: false,
|
wantFailure: false,
|
||||||
@@ -104,11 +96,6 @@ func TestHandleMsgSysEnumerateClient(t *testing.T) {
|
|||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
|
|
||||||
// Initialize stages map if needed
|
|
||||||
if s.server.stages == nil {
|
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Setup stage
|
// Setup stage
|
||||||
tt.setupStage(s.server, tt.stageID)
|
tt.setupStage(s.server, tt.stageID)
|
||||||
|
|
||||||
@@ -389,7 +376,6 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
|
|||||||
logger, _ := zap.NewDevelopment()
|
logger, _ := zap.NewDevelopment()
|
||||||
server := &Server{
|
server := &Server{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
erupeConfig: &cfg.Config{
|
erupeConfig: &cfg.Config{
|
||||||
DebugOptions: cfg.DebugOptions{
|
DebugOptions: cfg.DebugOptions{
|
||||||
LogOutboundMessages: false,
|
LogOutboundMessages: false,
|
||||||
@@ -408,9 +394,7 @@ func TestEnumerateClient_ConcurrentAccess(t *testing.T) {
|
|||||||
stage.clients[sess] = i * 100
|
stage.clients[sess] = i * 100
|
||||||
}
|
}
|
||||||
|
|
||||||
server.stagesLock.Lock()
|
server.stages.Store(stageID, stage)
|
||||||
server.stages[stageID] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
|
|
||||||
// Run concurrent enumerations
|
// Run concurrent enumerations
|
||||||
done := make(chan bool, 5)
|
done := make(chan bool, 5)
|
||||||
@@ -562,7 +546,6 @@ func BenchmarkEnumerateClients(b *testing.B) {
|
|||||||
logger, _ := zap.NewDevelopment()
|
logger, _ := zap.NewDevelopment()
|
||||||
server := &Server{
|
server := &Server{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
stageID := "bench_stage"
|
stageID := "bench_stage"
|
||||||
@@ -576,7 +559,7 @@ func BenchmarkEnumerateClients(b *testing.B) {
|
|||||||
stage.clients[sess] = i
|
stage.clients[sess] = i
|
||||||
}
|
}
|
||||||
|
|
||||||
server.stages[stageID] = stage
|
server.stages.Store(stageID, stage)
|
||||||
|
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
|
|||||||
@@ -600,9 +600,8 @@ func TestHandleMsgSysLockGlobalSema_WithChannel(t *testing.T) {
|
|||||||
// Create a mock channel with stages
|
// Create a mock channel with stages
|
||||||
channel := &Server{
|
channel := &Server{
|
||||||
GlobalID: "other-server",
|
GlobalID: "other-server",
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
}
|
}
|
||||||
channel.stages["stage_user123"] = NewStage("stage_user123")
|
channel.stages.Store("stage_user123", NewStage("stage_user123"))
|
||||||
server.Channels = []*Server{channel}
|
server.Channels = []*Server{channel}
|
||||||
|
|
||||||
session := createMockSession(1, server)
|
session := createMockSession(1, server)
|
||||||
@@ -632,9 +631,8 @@ func TestHandleMsgSysLockGlobalSema_SameServer(t *testing.T) {
|
|||||||
// Create a mock channel with same GlobalID
|
// Create a mock channel with same GlobalID
|
||||||
channel := &Server{
|
channel := &Server{
|
||||||
GlobalID: "test-server",
|
GlobalID: "test-server",
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
}
|
}
|
||||||
channel.stages["stage_user456"] = NewStage("stage_user456")
|
channel.stages.Store("stage_user456", NewStage("stage_user456"))
|
||||||
server.Channels = []*Server{channel}
|
server.Channels = []*Server{channel}
|
||||||
|
|
||||||
session := createMockSession(1, server)
|
session := createMockSession(1, server)
|
||||||
|
|||||||
@@ -984,7 +984,7 @@ func TestHandleMsgSysCreateStage_Coverage3(t *testing.T) {
|
|||||||
default:
|
default:
|
||||||
t.Error("no response queued")
|
t.Error("no response queued")
|
||||||
}
|
}
|
||||||
if _, exists := server.stages["test_create_stage"]; !exists {
|
if _, exists := server.stages.Get("test_create_stage"); !exists {
|
||||||
t.Error("stage should have been created")
|
t.Error("stage should have been created")
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -290,7 +290,7 @@ func logoutPlayer(s *Session) {
|
|||||||
_ = s.rawConn.Close()
|
_ = s.rawConn.Close()
|
||||||
s.server.Unlock()
|
s.server.Unlock()
|
||||||
|
|
||||||
// Stage cleanup — snapshot sessions first under server mutex, then iterate stages under stagesLock
|
// Stage cleanup — snapshot sessions first under server mutex, then iterate stages
|
||||||
s.server.Lock()
|
s.server.Lock()
|
||||||
sessionSnapshot := make([]*Session, 0, len(s.server.sessions))
|
sessionSnapshot := make([]*Session, 0, len(s.server.sessions))
|
||||||
for _, sess := range s.server.sessions {
|
for _, sess := range s.server.sessions {
|
||||||
@@ -298,8 +298,7 @@ func logoutPlayer(s *Session) {
|
|||||||
}
|
}
|
||||||
s.server.Unlock()
|
s.server.Unlock()
|
||||||
|
|
||||||
s.server.stagesLock.RLock()
|
s.server.stages.Range(func(_ string, stage *Stage) bool {
|
||||||
for _, stage := range s.server.stages {
|
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
// Tell sessions registered to disconnecting player's quest to unregister
|
// Tell sessions registered to disconnecting player's quest to unregister
|
||||||
if stage.host != nil && stage.host.charID == s.charID {
|
if stage.host != nil && stage.host.charID == s.charID {
|
||||||
@@ -317,8 +316,8 @@ func logoutPlayer(s *Session) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
stage.Unlock()
|
stage.Unlock()
|
||||||
}
|
return true
|
||||||
s.server.stagesLock.RUnlock()
|
})
|
||||||
|
|
||||||
// Update sign sessions and server player count
|
// Update sign sessions and server player count
|
||||||
if s.server.db != nil {
|
if s.server.db != nil {
|
||||||
@@ -346,13 +345,12 @@ func logoutPlayer(s *Session) {
|
|||||||
CharID: s.charID,
|
CharID: s.charID,
|
||||||
}, s)
|
}, s)
|
||||||
|
|
||||||
s.server.stagesLock.RLock()
|
s.server.stages.Range(func(_ string, stage *Stage) bool {
|
||||||
for _, stage := range s.server.stages {
|
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
delete(stage.reservedClientSlots, s.charID)
|
delete(stage.reservedClientSlots, s.charID)
|
||||||
stage.Unlock()
|
stage.Unlock()
|
||||||
}
|
return true
|
||||||
s.server.stagesLock.RUnlock()
|
})
|
||||||
|
|
||||||
removeSessionFromSemaphore(s)
|
removeSessionFromSemaphore(s)
|
||||||
removeSessionFromStage(s)
|
removeSessionFromStage(s)
|
||||||
@@ -449,13 +447,12 @@ func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString)
|
sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString)
|
||||||
} else {
|
} else {
|
||||||
for _, channel := range s.server.Channels {
|
for _, channel := range s.server.Channels {
|
||||||
channel.stagesLock.RLock()
|
channel.stages.Range(func(id string, _ *Stage) bool {
|
||||||
for id := range channel.stages {
|
|
||||||
if strings.HasSuffix(id, pkt.UserIDString) {
|
if strings.HasSuffix(id, pkt.UserIDString) {
|
||||||
sgid = channel.GlobalID
|
sgid = channel.GlobalID
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
channel.stagesLock.RUnlock()
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
bf := byteframe.NewByteFrame()
|
bf := byteframe.NewByteFrame()
|
||||||
@@ -689,10 +686,11 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
if count == maxResults {
|
if count == maxResults {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
c.stagesLock.RLock()
|
cIP := net.ParseIP(c.IP).To4()
|
||||||
for _, stage := range c.stages {
|
cPort := c.Port
|
||||||
|
c.stages.Range(func(_ string, stage *Stage) bool {
|
||||||
if count == maxResults {
|
if count == maxResults {
|
||||||
break
|
return false
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) {
|
if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) {
|
||||||
stage.RLock()
|
stage.RLock()
|
||||||
@@ -718,7 +716,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
if findPartyParams.RankRestriction >= 0 {
|
if findPartyParams.RankRestriction >= 0 {
|
||||||
if stageData[0] > findPartyParams.RankRestriction {
|
if stageData[0] > findPartyParams.RankRestriction {
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -732,7 +730,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
}
|
}
|
||||||
if !hasTarget {
|
if !hasTarget {
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -746,8 +744,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
|
|
||||||
count++
|
count++
|
||||||
stageResults = append(stageResults, stageResult{
|
stageResults = append(stageResults, stageResult{
|
||||||
ip: net.ParseIP(c.IP).To4(),
|
ip: cIP,
|
||||||
port: c.Port,
|
port: cPort,
|
||||||
clientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
clientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
||||||
reserved: len(stage.reservedClientSlots),
|
reserved: len(stage.reservedClientSlots),
|
||||||
maxPlayers: stage.maxPlayers,
|
maxPlayers: stage.maxPlayers,
|
||||||
@@ -758,8 +756,8 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
})
|
})
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
c.stagesLock.RUnlock()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sr := range stageResults {
|
for _, sr := range stageResults {
|
||||||
|
|||||||
@@ -14,32 +14,23 @@ import (
|
|||||||
|
|
||||||
func handleMsgSysCreateStage(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysCreateStage(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysCreateStage)
|
pkt := p.(*mhfpacket.MsgSysCreateStage)
|
||||||
s.server.stagesLock.Lock()
|
stage := NewStage(pkt.StageID)
|
||||||
defer s.server.stagesLock.Unlock()
|
stage.host = s
|
||||||
if _, exists := s.server.stages[pkt.StageID]; exists {
|
stage.maxPlayers = uint16(pkt.PlayerCount)
|
||||||
doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
|
if s.server.stages.StoreIfAbsent(pkt.StageID, stage) {
|
||||||
} else {
|
|
||||||
stage := NewStage(pkt.StageID)
|
|
||||||
stage.host = s
|
|
||||||
stage.maxPlayers = uint16(pkt.PlayerCount)
|
|
||||||
s.server.stages[stage.id] = stage
|
|
||||||
doAckSimpleSucceed(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
|
doAckSimpleSucceed(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
|
||||||
|
} else {
|
||||||
|
doAckSimpleFail(s, pkt.AckHandle, []byte{0x00, 0x00, 0x00, 0x00})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleMsgSysStageDestruct(s *Session, p mhfpacket.MHFPacket) {}
|
func handleMsgSysStageDestruct(s *Session, p mhfpacket.MHFPacket) {}
|
||||||
|
|
||||||
func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
|
func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
|
||||||
s.server.stagesLock.Lock()
|
stage, created := s.server.stages.GetOrCreate(stageID)
|
||||||
stage, exists := s.server.stages[stageID]
|
|
||||||
if !exists {
|
|
||||||
s.server.stages[stageID] = NewStage(stageID)
|
|
||||||
stage = s.server.stages[stageID]
|
|
||||||
}
|
|
||||||
s.server.stagesLock.Unlock()
|
|
||||||
|
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
if !exists {
|
if created {
|
||||||
stage.host = s
|
stage.host = s
|
||||||
}
|
}
|
||||||
stage.clients[s] = s.charID
|
stage.clients[s] = s.charID
|
||||||
@@ -50,12 +41,9 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
|
|||||||
removeSessionFromStage(s)
|
removeSessionFromStage(s)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save our new stage ID and pointer to the new stage itself.
|
// Save our new stage pointer.
|
||||||
s.server.stagesLock.RLock()
|
|
||||||
newStage := s.server.stages[stageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
s.Lock()
|
s.Lock()
|
||||||
s.stage = newStage
|
s.stage = stage
|
||||||
s.Unlock()
|
s.Unlock()
|
||||||
|
|
||||||
// Tell the client to cleanup its current stage objects.
|
// Tell the client to cleanup its current stage objects.
|
||||||
@@ -140,22 +128,20 @@ func doStageTransfer(s *Session, ackHandle uint32, stageID string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func destructEmptyStages(s *Session) {
|
func destructEmptyStages(s *Session) {
|
||||||
s.server.stagesLock.Lock()
|
s.server.stages.Range(func(id string, stage *Stage) bool {
|
||||||
defer s.server.stagesLock.Unlock()
|
|
||||||
for _, stage := range s.server.stages {
|
|
||||||
// Destroy empty Quest/My series/Guild stages.
|
// Destroy empty Quest/My series/Guild stages.
|
||||||
if stage.id[3:5] == "Qs" || stage.id[3:5] == "Ms" || stage.id[3:5] == "Gs" || stage.id[3:5] == "Ls" {
|
if id[3:5] == "Qs" || id[3:5] == "Ms" || id[3:5] == "Gs" || id[3:5] == "Ls" {
|
||||||
// Lock stage to safely check its client and reservation counts
|
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
isEmpty := len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0
|
isEmpty := len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0
|
||||||
stage.Unlock()
|
stage.Unlock()
|
||||||
|
|
||||||
if isEmpty {
|
if isEmpty {
|
||||||
delete(s.server.stages, stage.id)
|
s.server.stages.Delete(id)
|
||||||
s.logger.Debug("Destructed stage", zap.String("stage.id", stage.id))
|
s.logger.Debug("Destructed stage", zap.String("stage.id", id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func removeSessionFromStage(s *Session) {
|
func removeSessionFromStage(s *Session) {
|
||||||
@@ -194,9 +180,7 @@ func removeSessionFromStage(s *Session) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func isStageFull(s *Session, StageID string) bool {
|
func isStageFull(s *Session, StageID string) bool {
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(StageID)
|
||||||
stage, exists := s.server.stages[StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
// Lock stage to safely check client counts
|
// Lock stage to safely check client counts
|
||||||
@@ -261,9 +245,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
s.stage.Unlock()
|
s.stage.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
s.server.stagesLock.RLock()
|
backStagePtr, exists := s.server.stages.Get(backStage)
|
||||||
backStagePtr, exists := s.server.stages[backStage]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
backStagePtr.Lock()
|
backStagePtr.Lock()
|
||||||
delete(backStagePtr.reservedClientSlots, s.charID)
|
delete(backStagePtr.reservedClientSlots, s.charID)
|
||||||
@@ -288,9 +270,7 @@ func handleMsgSysLeaveStage(s *Session, p mhfpacket.MHFPacket) {}
|
|||||||
|
|
||||||
func handleMsgSysLockStage(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysLockStage(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysLockStage)
|
pkt := p.(*mhfpacket.MsgSysLockStage)
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(pkt.StageID)
|
||||||
stage, exists := s.server.stages[pkt.StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
stage.locked = true
|
stage.locked = true
|
||||||
@@ -317,10 +297,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Delete from stages map under stagesLock (not nested inside stage RLock)
|
s.server.stages.Delete(stageID)
|
||||||
s.server.stagesLock.Lock()
|
|
||||||
delete(s.server.stages, stageID)
|
|
||||||
s.server.stagesLock.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
destructEmptyStages(s)
|
destructEmptyStages(s)
|
||||||
@@ -328,9 +305,7 @@ func handleMsgSysUnlockStage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
|
|
||||||
func handleMsgSysReserveStage(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysReserveStage(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysReserveStage)
|
pkt := p.(*mhfpacket.MsgSysReserveStage)
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(pkt.StageID)
|
||||||
stage, exists := s.server.stages[pkt.StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
defer stage.Unlock()
|
defer stage.Unlock()
|
||||||
@@ -402,9 +377,7 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
|
|
||||||
func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysSetStageBinary)
|
pkt := p.(*mhfpacket.MsgSysSetStageBinary)
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(pkt.StageID)
|
||||||
stage, exists := s.server.stages[pkt.StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] = pkt.RawDataPayload
|
stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] = pkt.RawDataPayload
|
||||||
@@ -416,9 +389,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
|
|
||||||
func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysGetStageBinary)
|
pkt := p.(*mhfpacket.MsgSysGetStageBinary)
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(pkt.StageID)
|
||||||
stage, exists := s.server.stages[pkt.StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
stage.Lock()
|
stage.Lock()
|
||||||
if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists {
|
if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists {
|
||||||
@@ -443,9 +414,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
|
|
||||||
func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysWaitStageBinary)
|
pkt := p.(*mhfpacket.MsgSysWaitStageBinary)
|
||||||
s.server.stagesLock.RLock()
|
stage, exists := s.server.stages.Get(pkt.StageID)
|
||||||
stage, exists := s.server.stages[pkt.StageID]
|
|
||||||
s.server.stagesLock.RUnlock()
|
|
||||||
if exists {
|
if exists {
|
||||||
if pkt.BinaryType0 == 1 && pkt.BinaryType1 == 12 {
|
if pkt.BinaryType0 == 1 && pkt.BinaryType1 == 12 {
|
||||||
// This might contain the hunter count, or max player count?
|
// This might contain the hunter count, or max player count?
|
||||||
@@ -479,24 +448,20 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
|
func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
|
||||||
pkt := p.(*mhfpacket.MsgSysEnumerateStage)
|
pkt := p.(*mhfpacket.MsgSysEnumerateStage)
|
||||||
|
|
||||||
// Read-lock the server stage map.
|
|
||||||
s.server.stagesLock.RLock()
|
|
||||||
defer s.server.stagesLock.RUnlock()
|
|
||||||
|
|
||||||
// Build the response
|
// Build the response
|
||||||
bf := byteframe.NewByteFrame()
|
bf := byteframe.NewByteFrame()
|
||||||
var joinable uint16
|
var joinable uint16
|
||||||
bf.WriteUint16(0)
|
bf.WriteUint16(0)
|
||||||
for sid, stage := range s.server.stages {
|
s.server.stages.Range(func(sid string, stage *Stage) bool {
|
||||||
stage.RLock()
|
stage.RLock()
|
||||||
|
|
||||||
if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 {
|
if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 {
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
if !strings.Contains(stage.id, pkt.StagePrefix) {
|
if !strings.Contains(stage.id, pkt.StagePrefix) {
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
continue
|
return true
|
||||||
}
|
}
|
||||||
joinable++
|
joinable++
|
||||||
|
|
||||||
@@ -518,7 +483,8 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
|
|||||||
bf.WriteUint8(flags)
|
bf.WriteUint8(flags)
|
||||||
ps.Uint8(bf, sid, false)
|
ps.Uint8(bf, sid, false)
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
_, _ = bf.Seek(0, 0)
|
_, _ = bf.Seek(0, 0)
|
||||||
bf.WriteUint16(joinable)
|
bf.WriteUint16(joinable)
|
||||||
|
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ const raceTestCompletionMsg = "Test completed. No race conditions with fixed loc
|
|||||||
func TestCreateStageSuccess(t *testing.T) {
|
func TestCreateStageSuccess(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create a new stage
|
// Create a new stage
|
||||||
pkt := &mhfpacket.MsgSysCreateStage{
|
pkt := &mhfpacket.MsgSysCreateStage{
|
||||||
@@ -29,11 +29,10 @@ func TestCreateStageSuccess(t *testing.T) {
|
|||||||
handleMsgSysCreateStage(s, pkt)
|
handleMsgSysCreateStage(s, pkt)
|
||||||
|
|
||||||
// Verify stage was created
|
// Verify stage was created
|
||||||
if _, exists := s.server.stages["test_stage_1"]; !exists {
|
stage, exists := s.server.stages.Get("test_stage_1")
|
||||||
|
if !exists {
|
||||||
t.Error("stage was not created")
|
t.Error("stage was not created")
|
||||||
}
|
}
|
||||||
|
|
||||||
stage := s.server.stages["test_stage_1"]
|
|
||||||
if stage.id != "test_stage_1" {
|
if stage.id != "test_stage_1" {
|
||||||
t.Errorf("stage ID mismatch: got %s, want test_stage_1", stage.id)
|
t.Errorf("stage ID mismatch: got %s, want test_stage_1", stage.id)
|
||||||
}
|
}
|
||||||
@@ -46,7 +45,7 @@ func TestCreateStageSuccess(t *testing.T) {
|
|||||||
func TestCreateStageDuplicate(t *testing.T) {
|
func TestCreateStageDuplicate(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create first stage
|
// Create first stage
|
||||||
pkt1 := &mhfpacket.MsgSysCreateStage{
|
pkt1 := &mhfpacket.MsgSysCreateStage{
|
||||||
@@ -65,8 +64,10 @@ func TestCreateStageDuplicate(t *testing.T) {
|
|||||||
handleMsgSysCreateStage(s, pkt2)
|
handleMsgSysCreateStage(s, pkt2)
|
||||||
|
|
||||||
// Verify only one stage exists
|
// Verify only one stage exists
|
||||||
if len(s.server.stages) != 1 {
|
count := 0
|
||||||
t.Errorf("expected 1 stage, got %d", len(s.server.stages))
|
s.server.stages.Range(func(_ string, _ *Stage) bool { count++; return true })
|
||||||
|
if count != 1 {
|
||||||
|
t.Errorf("expected 1 stage, got %d", count)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,13 +75,13 @@ func TestCreateStageDuplicate(t *testing.T) {
|
|||||||
func TestStageLocking(t *testing.T) {
|
func TestStageLocking(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create a stage
|
// Create a stage
|
||||||
stage := NewStage("locked_stage")
|
stage := NewStage("locked_stage")
|
||||||
stage.host = s
|
stage.host = s
|
||||||
stage.password = ""
|
stage.password = ""
|
||||||
s.server.stages["locked_stage"] = stage
|
s.server.stages.Store("locked_stage", stage)
|
||||||
|
|
||||||
// Lock the stage
|
// Lock the stage
|
||||||
pkt := &mhfpacket.MsgSysLockStage{
|
pkt := &mhfpacket.MsgSysLockStage{
|
||||||
@@ -103,14 +104,14 @@ func TestStageLocking(t *testing.T) {
|
|||||||
func TestStageReservation(t *testing.T) {
|
func TestStageReservation(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create a stage
|
// Create a stage
|
||||||
stage := NewStage("reserved_stage")
|
stage := NewStage("reserved_stage")
|
||||||
stage.host = s
|
stage.host = s
|
||||||
stage.reservedClientSlots = make(map[uint32]bool)
|
stage.reservedClientSlots = make(map[uint32]bool)
|
||||||
stage.reservedClientSlots[s.charID] = false // Pre-add the charID so reservation works
|
stage.reservedClientSlots[s.charID] = false // Pre-add the charID so reservation works
|
||||||
s.server.stages["reserved_stage"] = stage
|
s.server.stages.Store("reserved_stage", stage)
|
||||||
|
|
||||||
// Reserve the stage
|
// Reserve the stage
|
||||||
pkt := &mhfpacket.MsgSysReserveStage{
|
pkt := &mhfpacket.MsgSysReserveStage{
|
||||||
@@ -163,8 +164,8 @@ func TestStageBinaryData(t *testing.T) {
|
|||||||
stage := NewStage("binary_stage")
|
stage := NewStage("binary_stage")
|
||||||
stage.rawBinaryData = make(map[stageBinaryKey][]byte)
|
stage.rawBinaryData = make(map[stageBinaryKey][]byte)
|
||||||
s.stage = stage
|
s.stage = stage
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.stages["binary_stage"] = stage
|
s.server.stages.Store("binary_stage", stage)
|
||||||
|
|
||||||
// Store binary data directly
|
// Store binary data directly
|
||||||
key := stageBinaryKey{id0: byte(s.charID >> 8), id1: byte(s.charID & 0xFF)}
|
key := stageBinaryKey{id0: byte(s.charID >> 8), id1: byte(s.charID & 0xFF)}
|
||||||
@@ -230,8 +231,8 @@ func TestIsStageFull(t *testing.T) {
|
|||||||
stage.clients[client] = uint32(i)
|
stage.clients[client] = uint32(i)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.stages["full_test_stage"] = stage
|
s.server.stages.Store("full_test_stage", stage)
|
||||||
|
|
||||||
result := isStageFull(s, "full_test_stage")
|
result := isStageFull(s, "full_test_stage")
|
||||||
if result != tt.wantFull {
|
if result != tt.wantFull {
|
||||||
@@ -245,14 +246,14 @@ func TestIsStageFull(t *testing.T) {
|
|||||||
func TestEnumerateStage(t *testing.T) {
|
func TestEnumerateStage(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
// Create multiple stages
|
// Create multiple stages
|
||||||
for i := 0; i < 3; i++ {
|
for i := 0; i < 3; i++ {
|
||||||
stage := NewStage("stage_" + string(rune(i)))
|
stage := NewStage("stage_" + string(rune(i)))
|
||||||
stage.maxPlayers = 4
|
stage.maxPlayers = 4
|
||||||
s.server.stages[stage.id] = stage
|
s.server.stages.Store(stage.id, stage)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Enumerate stages
|
// Enumerate stages
|
||||||
@@ -264,8 +265,10 @@ func TestEnumerateStage(t *testing.T) {
|
|||||||
|
|
||||||
// Basic verification that enumeration was processed
|
// Basic verification that enumeration was processed
|
||||||
// In a real test, we'd verify the response packet content
|
// In a real test, we'd verify the response packet content
|
||||||
if len(s.server.stages) != 3 {
|
stageCount := 0
|
||||||
t.Errorf("expected 3 stages, got %d", len(s.server.stages))
|
s.server.stages.Range(func(_ string, _ *Stage) bool { stageCount++; return true })
|
||||||
|
if stageCount != 3 {
|
||||||
|
t.Errorf("expected 3 stages, got %d", stageCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -279,8 +282,8 @@ func TestRemoveSessionFromStage(t *testing.T) {
|
|||||||
stage.clients[s] = s.charID
|
stage.clients[s] = s.charID
|
||||||
|
|
||||||
s.stage = stage
|
s.stage = stage
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.stages["removal_stage"] = stage
|
s.server.stages.Store("removal_stage", stage)
|
||||||
|
|
||||||
// Remove session
|
// Remove session
|
||||||
removeSessionFromStage(s)
|
removeSessionFromStage(s)
|
||||||
@@ -299,18 +302,18 @@ func TestRemoveSessionFromStage(t *testing.T) {
|
|||||||
func TestDestructEmptyStages(t *testing.T) {
|
func TestDestructEmptyStages(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create stages with different client counts
|
// Create stages with different client counts
|
||||||
emptyStage := NewStage("empty_stage")
|
emptyStage := NewStage("empty_stage")
|
||||||
emptyStage.clients = make(map[*Session]uint32)
|
emptyStage.clients = make(map[*Session]uint32)
|
||||||
emptyStage.host = s // Host needs to be set or it won't be destructed
|
emptyStage.host = s // Host needs to be set or it won't be destructed
|
||||||
s.server.stages["empty_stage"] = emptyStage
|
s.server.stages.Store("empty_stage", emptyStage)
|
||||||
|
|
||||||
populatedStage := NewStage("populated_stage")
|
populatedStage := NewStage("populated_stage")
|
||||||
populatedStage.clients = make(map[*Session]uint32)
|
populatedStage.clients = make(map[*Session]uint32)
|
||||||
populatedStage.clients[s] = s.charID
|
populatedStage.clients[s] = s.charID
|
||||||
s.server.stages["populated_stage"] = populatedStage
|
s.server.stages.Store("populated_stage", populatedStage)
|
||||||
|
|
||||||
// Destruct empty stages (from the channel server's perspective, not our session's)
|
// Destruct empty stages (from the channel server's perspective, not our session's)
|
||||||
// The function destructs stages that are not referenced by us or don't have clients
|
// The function destructs stages that are not referenced by us or don't have clients
|
||||||
@@ -318,8 +321,10 @@ func TestDestructEmptyStages(t *testing.T) {
|
|||||||
|
|
||||||
// For this test to work correctly, we'd need to verify the actual removal
|
// For this test to work correctly, we'd need to verify the actual removal
|
||||||
// Let's just verify the stages exist first
|
// Let's just verify the stages exist first
|
||||||
if len(s.server.stages) != 2 {
|
initialCount := 0
|
||||||
t.Errorf("expected 2 stages initially, got %d", len(s.server.stages))
|
s.server.stages.Range(func(_ string, _ *Stage) bool { initialCount++; return true })
|
||||||
|
if initialCount != 2 {
|
||||||
|
t.Errorf("expected 2 stages initially, got %d", initialCount)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,14 +332,14 @@ func TestDestructEmptyStages(t *testing.T) {
|
|||||||
func TestStageTransferBasic(t *testing.T) {
|
func TestStageTransferBasic(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
// Transfer to non-existent stage (should create it)
|
// Transfer to non-existent stage (should create it)
|
||||||
doStageTransfer(s, 0x12345678, "new_transfer_stage")
|
doStageTransfer(s, 0x12345678, "new_transfer_stage")
|
||||||
|
|
||||||
// Verify stage was created
|
// Verify stage was created
|
||||||
if stage, exists := s.server.stages["new_transfer_stage"]; !exists {
|
if stage, exists := s.server.stages.Get("new_transfer_stage"); !exists {
|
||||||
t.Error("stage was not created during transfer")
|
t.Error("stage was not created during transfer")
|
||||||
} else {
|
} else {
|
||||||
// Verify session is in the stage
|
// Verify session is in the stage
|
||||||
@@ -357,12 +362,12 @@ func TestStageTransferBasic(t *testing.T) {
|
|||||||
func TestEnterStageBasic(t *testing.T) {
|
func TestEnterStageBasic(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
stage := NewStage("entry_stage")
|
stage := NewStage("entry_stage")
|
||||||
stage.clients = make(map[*Session]uint32)
|
stage.clients = make(map[*Session]uint32)
|
||||||
s.server.stages["entry_stage"] = stage
|
s.server.stages.Store("entry_stage", stage)
|
||||||
|
|
||||||
pkt := &mhfpacket.MsgSysEnterStage{
|
pkt := &mhfpacket.MsgSysEnterStage{
|
||||||
StageID: "entry_stage",
|
StageID: "entry_stage",
|
||||||
@@ -383,7 +388,7 @@ func TestEnterStageBasic(t *testing.T) {
|
|||||||
func TestMoveStagePreservesData(t *testing.T) {
|
func TestMoveStagePreservesData(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
// Create source stage with binary data
|
// Create source stage with binary data
|
||||||
@@ -392,13 +397,13 @@ func TestMoveStagePreservesData(t *testing.T) {
|
|||||||
sourceStage.rawBinaryData = make(map[stageBinaryKey][]byte)
|
sourceStage.rawBinaryData = make(map[stageBinaryKey][]byte)
|
||||||
key := stageBinaryKey{id0: 0x00, id1: 0x01}
|
key := stageBinaryKey{id0: 0x00, id1: 0x01}
|
||||||
sourceStage.rawBinaryData[key] = []byte{0xAA, 0xBB}
|
sourceStage.rawBinaryData[key] = []byte{0xAA, 0xBB}
|
||||||
s.server.stages["source_stage"] = sourceStage
|
s.server.stages.Store("source_stage", sourceStage)
|
||||||
s.stage = sourceStage
|
s.stage = sourceStage
|
||||||
|
|
||||||
// Create destination stage
|
// Create destination stage
|
||||||
destStage := NewStage("dest_stage")
|
destStage := NewStage("dest_stage")
|
||||||
destStage.clients = make(map[*Session]uint32)
|
destStage.clients = make(map[*Session]uint32)
|
||||||
s.server.stages["dest_stage"] = destStage
|
s.server.stages.Store("dest_stage", destStage)
|
||||||
|
|
||||||
pkt := &mhfpacket.MsgSysMoveStage{
|
pkt := &mhfpacket.MsgSysMoveStage{
|
||||||
StageID: "dest_stage",
|
StageID: "dest_stage",
|
||||||
@@ -417,12 +422,12 @@ func TestMoveStagePreservesData(t *testing.T) {
|
|||||||
func TestConcurrentStageOperations(t *testing.T) {
|
func TestConcurrentStageOperations(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
baseSession := createTestSession(mock)
|
baseSession := createTestSession(mock)
|
||||||
baseSession.server.stages = make(map[string]*Stage)
|
|
||||||
|
|
||||||
// Create a stage
|
// Create a stage
|
||||||
stage := NewStage("concurrent_stage")
|
stage := NewStage("concurrent_stage")
|
||||||
stage.clients = make(map[*Session]uint32)
|
stage.clients = make(map[*Session]uint32)
|
||||||
baseSession.server.stages["concurrent_stage"] = stage
|
baseSession.server.stages.Store("concurrent_stage", stage)
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
@@ -459,7 +464,7 @@ func TestConcurrentStageOperations(t *testing.T) {
|
|||||||
func TestBackStageNavigation(t *testing.T) {
|
func TestBackStageNavigation(t *testing.T) {
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
// Create a stringstack for stage move history
|
// Create a stringstack for stage move history
|
||||||
@@ -472,8 +477,8 @@ func TestBackStageNavigation(t *testing.T) {
|
|||||||
stage2 := NewStage("stage_2")
|
stage2 := NewStage("stage_2")
|
||||||
stage2.clients = make(map[*Session]uint32)
|
stage2.clients = make(map[*Session]uint32)
|
||||||
|
|
||||||
s.server.stages["stage_1"] = stage1
|
s.server.stages.Store("stage_1", stage1)
|
||||||
s.server.stages["stage_2"] = stage2
|
s.server.stages.Store("stage_2", stage2)
|
||||||
|
|
||||||
// First enter stage 2 and push to stack
|
// First enter stage 2 and push to stack
|
||||||
s.stage = stage2
|
s.stage = stage2
|
||||||
@@ -502,13 +507,13 @@ func TestRaceConditionRemoveSessionFromStageNotLocked(t *testing.T) {
|
|||||||
|
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
s := createTestSession(mock)
|
s := createTestSession(mock)
|
||||||
s.server.stages = make(map[string]*Stage)
|
|
||||||
s.server.sessions = make(map[net.Conn]*Session)
|
s.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
stage := NewStage("race_test_stage")
|
stage := NewStage("race_test_stage")
|
||||||
stage.clients = make(map[*Session]uint32)
|
stage.clients = make(map[*Session]uint32)
|
||||||
stage.objects = make(map[uint32]*Object)
|
stage.objects = make(map[uint32]*Object)
|
||||||
s.server.stages["race_test_stage"] = stage
|
s.server.stages.Store("race_test_stage", stage)
|
||||||
s.stage = stage
|
s.stage = stage
|
||||||
stage.clients[s] = s.charID
|
stage.clients[s] = s.charID
|
||||||
|
|
||||||
@@ -567,14 +572,14 @@ func TestRaceConditionDoStageTransferUnlockedAccess(t *testing.T) {
|
|||||||
|
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
baseSession := createTestSession(mock)
|
baseSession := createTestSession(mock)
|
||||||
baseSession.server.stages = make(map[string]*Stage)
|
|
||||||
baseSession.server.sessions = make(map[net.Conn]*Session)
|
baseSession.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
// Create initial stage
|
// Create initial stage
|
||||||
stage := NewStage("initial_stage")
|
stage := NewStage("initial_stage")
|
||||||
stage.clients = make(map[*Session]uint32)
|
stage.clients = make(map[*Session]uint32)
|
||||||
stage.objects = make(map[uint32]*Object)
|
stage.objects = make(map[uint32]*Object)
|
||||||
baseSession.server.stages["initial_stage"] = stage
|
baseSession.server.stages.Store("initial_stage", stage)
|
||||||
baseSession.stage = stage
|
baseSession.stage = stage
|
||||||
stage.clients[baseSession] = baseSession.charID
|
stage.clients[baseSession] = baseSession.charID
|
||||||
|
|
||||||
@@ -631,13 +636,13 @@ func TestRaceConditionStageObjectsIteration(t *testing.T) {
|
|||||||
|
|
||||||
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
mock := &MockCryptConn{sentPackets: make([][]byte, 0)}
|
||||||
baseSession := createTestSession(mock)
|
baseSession := createTestSession(mock)
|
||||||
baseSession.server.stages = make(map[string]*Stage)
|
|
||||||
baseSession.server.sessions = make(map[net.Conn]*Session)
|
baseSession.server.sessions = make(map[net.Conn]*Session)
|
||||||
|
|
||||||
stage := NewStage("object_race_stage")
|
stage := NewStage("object_race_stage")
|
||||||
stage.clients = make(map[*Session]uint32)
|
stage.clients = make(map[*Session]uint32)
|
||||||
stage.objects = make(map[uint32]*Object)
|
stage.objects = make(map[uint32]*Object)
|
||||||
baseSession.server.stages["object_race_stage"] = stage
|
baseSession.server.stages.Store("object_race_stage", stage)
|
||||||
baseSession.stage = stage
|
baseSession.stage = stage
|
||||||
stage.clients[baseSession] = baseSession.charID
|
stage.clients[baseSession] = baseSession.charID
|
||||||
|
|
||||||
|
|||||||
@@ -582,7 +582,6 @@ func createTestServerWithDB(t *testing.T, db *sqlx.DB) *Server {
|
|||||||
server := &Server{
|
server := &Server{
|
||||||
db: db,
|
db: db,
|
||||||
sessions: make(map[net.Conn]*Session),
|
sessions: make(map[net.Conn]*Session),
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
userBinary: NewUserBinaryStore(),
|
userBinary: NewUserBinaryStore(),
|
||||||
minidata: NewMinidataStore(),
|
minidata: NewMinidataStore(),
|
||||||
semaphore: make(map[string]*Semaphore),
|
semaphore: make(map[string]*Semaphore),
|
||||||
|
|||||||
@@ -33,9 +33,11 @@ type Config struct {
|
|||||||
//
|
//
|
||||||
// Lock ordering (acquire in this order to avoid deadlocks):
|
// Lock ordering (acquire in this order to avoid deadlocks):
|
||||||
// 1. Server.Mutex – protects sessions map
|
// 1. Server.Mutex – protects sessions map
|
||||||
// 2. Server.stagesLock – protects stages map
|
// 2. Stage.RWMutex – protects per-stage state (clients, objects)
|
||||||
// 3. Stage.RWMutex – protects per-stage state (clients, objects)
|
// 3. Server.semaphoreLock – protects semaphore map
|
||||||
// 4. Server.semaphoreLock – protects semaphore map
|
//
|
||||||
|
// Note: Server.stages is a StageMap (sync.Map-backed), so it requires no
|
||||||
|
// external lock for reads or writes.
|
||||||
//
|
//
|
||||||
// Self-contained stores (userBinary, minidata, questCache) manage their
|
// Self-contained stores (userBinary, minidata, questCache) manage their
|
||||||
// own locks internally and may be acquired at any point.
|
// own locks internally and may be acquired at any point.
|
||||||
@@ -78,8 +80,7 @@ type Server struct {
|
|||||||
isShuttingDown bool
|
isShuttingDown bool
|
||||||
done chan struct{} // Closed on Shutdown to wake background goroutines.
|
done chan struct{} // Closed on Shutdown to wake background goroutines.
|
||||||
|
|
||||||
stagesLock sync.RWMutex
|
stages StageMap
|
||||||
stages map[string]*Stage
|
|
||||||
|
|
||||||
// Used to map different languages
|
// Used to map different languages
|
||||||
i18n i18n
|
i18n i18n
|
||||||
@@ -115,7 +116,6 @@ func NewServer(config *Config) *Server {
|
|||||||
deleteConns: make(chan net.Conn),
|
deleteConns: make(chan net.Conn),
|
||||||
done: make(chan struct{}),
|
done: make(chan struct{}),
|
||||||
sessions: make(map[net.Conn]*Session),
|
sessions: make(map[net.Conn]*Session),
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
userBinary: NewUserBinaryStore(),
|
userBinary: NewUserBinaryStore(),
|
||||||
minidata: NewMinidataStore(),
|
minidata: NewMinidataStore(),
|
||||||
semaphore: make(map[string]*Semaphore),
|
semaphore: make(map[string]*Semaphore),
|
||||||
@@ -155,25 +155,25 @@ func NewServer(config *Config) *Server {
|
|||||||
s.mercenaryRepo = NewMercenaryRepository(config.DB)
|
s.mercenaryRepo = NewMercenaryRepository(config.DB)
|
||||||
|
|
||||||
// Mezeporta
|
// Mezeporta
|
||||||
s.stages["sl1Ns200p0a0u0"] = NewStage("sl1Ns200p0a0u0")
|
s.stages.Store("sl1Ns200p0a0u0", NewStage("sl1Ns200p0a0u0"))
|
||||||
|
|
||||||
// Rasta bar stage
|
// Rasta bar stage
|
||||||
s.stages["sl1Ns211p0a0u0"] = NewStage("sl1Ns211p0a0u0")
|
s.stages.Store("sl1Ns211p0a0u0", NewStage("sl1Ns211p0a0u0"))
|
||||||
|
|
||||||
// Pallone Carvan
|
// Pallone Carvan
|
||||||
s.stages["sl1Ns260p0a0u0"] = NewStage("sl1Ns260p0a0u0")
|
s.stages.Store("sl1Ns260p0a0u0", NewStage("sl1Ns260p0a0u0"))
|
||||||
|
|
||||||
// Pallone Guest House 1st Floor
|
// Pallone Guest House 1st Floor
|
||||||
s.stages["sl1Ns262p0a0u0"] = NewStage("sl1Ns262p0a0u0")
|
s.stages.Store("sl1Ns262p0a0u0", NewStage("sl1Ns262p0a0u0"))
|
||||||
|
|
||||||
// Pallone Guest House 2nd Floor
|
// Pallone Guest House 2nd Floor
|
||||||
s.stages["sl1Ns263p0a0u0"] = NewStage("sl1Ns263p0a0u0")
|
s.stages.Store("sl1Ns263p0a0u0", NewStage("sl1Ns263p0a0u0"))
|
||||||
|
|
||||||
// Diva fountain / prayer fountain.
|
// Diva fountain / prayer fountain.
|
||||||
s.stages["sl2Ns379p0a0u0"] = NewStage("sl2Ns379p0a0u0")
|
s.stages.Store("sl2Ns379p0a0u0", NewStage("sl2Ns379p0a0u0"))
|
||||||
|
|
||||||
// MezFes
|
// MezFes
|
||||||
s.stages["sl1Ns462p0a0u0"] = NewStage("sl1Ns462p0a0u0")
|
s.stages.Store("sl1Ns462p0a0u0", NewStage("sl1Ns462p0a0u0"))
|
||||||
|
|
||||||
s.i18n = getLangStrings(s)
|
s.i18n = getLangStrings(s)
|
||||||
|
|
||||||
@@ -424,21 +424,20 @@ func (s *Server) DisconnectUser(uid uint32) {
|
|||||||
|
|
||||||
// FindObjectByChar finds a stage object owned by the given character ID.
|
// FindObjectByChar finds a stage object owned by the given character ID.
|
||||||
func (s *Server) FindObjectByChar(charID uint32) *Object {
|
func (s *Server) FindObjectByChar(charID uint32) *Object {
|
||||||
s.stagesLock.RLock()
|
var found *Object
|
||||||
defer s.stagesLock.RUnlock()
|
s.stages.Range(func(_ string, stage *Stage) bool {
|
||||||
for _, stage := range s.stages {
|
|
||||||
stage.RLock()
|
stage.RLock()
|
||||||
for objId := range stage.objects {
|
for _, obj := range stage.objects {
|
||||||
obj := stage.objects[objId]
|
|
||||||
if obj.ownerCharID == charID {
|
if obj.ownerCharID == charID {
|
||||||
|
found = obj
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
return obj
|
return false // stop iteration
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
stage.RUnlock()
|
stage.RUnlock()
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
return nil
|
return found
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasSemaphore checks if the given session is hosting any semaphore.
|
// HasSemaphore checks if the given session is hosting any semaphore.
|
||||||
|
|||||||
@@ -56,7 +56,6 @@ func createTestServer() *Server {
|
|||||||
ID: 1,
|
ID: 1,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
sessions: make(map[net.Conn]*Session),
|
sessions: make(map[net.Conn]*Session),
|
||||||
stages: make(map[string]*Stage),
|
|
||||||
semaphore: make(map[string]*Semaphore),
|
semaphore: make(map[string]*Semaphore),
|
||||||
questCache: NewQuestCache(0),
|
questCache: NewQuestCache(0),
|
||||||
erupeConfig: &cfg.Config{
|
erupeConfig: &cfg.Config{
|
||||||
@@ -125,7 +124,7 @@ func TestNewServer(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, stageID := range expectedStages {
|
for _, stageID := range expectedStages {
|
||||||
if _, exists := server.stages[stageID]; !exists {
|
if _, exists := server.stages.Get(stageID); !exists {
|
||||||
t.Errorf("Default stage %s not initialized", stageID)
|
t.Errorf("Default stage %s not initialized", stageID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -682,9 +681,7 @@ func TestFindObjectByChar(t *testing.T) {
|
|||||||
stage.objects[1] = obj1
|
stage.objects[1] = obj1
|
||||||
stage.objects[2] = obj2
|
stage.objects[2] = obj2
|
||||||
|
|
||||||
server.stagesLock.Lock()
|
server.stages.Store("test_stage", stage)
|
||||||
server.stages["test_stage"] = stage
|
|
||||||
server.stagesLock.Unlock()
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -7,6 +7,57 @@ import (
|
|||||||
"erupe-ce/network/mhfpacket"
|
"erupe-ce/network/mhfpacket"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StageMap is a concurrent-safe map of stage ID → *Stage backed by sync.Map.
|
||||||
|
// It replaces the former stagesLock + map[string]*Stage pattern, eliminating
|
||||||
|
// read contention entirely (reads are lock-free) and allowing concurrent
|
||||||
|
// writes to disjoint keys.
|
||||||
|
type StageMap struct {
|
||||||
|
m sync.Map
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns the stage for the given ID, or (nil, false) if not found.
|
||||||
|
func (sm *StageMap) Get(id string) (*Stage, bool) {
|
||||||
|
v, ok := sm.m.Load(id)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return v.(*Stage), true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreate atomically returns the existing stage for id, or creates and
|
||||||
|
// stores a new one. The second return value is true when a new stage was created.
|
||||||
|
func (sm *StageMap) GetOrCreate(id string) (*Stage, bool) {
|
||||||
|
newStage := NewStage(id)
|
||||||
|
v, loaded := sm.m.LoadOrStore(id, newStage)
|
||||||
|
return v.(*Stage), !loaded // created == !loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// StoreIfAbsent stores the stage only if the key does not already exist.
|
||||||
|
// Returns true if the store succeeded (key was absent).
|
||||||
|
func (sm *StageMap) StoreIfAbsent(id string, stage *Stage) bool {
|
||||||
|
_, loaded := sm.m.LoadOrStore(id, stage)
|
||||||
|
return !loaded
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store unconditionally sets the stage for the given ID.
|
||||||
|
func (sm *StageMap) Store(id string, stage *Stage) {
|
||||||
|
sm.m.Store(id, stage)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete removes the stage with the given ID.
|
||||||
|
func (sm *StageMap) Delete(id string) {
|
||||||
|
sm.m.Delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Range iterates over all stages. The callback receives each (id, stage) pair
|
||||||
|
// and should return true to continue iteration or false to stop.
|
||||||
|
// It is safe to call Delete during iteration.
|
||||||
|
func (sm *StageMap) Range(fn func(id string, stage *Stage) bool) {
|
||||||
|
sm.m.Range(func(key, value any) bool {
|
||||||
|
return fn(key.(string), value.(*Stage))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// Object holds infomation about a specific object.
|
// Object holds infomation about a specific object.
|
||||||
type Object struct {
|
type Object struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func createMockServer() *Server {
|
|||||||
s := &Server{
|
s := &Server{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
erupeConfig: &cfg.Config{},
|
erupeConfig: &cfg.Config{},
|
||||||
stages: make(map[string]*Stage),
|
// stages is a StageMap (zero value is ready to use)
|
||||||
sessions: make(map[net.Conn]*Session),
|
sessions: make(map[net.Conn]*Session),
|
||||||
handlerTable: buildHandlerTable(),
|
handlerTable: buildHandlerTable(),
|
||||||
raviente: &Raviente{
|
raviente: &Raviente{
|
||||||
|
|||||||
Reference in New Issue
Block a user