propose mutex rework

This commit is contained in:
wish
2023-11-11 19:39:42 +11:00
parent 6ff20858ed
commit e24c432d69
4 changed files with 28 additions and 35 deletions

View File

@@ -10,15 +10,14 @@ import (
func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateClient) pkt := p.(*mhfpacket.MsgSysEnumerateClient)
s.server.stagesLock.RLock() s.server.RLock()
stage, ok := s.server.stages[pkt.StageID] stage, ok := s.server.stages[pkt.StageID]
s.server.RUnlock()
if !ok { if !ok {
s.server.stagesLock.RUnlock()
s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID)) s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID))
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
return return
} }
s.server.stagesLock.RUnlock()
// Read-lock the stage and make the response with all of the charID's in the stage. // Read-lock the stage and make the response with all of the charID's in the stage.
resp := byteframe.NewByteFrame() resp := byteframe.NewByteFrame()
@@ -49,7 +48,6 @@ func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
stage.RUnlock() stage.RUnlock()
doAckBufSucceed(s, pkt.AckHandle, resp.Data()) doAckBufSucceed(s, pkt.AckHandle, resp.Data())
s.logger.Debug("MsgSysEnumerateClient Done!")
} }
func handleMsgMhfListMember(s *Session, p mhfpacket.MHFPacket) { func handleMsgMhfListMember(s *Session, p mhfpacket.MHFPacket) {

View File

@@ -181,6 +181,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
panic(err) panic(err)
} }
s.stage.Lock()
if _, exists := s.stage.reservedClientSlots[s.charID]; exists { if _, exists := s.stage.reservedClientSlots[s.charID]; exists {
delete(s.stage.reservedClientSlots, s.charID) delete(s.stage.reservedClientSlots, s.charID)
} }
@@ -188,6 +189,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
if _, exists := s.server.stages[backStage].reservedClientSlots[s.charID]; exists { if _, exists := s.server.stages[backStage].reservedClientSlots[s.charID]; exists {
delete(s.server.stages[backStage].reservedClientSlots, s.charID) delete(s.server.stages[backStage].reservedClientSlots, s.charID)
} }
s.stage.Unlock()
doStageTransfer(s, pkt.AckHandle, backStage) doStageTransfer(s, pkt.AckHandle, backStage)
} }
@@ -279,9 +281,9 @@ func handleMsgSysUnreserveStage(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysSetStagePass) pkt := p.(*mhfpacket.MsgSysSetStagePass)
s.Lock() s.RLock()
stage := s.reservationStage stage := s.reservationStage
s.Unlock() s.RUnlock()
if stage != nil { if stage != nil {
stage.Lock() stage.Lock()
// Will only exist if host. // Will only exist if host.
@@ -291,9 +293,9 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
stage.Unlock() stage.Unlock()
} else { } else {
// Store for use on next ReserveStage. // Store for use on next ReserveStage.
s.Lock() s.RLock()
s.stagePass = pkt.Password s.stagePass = pkt.Password
s.Unlock() s.RUnlock()
} }
} }
@@ -311,7 +313,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysGetStageBinary) pkt := p.(*mhfpacket.MsgSysGetStageBinary)
if stage, exists := s.server.stages[pkt.StageID]; exists { if stage, exists := s.server.stages[pkt.StageID]; exists {
stage.Lock() stage.RLock()
if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists { if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists {
doAckBufSucceed(s, pkt.AckHandle, binaryData) doAckBufSucceed(s, pkt.AckHandle, binaryData)
} else if pkt.BinaryType1 == 4 { } else if pkt.BinaryType1 == 4 {
@@ -323,7 +325,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
s.logger.Warn("Sending blank stage binary") s.logger.Warn("Sending blank stage binary")
doAckBufSucceed(s, pkt.AckHandle, []byte{}) doAckBufSucceed(s, pkt.AckHandle, []byte{})
} }
stage.Unlock() stage.RUnlock()
} else { } else {
s.logger.Warn("Failed to get stage", zap.String("StageID", pkt.StageID)) s.logger.Warn("Failed to get stage", zap.String("StageID", pkt.StageID))
} }
@@ -340,9 +342,9 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
} }
for { for {
s.logger.Debug("MsgSysWaitStageBinary before lock and get stage") s.logger.Debug("MsgSysWaitStageBinary before lock and get stage")
stage.Lock() stage.RLock()
stageBinary, gotBinary := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}] stageBinary, gotBinary := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]
stage.Unlock() stage.RUnlock()
s.logger.Debug("MsgSysWaitStageBinary after lock and get stage") s.logger.Debug("MsgSysWaitStageBinary after lock and get stage")
if gotBinary { if gotBinary {
doAckBufSucceed(s, pkt.AckHandle, stageBinary) doAckBufSucceed(s, pkt.AckHandle, stageBinary)
@@ -362,29 +364,27 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) { func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateStage) pkt := p.(*mhfpacket.MsgSysEnumerateStage)
// Read-lock the server stage map.
s.server.stagesLock.RLock()
defer s.server.stagesLock.RUnlock()
// Build the response
bf := byteframe.NewByteFrame() bf := byteframe.NewByteFrame()
var joinable uint16 var stages []*Stage
bf.WriteUint16(0) s.server.RLock()
for sid, stage := range s.server.stages { for _, stage := range s.server.stages {
stage.RLock() stage.RLock()
defer stage.RUnlock()
if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 { if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 {
stage.RUnlock()
continue continue
} }
if !strings.Contains(stage.id, pkt.StagePrefix) { if !strings.Contains(stage.id, pkt.StagePrefix) {
stage.RUnlock()
continue continue
} }
joinable++ stages = append(stages, stage)
}
s.server.RUnlock()
bf.WriteUint16(uint16(len(stages)))
for _, stage := range stages {
stage.RLock()
bf.WriteUint16(uint16(len(stage.reservedClientSlots))) bf.WriteUint16(uint16(len(stage.reservedClientSlots)))
bf.WriteUint16(0) // Unk bf.WriteUint16(0)
if len(stage.clients) > 0 { if len(stage.clients) > 0 {
bf.WriteUint16(1) bf.WriteUint16(1)
} else { } else {
@@ -393,16 +393,14 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
bf.WriteUint16(stage.maxPlayers) bf.WriteUint16(stage.maxPlayers)
if len(stage.password) > 0 { if len(stage.password) > 0 {
// This byte has also been seen as 1 // This byte has also been seen as 1
// The quest is also recognised as locked when this is 2 // 2/3 = Locked, bitfield?
bf.WriteUint8(2) bf.WriteUint8(2)
} else { } else {
bf.WriteUint8(0) bf.WriteUint8(0)
} }
ps.Uint8(bf, sid, false) ps.Uint8(bf, stage.id, false)
stage.RUnlock() stage.RUnlock()
} }
bf.Seek(0, 0)
bf.WriteUint16(joinable)
doAckBufSucceed(s, pkt.AckHandle, bf.Data()) doAckBufSucceed(s, pkt.AckHandle, bf.Data())
} }

View File

@@ -36,7 +36,7 @@ type userBinaryPartID struct {
// Server is a MHF channel server. // Server is a MHF channel server.
type Server struct { type Server struct {
sync.Mutex sync.RWMutex
Channels []*Server Channels []*Server
ID uint16 ID uint16
GlobalID string GlobalID string
@@ -52,8 +52,7 @@ type Server struct {
listener net.Listener // Listener that is created when Server.Start is called. listener net.Listener // Listener that is created when Server.Start is called.
isShuttingDown bool isShuttingDown bool
stagesLock sync.RWMutex stages map[string]*Stage
stages map[string]*Stage
// Used to map different languages // Used to map different languages
dict map[string]string dict map[string]string
@@ -374,8 +373,6 @@ func (s *Server) FindSessionByCharID(charID uint32) *Session {
} }
func (s *Server) FindObjectByChar(charID uint32) *Object { func (s *Server) FindObjectByChar(charID uint32) *Object {
s.stagesLock.RLock()
defer s.stagesLock.RUnlock()
for _, stage := range s.stages { for _, stage := range s.stages {
stage.RLock() stage.RLock()
for objId := range stage.objects { for objId := range stage.objects {

View File

@@ -25,7 +25,7 @@ type packet struct {
// Session holds state for the channel server connection. // Session holds state for the channel server connection.
type Session struct { type Session struct {
sync.Mutex sync.RWMutex
logger *zap.Logger logger *zap.Logger
server *Server server *Server
rawConn net.Conn rawConn net.Conn