diff --git a/config.example.json b/config.example.json index a1951e791..e92a7a92f 100644 --- a/config.example.json +++ b/config.example.json @@ -219,34 +219,34 @@ { "Name": "Newbie", "Description": "", "IP": "", "Type": 3, "Recommended": 2, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54001, "MaxPlayers": 100 }, - { "Port": 54002, "MaxPlayers": 100 } + { "Port": 54001, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54002, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Normal", "Description": "", "IP": "", "Type": 1, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54003, "MaxPlayers": 100 }, - { "Port": 54004, "MaxPlayers": 100 } + { "Port": 54003, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54004, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Cities", "Description": "", "IP": "", "Type": 2, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54005, "MaxPlayers": 100 } + { "Port": 54005, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Tavern", "Description": "", "IP": "", "Type": 4, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54006, "MaxPlayers": 100 } + { "Port": 54006, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Return", "Description": "", "IP": "", "Type": 5, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54007, "MaxPlayers": 100 } + { "Port": 54007, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "MezFes", "Description": "", "IP": "", "Type": 6, "Recommended": 6, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54008, "MaxPlayers": 100 } + { "Port": 54008, "MaxPlayers": 100, "Enabled": true } ] } ] diff --git a/config/config.go b/config/config.go index e30cbcd12..6c26798aa 100644 --- a/config/config.go +++ b/config/config.go @@ -297,6 +297,15 @@ type EntranceChannelInfo struct { Port uint16 MaxPlayers uint16 CurrentPlayers uint16 + Enabled *bool // nil defaults to true for backward compatibility +} + +// IsEnabled returns whether this channel is enabled. Defaults to true if Enabled is nil. +func (c *EntranceChannelInfo) IsEnabled() bool { + if c.Enabled == nil { + return true + } + return *c.Enabled } var ErupeConfig *Config diff --git a/config/config_test.go b/config/config_test.go index 782b3ef89..cbad553ec 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -536,6 +536,34 @@ func TestEntranceChannelInfo(t *testing.T) { } } +// TestEntranceChannelInfoIsEnabled tests the Enabled field and IsEnabled helper +func TestEntranceChannelInfoIsEnabled(t *testing.T) { + trueVal := true + falseVal := false + + tests := []struct { + name string + enabled *bool + want bool + }{ + {"nil defaults to true", nil, true}, + {"explicit true", &trueVal, true}, + {"explicit false", &falseVal, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := EntranceChannelInfo{ + Port: 10001, + Enabled: tt.enabled, + } + if got := info.IsEnabled(); got != tt.want { + t.Errorf("IsEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + // TestDiscord verifies Discord struct func TestDiscord(t *testing.T) { discord := Discord{ diff --git a/main.go b/main.go index de44e0693..c6da1c977 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "erupe-ce/server/discordbot" "erupe-ce/server/entranceserver" "erupe-ce/server/signserver" + "strings" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -129,11 +130,30 @@ func main() { } logger.Info("Database: Started successfully") - // Clear stale data - if config.DebugOptions.ProxyPort == 0 { - _ = db.MustExec("DELETE FROM sign_sessions") + // Pre-compute all server IDs this instance will own, so we only + // delete our own rows (safe for multi-instance on the same DB). + var ownedServerIDs []string + { + si := 0 + for _, ee := range config.Entrance.Entries { + ci := 0 + for range ee.Channels { + sid := (4096 + si*256) + (16 + ci) + ownedServerIDs = append(ownedServerIDs, fmt.Sprint(sid)) + ci++ + } + si++ + } + } + + // Clear stale data scoped to this instance's server IDs + if len(ownedServerIDs) > 0 { + idList := strings.Join(ownedServerIDs, ",") + if config.DebugOptions.ProxyPort == 0 { + _ = db.MustExec("DELETE FROM sign_sessions WHERE server_id IN (" + idList + ")") + } + _ = db.MustExec("DELETE FROM servers WHERE server_id IN (" + idList + ")") } - _ = db.MustExec("DELETE FROM servers") _ = db.MustExec(`UPDATE guild_characters SET treasure_hunt=NULL`) // Clean the DB if the option is on. @@ -213,6 +233,12 @@ func main() { for j, ee := range config.Entrance.Entries { for i, ce := range ee.Channels { sid := (4096 + si*256) + (16 + ci) + if !ce.IsEnabled() { + logger.Info(fmt.Sprintf("Channel %d (%d): Disabled via config", count, ce.Port)) + ci++ + count++ + continue + } c := *channelserver.NewServer(&channelserver.Config{ ID: uint16(sid), Logger: logger.Named("channel-" + fmt.Sprint(count)), @@ -237,9 +263,9 @@ func main() { ) channels = append(channels, &c) logger.Info(fmt.Sprintf("Channel %d (%d): Started successfully", count, ce.Port)) - ci++ count++ } + ci++ } ci = 0 si++ @@ -248,8 +274,10 @@ func main() { // Register all servers in DB _ = db.MustExec(channelQuery) + registry := channelserver.NewLocalChannelRegistry(channels) for _, c := range channels { c.Channels = channels + c.Registry = registry } } diff --git a/server/channelserver/channel_registry.go b/server/channelserver/channel_registry.go new file mode 100644 index 000000000..af391a727 --- /dev/null +++ b/server/channelserver/channel_registry.go @@ -0,0 +1,58 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" +) + +// ChannelRegistry abstracts cross-channel operations behind an interface. +// The default LocalChannelRegistry wraps the in-process []*Server slice. +// Future implementations may use DB/Redis/NATS for multi-process deployments. +type ChannelRegistry interface { + // Worldcast broadcasts a packet to all sessions across all channels. + Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) + + // FindSessionByCharID looks up a session by character ID across all channels. + FindSessionByCharID(charID uint32) *Session + + // DisconnectUser disconnects all sessions belonging to the given character IDs. + DisconnectUser(cids []uint32) + + // FindChannelForStage searches all channels for a stage whose ID has the + // given suffix and returns the owning channel's GlobalID, or "" if not found. + FindChannelForStage(stageSuffix string) string + + // SearchSessions searches sessions across all channels using a predicate, + // returning up to max snapshot results. + SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot + + // SearchStages searches stages across all channels with a prefix filter, + // returning up to max snapshot results. + SearchStages(stagePrefix string, max int) []StageSnapshot + + // NotifyMailToCharID finds the session for charID and sends a mail notification. + NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) +} + +// SessionSnapshot is an immutable copy of session data taken under lock. +type SessionSnapshot struct { + CharID uint32 + Name string + StageID string + ServerIP net.IP + ServerPort uint16 + UserBinary3 []byte // Copy of userBinaryParts index 3 +} + +// StageSnapshot is an immutable copy of stage data taken under lock. +type StageSnapshot struct { + ServerIP net.IP + ServerPort uint16 + StageID string + ClientCount int + Reserved int + MaxPlayers uint16 + RawBinData0 []byte + RawBinData1 []byte + RawBinData3 []byte +} diff --git a/server/channelserver/channel_registry_local.go b/server/channelserver/channel_registry_local.go new file mode 100644 index 000000000..a04651d0e --- /dev/null +++ b/server/channelserver/channel_registry_local.go @@ -0,0 +1,156 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" + "strings" +) + +// LocalChannelRegistry is the in-process ChannelRegistry backed by []*Server. +type LocalChannelRegistry struct { + channels []*Server +} + +// NewLocalChannelRegistry creates a LocalChannelRegistry wrapping the given channels. +func NewLocalChannelRegistry(channels []*Server) *LocalChannelRegistry { + return &LocalChannelRegistry{channels: channels} +} + +func (r *LocalChannelRegistry) Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + for _, c := range r.channels { + if c == ignoredChannel { + continue + } + c.BroadcastMHF(pkt, ignoredSession) + } +} + +func (r *LocalChannelRegistry) FindSessionByCharID(charID uint32) *Session { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + if session.charID == charID { + c.Unlock() + return session + } + } + c.Unlock() + } + return nil +} + +func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + for _, cid := range cids { + if session.charID == cid { + _ = session.rawConn.Close() + break + } + } + } + c.Unlock() + } +} + +func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string { + for _, channel := range r.channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, stageSuffix) { + gid := channel.GlobalID + channel.stagesLock.RUnlock() + return gid + } + } + channel.stagesLock.RUnlock() + } + return "" +} + +func (r *LocalChannelRegistry) SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot { + var results []SessionSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() + for _, session := range c.sessions { + if len(results) >= max { + break + } + snap := SessionSnapshot{ + CharID: session.charID, + Name: session.Name, + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + } + if session.stage != nil { + snap.StageID = session.stage.id + } + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + if len(ub3) > 0 { + snap.UserBinary3 = make([]byte, len(ub3)) + copy(snap.UserBinary3, ub3) + } + if predicate(snap) { + results = append(results, snap) + } + } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + return results +} + +func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []StageSnapshot { + var results []StageSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.stagesLock.RLock() + for _, stage := range c.stages { + if len(results) >= max { + break + } + if !strings.HasPrefix(stage.id, stagePrefix) { + continue + } + stage.RLock() + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + bin3 := stage.rawBinaryData[stageBinaryKey{1, 3}] + bin3Copy := make([]byte, len(bin3)) + copy(bin3Copy, bin3) + + results = append(results, StageSnapshot{ + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + StageID: stage.id, + ClientCount: len(stage.clients) + len(stage.reservedClientSlots), + Reserved: len(stage.reservedClientSlots), + MaxPlayers: stage.maxPlayers, + RawBinData0: bin0Copy, + RawBinData1: bin1Copy, + RawBinData3: bin3Copy, + }) + stage.RUnlock() + } + c.stagesLock.RUnlock() + } + return results +} + +func (r *LocalChannelRegistry) NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) { + session := r.FindSessionByCharID(charID) + if session != nil { + SendMailNotification(sender, mail, session) + } +} diff --git a/server/channelserver/channel_registry_test.go b/server/channelserver/channel_registry_test.go new file mode 100644 index 000000000..bdeccffbf --- /dev/null +++ b/server/channelserver/channel_registry_test.go @@ -0,0 +1,190 @@ +package channelserver + +import ( + "net" + "sync" + "testing" +) + +func createTestChannels(count int) []*Server { + channels := make([]*Server, count) + for i := 0; i < count; i++ { + s := createTestServer() + s.ID = uint16(0x1010 + i) + s.IP = "10.0.0.1" + s.Port = uint16(54001 + i) + s.GlobalID = "0101" + s.userBinaryParts = make(map[userBinaryPartID][]byte) + channels[i] = s + } + return channels +} + +func TestLocalRegistryFindSessionByCharID(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + conn1 := &mockConn{} + sess1 := createTestSessionForServer(channels[0], conn1, 100, "Alice") + channels[0].Lock() + channels[0].sessions[conn1] = sess1 + channels[0].Unlock() + + conn2 := &mockConn{} + sess2 := createTestSessionForServer(channels[1], conn2, 200, "Bob") + channels[1].Lock() + channels[1].sessions[conn2] = sess2 + channels[1].Unlock() + + // Find on first channel + found := reg.FindSessionByCharID(100) + if found == nil || found.charID != 100 { + t.Errorf("FindSessionByCharID(100) = %v, want session with charID 100", found) + } + + // Find on second channel + found = reg.FindSessionByCharID(200) + if found == nil || found.charID != 200 { + t.Errorf("FindSessionByCharID(200) = %v, want session with charID 200", found) + } + + // Not found + found = reg.FindSessionByCharID(999) + if found != nil { + t.Errorf("FindSessionByCharID(999) = %v, want nil", found) + } +} + +func TestLocalRegistryFindChannelForStage(t *testing.T) { + channels := createTestChannels(2) + channels[0].GlobalID = "0101" + channels[1].GlobalID = "0102" + reg := NewLocalChannelRegistry(channels) + + channels[1].stagesLock.Lock() + channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42") + channels[1].stagesLock.Unlock() + + gid := reg.FindChannelForStage("u42") + if gid != "0102" { + t.Errorf("FindChannelForStage(u42) = %q, want %q", gid, "0102") + } + + gid = reg.FindChannelForStage("u999") + if gid != "" { + t.Errorf("FindChannelForStage(u999) = %q, want empty", gid) + } +} + +func TestLocalRegistryDisconnectUser(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + conn := &mockConn{} + sess := createTestSessionForServer(channels[0], conn, 42, "Target") + channels[0].Lock() + channels[0].sessions[conn] = sess + channels[0].Unlock() + + reg.DisconnectUser([]uint32{42}) + + if !conn.WasClosed() { + t.Error("DisconnectUser should have closed the connection for charID 42") + } +} + +func TestLocalRegistrySearchSessions(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Add 3 sessions across 2 channels + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + conn3 := &mockConn{} + sess3 := createTestSessionForServer(channels[0], conn3, 3, "Player") + sess3.stage = NewStage("sl1Ns200p0a0u0") + channels[0].Lock() + channels[0].sessions[conn3] = sess3 + channels[0].Unlock() + + // Search all + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 3 { + t.Errorf("SearchSessions(all) returned %d results, want 3", len(results)) + } + + // Search with max + results = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 2) + if len(results) != 2 { + t.Errorf("SearchSessions(max=2) returned %d results, want 2", len(results)) + } + + // Search with predicate + results = reg.SearchSessions(func(s SessionSnapshot) bool { return s.CharID == 1 }, 10) + if len(results) != 1 { + t.Errorf("SearchSessions(charID==1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistrySearchStages(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + channels[0].stagesLock.Lock() + channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1") + channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2") + channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other") + channels[0].stagesLock.Unlock() + + results := reg.SearchStages("sl2Ls210", 10) + if len(results) != 2 { + t.Errorf("SearchStages(sl2Ls210) returned %d results, want 2", len(results)) + } + + results = reg.SearchStages("sl2Ls210", 1) + if len(results) != 1 { + t.Errorf("SearchStages(sl2Ls210, max=1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistryConcurrentAccess(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Populate some sessions + for _, ch := range channels { + for i := 0; i < 10; i++ { + conn := &mockConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 50000 + i}} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + } + + // Run concurrent operations + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(3) + go func(id int) { + defer wg.Done() + _ = reg.FindSessionByCharID(uint32(id%10 + 1)) + }(i) + go func() { + defer wg.Done() + _ = reg.FindChannelForStage("u0") + }() + go func() { + defer wg.Done() + _ = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 5) + }() + } + wg.Wait() +} diff --git a/server/channelserver/handlers_guild_ops.go b/server/channelserver/handlers_guild_ops.go index 5a9a1b3e0..dc4086aea 100644 --- a/server/channelserver/handlers_guild_ops.go +++ b/server/channelserver/handlers_guild_ops.go @@ -304,11 +304,26 @@ func handleMsgMhfOperateGuildMember(s *Session, p mhfpacket.MHFPacket) { doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) } else { _ = mail.Send(s, nil) - for _, channel := range s.server.Channels { - for _, session := range channel.sessions { - if session.charID == pkt.CharID { - SendMailNotification(s, &mail, session) + if s.server.Registry != nil { + s.server.Registry.NotifyMailToCharID(pkt.CharID, s, &mail) + } else { + // Fallback: find the target session under lock, then notify outside the lock. + var targetSession *Session + for _, channel := range s.server.Channels { + channel.Lock() + for _, session := range channel.sessions { + if session.charID == pkt.CharID { + targetSession = session + break + } } + channel.Unlock() + if targetSession != nil { + break + } + } + if targetSession != nil { + SendMailNotification(s, &mail, targetSession) } } doAckSimpleSucceed(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 6e805a7dc..1944a7902 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -399,11 +399,17 @@ func handleMsgSysEcho(s *Session, p mhfpacket.MHFPacket) {} func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysLockGlobalSema) var sgid string - for _, channel := range s.server.Channels { - for id := range channel.stages { - if strings.HasSuffix(id, pkt.UserIDString) { - sgid = channel.GlobalID + if s.server.Registry != nil { + sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString) + } else { + for _, channel := range s.server.Channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, pkt.UserIDString) { + sgid = channel.GlobalID + } } + channel.stagesLock.RUnlock() } } bf := byteframe.NewByteFrame() @@ -468,7 +474,23 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { resp.WriteUint16(0) switch pkt.SearchType { case 1, 2, 3: // usersearchidx, usersearchname, lobbysearchname + // Snapshot matching sessions under lock, then build response outside locks. + type sessionResult struct { + charID uint32 + name []byte + stageID []byte + ip net.IP + port uint16 + userBin3 []byte + } + var results []sessionResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() for _, session := range c.sessions { if count == maxResults { break @@ -483,31 +505,45 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { continue } count++ - sessionName := stringsupport.UTF8ToSJIS(session.Name) - sessionStage := stringsupport.UTF8ToSJIS(session.stage.id) - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - resp.WriteUint32(session.charID) - resp.WriteUint8(uint8(len(sessionStage) + 1)) - resp.WriteUint8(uint8(len(sessionName) + 1)) - resp.WriteUint16(uint16(len(c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}]))) - - // TODO: This case might be <=G2 - if _config.ErupeConfig.RealClientMode <= _config.G1 { - resp.WriteBytes(make([]byte, 8)) - } else { - resp.WriteBytes(make([]byte, 40)) - } - resp.WriteBytes(make([]byte, 8)) - - resp.WriteNullTerminatedBytes(sessionStage) - resp.WriteNullTerminatedBytes(sessionName) - resp.WriteBytes(c.userBinaryParts[userBinaryPartID{session.charID, 3}]) + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + ub3Copy := make([]byte, len(ub3)) + copy(ub3Copy, ub3) + results = append(results, sessionResult{ + charID: session.charID, + name: stringsupport.UTF8ToSJIS(session.Name), + stageID: stringsupport.UTF8ToSJIS(session.stage.id), + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + userBin3: ub3Copy, + }) } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + + for _, r := range results { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(r.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(r.port) + resp.WriteUint32(r.charID) + resp.WriteUint8(uint8(len(r.stageID) + 1)) + resp.WriteUint8(uint8(len(r.name) + 1)) + resp.WriteUint16(uint16(len(r.userBin3))) + + // TODO: This case might be <=G2 + if _config.ErupeConfig.RealClientMode <= _config.G1 { + resp.WriteBytes(make([]byte, 8)) + } else { + resp.WriteBytes(make([]byte, 40)) + } + resp.WriteBytes(make([]byte, 8)) + + resp.WriteNullTerminatedBytes(r.stageID) + resp.WriteNullTerminatedBytes(r.name) + resp.WriteBytes(r.userBin3) } case 4: // lobbysearch type FindPartyParams struct { @@ -594,12 +630,31 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } } + // Snapshot matching stages under lock, then build response outside locks. + type stageResult struct { + ip net.IP + port uint16 + clientCount int + reserved int + maxPlayers uint16 + stageID string + stageData []int16 + rawBinData0 []byte + rawBinData1 []byte + } + var stageResults []stageResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.stagesLock.RLock() for _, stage := range c.stages { if count == maxResults { break } if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) { + stage.RLock() sb3 := byteframe.NewByteFrameFromBytes(stage.rawBinaryData[stageBinaryKey{1, 3}]) _, _ = sb3.Seek(4, 0) @@ -621,6 +676,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { if findPartyParams.RankRestriction >= 0 { if stageData[0] > findPartyParams.RankRestriction { + stage.RUnlock() continue } } @@ -634,47 +690,72 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } if !hasTarget { + stage.RUnlock() continue } } + // Copy binary data under lock + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + count++ - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - - resp.WriteUint16(0) // Static? - resp.WriteUint16(0) // Unk, [0 1 2] - resp.WriteUint16(uint16(len(stage.clients) + len(stage.reservedClientSlots))) - resp.WriteUint16(stage.maxPlayers) - // TODO: Retail returned the number of clients in quests, not workshop/my series - resp.WriteUint16(uint16(len(stage.reservedClientSlots))) - - resp.WriteUint8(0) // Static? - resp.WriteUint8(uint8(stage.maxPlayers)) - resp.WriteUint8(1) // Static? - resp.WriteUint8(uint8(len(stage.id) + 1)) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 0}]))) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 1}]))) - - for i := range stageData { - if _config.ErupeConfig.RealClientMode >= _config.Z1 { - resp.WriteInt16(stageData[i]) - } else { - resp.WriteInt8(int8(stageData[i])) - } - } - resp.WriteUint8(0) // Unk - resp.WriteUint8(0) // Unk - - resp.WriteNullTerminatedBytes([]byte(stage.id)) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 0}]) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 1}]) + stageResults = append(stageResults, stageResult{ + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + clientCount: len(stage.clients) + len(stage.reservedClientSlots), + reserved: len(stage.reservedClientSlots), + maxPlayers: stage.maxPlayers, + stageID: stage.id, + stageData: stageData, + rawBinData0: bin0Copy, + rawBinData1: bin1Copy, + }) + stage.RUnlock() } } + c.stagesLock.RUnlock() + } + + for _, sr := range stageResults { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(sr.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(sr.port) + + resp.WriteUint16(0) // Static? + resp.WriteUint16(0) // Unk, [0 1 2] + resp.WriteUint16(uint16(sr.clientCount)) + resp.WriteUint16(sr.maxPlayers) + // TODO: Retail returned the number of clients in quests, not workshop/my series + resp.WriteUint16(uint16(sr.reserved)) + + resp.WriteUint8(0) // Static? + resp.WriteUint8(uint8(sr.maxPlayers)) + resp.WriteUint8(1) // Static? + resp.WriteUint8(uint8(len(sr.stageID) + 1)) + resp.WriteUint8(uint8(len(sr.rawBinData0))) + resp.WriteUint8(uint8(len(sr.rawBinData1))) + + for i := range sr.stageData { + if _config.ErupeConfig.RealClientMode >= _config.Z1 { + resp.WriteInt16(sr.stageData[i]) + } else { + resp.WriteInt8(int8(sr.stageData[i])) + } + } + resp.WriteUint8(0) // Unk + resp.WriteUint8(0) // Unk + + resp.WriteNullTerminatedBytes([]byte(sr.stageID)) + resp.WriteBytes(sr.rawBinData0) + resp.WriteBytes(sr.rawBinData1) } } _, _ = resp.Seek(0, io.SeekStart) diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index d951f93d0..f042c6109 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -37,6 +37,7 @@ type userBinaryPartID struct { type Server struct { sync.Mutex Channels []*Server + Registry ChannelRegistry ID uint16 GlobalID string IP string @@ -271,6 +272,10 @@ func (s *Server) BroadcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session) // WorldcastMHF broadcasts a packet to all sessions across all channel servers. func (s *Server) WorldcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + if s.Registry != nil { + s.Registry.Worldcast(pkt, ignoredSession, ignoredChannel) + return + } for _, c := range s.Channels { if c == ignoredChannel { continue @@ -317,12 +322,18 @@ func (s *Server) DiscordScreenShotSend(charName string, title string, descriptio // FindSessionByCharID looks up a session by character ID across all channels. func (s *Server) FindSessionByCharID(charID uint32) *Session { + if s.Registry != nil { + return s.Registry.FindSessionByCharID(charID) + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { if session.charID == charID { + c.Unlock() return session } } + c.Unlock() } return nil } @@ -341,7 +352,12 @@ func (s *Server) DisconnectUser(uid uint32) { cids = append(cids, cid) } } + if s.Registry != nil { + s.Registry.DisconnectUser(cids) + return + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { for _, cid := range cids { if session.charID == cid { @@ -350,6 +366,7 @@ func (s *Server) DisconnectUser(uid uint32) { } } } + c.Unlock() } }