From e24c432d698346960186700e2816848393bc15d3 Mon Sep 17 00:00:00 2001 From: wish Date: Sat, 11 Nov 2023 19:39:42 +1100 Subject: [PATCH] propose mutex rework --- server/channelserver/handlers_clients.go | 6 +-- server/channelserver/handlers_stage.go | 48 +++++++++++----------- server/channelserver/sys_channel_server.go | 7 +--- server/channelserver/sys_session.go | 2 +- 4 files changed, 28 insertions(+), 35 deletions(-) diff --git a/server/channelserver/handlers_clients.go b/server/channelserver/handlers_clients.go index 12e8540b3..a617cad2d 100644 --- a/server/channelserver/handlers_clients.go +++ b/server/channelserver/handlers_clients.go @@ -10,15 +10,14 @@ import ( func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysEnumerateClient) - s.server.stagesLock.RLock() + s.server.RLock() stage, ok := s.server.stages[pkt.StageID] + s.server.RUnlock() if !ok { - s.server.stagesLock.RUnlock() s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) return } - s.server.stagesLock.RUnlock() // Read-lock the stage and make the response with all of the charID's in the stage. resp := byteframe.NewByteFrame() @@ -49,7 +48,6 @@ func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) { stage.RUnlock() doAckBufSucceed(s, pkt.AckHandle, resp.Data()) - s.logger.Debug("MsgSysEnumerateClient Done!") } func handleMsgMhfListMember(s *Session, p mhfpacket.MHFPacket) { diff --git a/server/channelserver/handlers_stage.go b/server/channelserver/handlers_stage.go index be7ab1267..c29ea0093 100644 --- a/server/channelserver/handlers_stage.go +++ b/server/channelserver/handlers_stage.go @@ -181,6 +181,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) { panic(err) } + s.stage.Lock() if _, exists := s.stage.reservedClientSlots[s.charID]; exists { delete(s.stage.reservedClientSlots, s.charID) } @@ -188,6 +189,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) { if _, exists := s.server.stages[backStage].reservedClientSlots[s.charID]; exists { delete(s.server.stages[backStage].reservedClientSlots, s.charID) } + s.stage.Unlock() doStageTransfer(s, pkt.AckHandle, backStage) } @@ -279,9 +281,9 @@ func handleMsgSysUnreserveStage(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysSetStagePass) - s.Lock() + s.RLock() stage := s.reservationStage - s.Unlock() + s.RUnlock() if stage != nil { stage.Lock() // Will only exist if host. @@ -291,9 +293,9 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) { stage.Unlock() } else { // Store for use on next ReserveStage. - s.Lock() + s.RLock() s.stagePass = pkt.Password - s.Unlock() + s.RUnlock() } } @@ -311,7 +313,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysGetStageBinary) if stage, exists := s.server.stages[pkt.StageID]; exists { - stage.Lock() + stage.RLock() if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists { doAckBufSucceed(s, pkt.AckHandle, binaryData) } else if pkt.BinaryType1 == 4 { @@ -323,7 +325,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) { s.logger.Warn("Sending blank stage binary") doAckBufSucceed(s, pkt.AckHandle, []byte{}) } - stage.Unlock() + stage.RUnlock() } else { s.logger.Warn("Failed to get stage", zap.String("StageID", pkt.StageID)) } @@ -340,9 +342,9 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) { } for { s.logger.Debug("MsgSysWaitStageBinary before lock and get stage") - stage.Lock() + stage.RLock() stageBinary, gotBinary := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] - stage.Unlock() + stage.RUnlock() s.logger.Debug("MsgSysWaitStageBinary after lock and get stage") if gotBinary { doAckBufSucceed(s, pkt.AckHandle, stageBinary) @@ -362,29 +364,27 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysEnumerateStage) - // Read-lock the server stage map. - s.server.stagesLock.RLock() - defer s.server.stagesLock.RUnlock() - - // Build the response bf := byteframe.NewByteFrame() - var joinable uint16 - bf.WriteUint16(0) - for sid, stage := range s.server.stages { + var stages []*Stage + s.server.RLock() + for _, stage := range s.server.stages { stage.RLock() - + defer stage.RUnlock() if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 { - stage.RUnlock() continue } if !strings.Contains(stage.id, pkt.StagePrefix) { - stage.RUnlock() continue } - joinable++ + stages = append(stages, stage) + } + s.server.RUnlock() + bf.WriteUint16(uint16(len(stages))) + for _, stage := range stages { + stage.RLock() bf.WriteUint16(uint16(len(stage.reservedClientSlots))) - bf.WriteUint16(0) // Unk + bf.WriteUint16(0) if len(stage.clients) > 0 { bf.WriteUint16(1) } else { @@ -393,16 +393,14 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) { bf.WriteUint16(stage.maxPlayers) if len(stage.password) > 0 { // This byte has also been seen as 1 - // The quest is also recognised as locked when this is 2 + // 2/3 = Locked, bitfield? bf.WriteUint8(2) } else { bf.WriteUint8(0) } - ps.Uint8(bf, sid, false) + ps.Uint8(bf, stage.id, false) stage.RUnlock() } - bf.Seek(0, 0) - bf.WriteUint16(joinable) doAckBufSucceed(s, pkt.AckHandle, bf.Data()) } diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index 1dfef82d0..e16d6f95d 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -36,7 +36,7 @@ type userBinaryPartID struct { // Server is a MHF channel server. type Server struct { - sync.Mutex + sync.RWMutex Channels []*Server ID uint16 GlobalID string @@ -52,8 +52,7 @@ type Server struct { listener net.Listener // Listener that is created when Server.Start is called. isShuttingDown bool - stagesLock sync.RWMutex - stages map[string]*Stage + stages map[string]*Stage // Used to map different languages dict map[string]string @@ -374,8 +373,6 @@ func (s *Server) FindSessionByCharID(charID uint32) *Session { } func (s *Server) FindObjectByChar(charID uint32) *Object { - s.stagesLock.RLock() - defer s.stagesLock.RUnlock() for _, stage := range s.stages { stage.RLock() for objId := range stage.objects { diff --git a/server/channelserver/sys_session.go b/server/channelserver/sys_session.go index 406ea4384..b11847af2 100644 --- a/server/channelserver/sys_session.go +++ b/server/channelserver/sys_session.go @@ -25,7 +25,7 @@ type packet struct { // Session holds state for the channel server connection. type Session struct { - sync.Mutex + sync.RWMutex logger *zap.Logger server *Server rawConn net.Conn