propose mutex rework

This commit is contained in:
wish
2023-11-11 19:39:42 +11:00
parent 6ff20858ed
commit e24c432d69
4 changed files with 28 additions and 35 deletions

View File

@@ -10,15 +10,14 @@ import (
func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateClient)
s.server.stagesLock.RLock()
s.server.RLock()
stage, ok := s.server.stages[pkt.StageID]
s.server.RUnlock()
if !ok {
s.server.stagesLock.RUnlock()
s.logger.Warn("Can't enumerate clients for stage that doesn't exist!", zap.String("stageID", pkt.StageID))
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
return
}
s.server.stagesLock.RUnlock()
// Read-lock the stage and make the response with all of the charID's in the stage.
resp := byteframe.NewByteFrame()
@@ -49,7 +48,6 @@ func handleMsgSysEnumerateClient(s *Session, p mhfpacket.MHFPacket) {
stage.RUnlock()
doAckBufSucceed(s, pkt.AckHandle, resp.Data())
s.logger.Debug("MsgSysEnumerateClient Done!")
}
func handleMsgMhfListMember(s *Session, p mhfpacket.MHFPacket) {

View File

@@ -181,6 +181,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
panic(err)
}
s.stage.Lock()
if _, exists := s.stage.reservedClientSlots[s.charID]; exists {
delete(s.stage.reservedClientSlots, s.charID)
}
@@ -188,6 +189,7 @@ func handleMsgSysBackStage(s *Session, p mhfpacket.MHFPacket) {
if _, exists := s.server.stages[backStage].reservedClientSlots[s.charID]; exists {
delete(s.server.stages[backStage].reservedClientSlots, s.charID)
}
s.stage.Unlock()
doStageTransfer(s, pkt.AckHandle, backStage)
}
@@ -279,9 +281,9 @@ func handleMsgSysUnreserveStage(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysSetStagePass)
s.Lock()
s.RLock()
stage := s.reservationStage
s.Unlock()
s.RUnlock()
if stage != nil {
stage.Lock()
// Will only exist if host.
@@ -291,9 +293,9 @@ func handleMsgSysSetStagePass(s *Session, p mhfpacket.MHFPacket) {
stage.Unlock()
} else {
// Store for use on next ReserveStage.
s.Lock()
s.RLock()
s.stagePass = pkt.Password
s.Unlock()
s.RUnlock()
}
}
@@ -311,7 +313,7 @@ func handleMsgSysSetStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysGetStageBinary)
if stage, exists := s.server.stages[pkt.StageID]; exists {
stage.Lock()
stage.RLock()
if binaryData, exists := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]; exists {
doAckBufSucceed(s, pkt.AckHandle, binaryData)
} else if pkt.BinaryType1 == 4 {
@@ -323,7 +325,7 @@ func handleMsgSysGetStageBinary(s *Session, p mhfpacket.MHFPacket) {
s.logger.Warn("Sending blank stage binary")
doAckBufSucceed(s, pkt.AckHandle, []byte{})
}
stage.Unlock()
stage.RUnlock()
} else {
s.logger.Warn("Failed to get stage", zap.String("StageID", pkt.StageID))
}
@@ -340,9 +342,9 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
}
for {
s.logger.Debug("MsgSysWaitStageBinary before lock and get stage")
stage.Lock()
stage.RLock()
stageBinary, gotBinary := stage.rawBinaryData[stageBinaryKey{pkt.BinaryType0, pkt.BinaryType1}]
stage.Unlock()
stage.RUnlock()
s.logger.Debug("MsgSysWaitStageBinary after lock and get stage")
if gotBinary {
doAckBufSucceed(s, pkt.AckHandle, stageBinary)
@@ -362,29 +364,27 @@ func handleMsgSysWaitStageBinary(s *Session, p mhfpacket.MHFPacket) {
func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
pkt := p.(*mhfpacket.MsgSysEnumerateStage)
// Read-lock the server stage map.
s.server.stagesLock.RLock()
defer s.server.stagesLock.RUnlock()
// Build the response
bf := byteframe.NewByteFrame()
var joinable uint16
bf.WriteUint16(0)
for sid, stage := range s.server.stages {
var stages []*Stage
s.server.RLock()
for _, stage := range s.server.stages {
stage.RLock()
defer stage.RUnlock()
if len(stage.reservedClientSlots) == 0 && len(stage.clients) == 0 {
stage.RUnlock()
continue
}
if !strings.Contains(stage.id, pkt.StagePrefix) {
stage.RUnlock()
continue
}
joinable++
stages = append(stages, stage)
}
s.server.RUnlock()
bf.WriteUint16(uint16(len(stages)))
for _, stage := range stages {
stage.RLock()
bf.WriteUint16(uint16(len(stage.reservedClientSlots)))
bf.WriteUint16(0) // Unk
bf.WriteUint16(0)
if len(stage.clients) > 0 {
bf.WriteUint16(1)
} else {
@@ -393,16 +393,14 @@ func handleMsgSysEnumerateStage(s *Session, p mhfpacket.MHFPacket) {
bf.WriteUint16(stage.maxPlayers)
if len(stage.password) > 0 {
// This byte has also been seen as 1
// The quest is also recognised as locked when this is 2
// 2/3 = Locked, bitfield?
bf.WriteUint8(2)
} else {
bf.WriteUint8(0)
}
ps.Uint8(bf, sid, false)
ps.Uint8(bf, stage.id, false)
stage.RUnlock()
}
bf.Seek(0, 0)
bf.WriteUint16(joinable)
doAckBufSucceed(s, pkt.AckHandle, bf.Data())
}

View File

@@ -36,7 +36,7 @@ type userBinaryPartID struct {
// Server is a MHF channel server.
type Server struct {
sync.Mutex
sync.RWMutex
Channels []*Server
ID uint16
GlobalID string
@@ -52,8 +52,7 @@ type Server struct {
listener net.Listener // Listener that is created when Server.Start is called.
isShuttingDown bool
stagesLock sync.RWMutex
stages map[string]*Stage
stages map[string]*Stage
// Used to map different languages
dict map[string]string
@@ -374,8 +373,6 @@ func (s *Server) FindSessionByCharID(charID uint32) *Session {
}
func (s *Server) FindObjectByChar(charID uint32) *Object {
s.stagesLock.RLock()
defer s.stagesLock.RUnlock()
for _, stage := range s.stages {
stage.RLock()
for objId := range stage.objects {

View File

@@ -25,7 +25,7 @@ type packet struct {
// Session holds state for the channel server connection.
type Session struct {
sync.Mutex
sync.RWMutex
logger *zap.Logger
server *Server
rawConn net.Conn