From 754b5a3bff9b4b287f76083a3eddd599bdc98d27 Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Thu, 19 Feb 2026 18:13:34 +0100 Subject: [PATCH 1/5] feat(channelserver): decouple channel servers for independent operation (#33) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Enable multiple Erupe instances to share a single PostgreSQL database without destroying each other's state, fix existing data races in cross-channel access, and lay groundwork for future distributed channel server deployments. Phase 1 — DB safety: - Scope DELETE FROM servers/sign_sessions to this instance's server IDs - Fix ci++ bug where failed channel start shifted subsequent IDs Phase 2 — Fix data races in cross-channel access: - Lock sessions map in FindSessionByCharID and DisconnectUser - Lock stagesLock in handleMsgSysLockGlobalSema - Snapshot sessions/stages under lock in TransitMessage types 1-4 - Lock channel when finding mail notification targets Phase 3 — ChannelRegistry interface: - Define ChannelRegistry interface with 7 cross-channel operations - Implement LocalChannelRegistry with proper locking - Add SessionSnapshot/StageSnapshot immutable copy types - Delegate WorldcastMHF, FindSessionByCharID, DisconnectUser to Registry - Migrate LockGlobalSema and guild mail handlers to use Registry - Add comprehensive tests including concurrent access Phase 4 — Per-channel enable/disable: - Add Enabled *bool to EntranceChannelInfo (nil defaults to true) - Skip disabled channels in startup loop, preserving ID stability - Add IsEnabled() helper with backward-compatible default - Update config.example.json with Enabled field --- config.example.json | 16 +- config/config.go | 9 + config/config_test.go | 28 +++ main.go | 38 +++- server/channelserver/channel_registry.go | 58 +++++ .../channelserver/channel_registry_local.go | 156 +++++++++++++ server/channelserver/channel_registry_test.go | 190 ++++++++++++++++ server/channelserver/handlers_guild_ops.go | 23 +- server/channelserver/handlers_session.go | 205 ++++++++++++------ server/channelserver/sys_channel_server.go | 17 ++ 10 files changed, 661 insertions(+), 79 deletions(-) create mode 100644 server/channelserver/channel_registry.go create mode 100644 server/channelserver/channel_registry_local.go create mode 100644 server/channelserver/channel_registry_test.go diff --git a/config.example.json b/config.example.json index a1951e791..e92a7a92f 100644 --- a/config.example.json +++ b/config.example.json @@ -219,34 +219,34 @@ { "Name": "Newbie", "Description": "", "IP": "", "Type": 3, "Recommended": 2, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54001, "MaxPlayers": 100 }, - { "Port": 54002, "MaxPlayers": 100 } + { "Port": 54001, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54002, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Normal", "Description": "", "IP": "", "Type": 1, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54003, "MaxPlayers": 100 }, - { "Port": 54004, "MaxPlayers": 100 } + { "Port": 54003, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54004, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Cities", "Description": "", "IP": "", "Type": 2, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54005, "MaxPlayers": 100 } + { "Port": 54005, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Tavern", "Description": "", "IP": "", "Type": 4, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54006, "MaxPlayers": 100 } + { "Port": 54006, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Return", "Description": "", "IP": "", "Type": 5, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54007, "MaxPlayers": 100 } + { "Port": 54007, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "MezFes", "Description": "", "IP": "", "Type": 6, "Recommended": 6, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54008, "MaxPlayers": 100 } + { "Port": 54008, "MaxPlayers": 100, "Enabled": true } ] } ] diff --git a/config/config.go b/config/config.go index e30cbcd12..6c26798aa 100644 --- a/config/config.go +++ b/config/config.go @@ -297,6 +297,15 @@ type EntranceChannelInfo struct { Port uint16 MaxPlayers uint16 CurrentPlayers uint16 + Enabled *bool // nil defaults to true for backward compatibility +} + +// IsEnabled returns whether this channel is enabled. Defaults to true if Enabled is nil. +func (c *EntranceChannelInfo) IsEnabled() bool { + if c.Enabled == nil { + return true + } + return *c.Enabled } var ErupeConfig *Config diff --git a/config/config_test.go b/config/config_test.go index 782b3ef89..cbad553ec 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -536,6 +536,34 @@ func TestEntranceChannelInfo(t *testing.T) { } } +// TestEntranceChannelInfoIsEnabled tests the Enabled field and IsEnabled helper +func TestEntranceChannelInfoIsEnabled(t *testing.T) { + trueVal := true + falseVal := false + + tests := []struct { + name string + enabled *bool + want bool + }{ + {"nil defaults to true", nil, true}, + {"explicit true", &trueVal, true}, + {"explicit false", &falseVal, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := EntranceChannelInfo{ + Port: 10001, + Enabled: tt.enabled, + } + if got := info.IsEnabled(); got != tt.want { + t.Errorf("IsEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + // TestDiscord verifies Discord struct func TestDiscord(t *testing.T) { discord := Discord{ diff --git a/main.go b/main.go index de44e0693..c6da1c977 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "erupe-ce/server/discordbot" "erupe-ce/server/entranceserver" "erupe-ce/server/signserver" + "strings" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -129,11 +130,30 @@ func main() { } logger.Info("Database: Started successfully") - // Clear stale data - if config.DebugOptions.ProxyPort == 0 { - _ = db.MustExec("DELETE FROM sign_sessions") + // Pre-compute all server IDs this instance will own, so we only + // delete our own rows (safe for multi-instance on the same DB). + var ownedServerIDs []string + { + si := 0 + for _, ee := range config.Entrance.Entries { + ci := 0 + for range ee.Channels { + sid := (4096 + si*256) + (16 + ci) + ownedServerIDs = append(ownedServerIDs, fmt.Sprint(sid)) + ci++ + } + si++ + } + } + + // Clear stale data scoped to this instance's server IDs + if len(ownedServerIDs) > 0 { + idList := strings.Join(ownedServerIDs, ",") + if config.DebugOptions.ProxyPort == 0 { + _ = db.MustExec("DELETE FROM sign_sessions WHERE server_id IN (" + idList + ")") + } + _ = db.MustExec("DELETE FROM servers WHERE server_id IN (" + idList + ")") } - _ = db.MustExec("DELETE FROM servers") _ = db.MustExec(`UPDATE guild_characters SET treasure_hunt=NULL`) // Clean the DB if the option is on. @@ -213,6 +233,12 @@ func main() { for j, ee := range config.Entrance.Entries { for i, ce := range ee.Channels { sid := (4096 + si*256) + (16 + ci) + if !ce.IsEnabled() { + logger.Info(fmt.Sprintf("Channel %d (%d): Disabled via config", count, ce.Port)) + ci++ + count++ + continue + } c := *channelserver.NewServer(&channelserver.Config{ ID: uint16(sid), Logger: logger.Named("channel-" + fmt.Sprint(count)), @@ -237,9 +263,9 @@ func main() { ) channels = append(channels, &c) logger.Info(fmt.Sprintf("Channel %d (%d): Started successfully", count, ce.Port)) - ci++ count++ } + ci++ } ci = 0 si++ @@ -248,8 +274,10 @@ func main() { // Register all servers in DB _ = db.MustExec(channelQuery) + registry := channelserver.NewLocalChannelRegistry(channels) for _, c := range channels { c.Channels = channels + c.Registry = registry } } diff --git a/server/channelserver/channel_registry.go b/server/channelserver/channel_registry.go new file mode 100644 index 000000000..af391a727 --- /dev/null +++ b/server/channelserver/channel_registry.go @@ -0,0 +1,58 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" +) + +// ChannelRegistry abstracts cross-channel operations behind an interface. +// The default LocalChannelRegistry wraps the in-process []*Server slice. +// Future implementations may use DB/Redis/NATS for multi-process deployments. +type ChannelRegistry interface { + // Worldcast broadcasts a packet to all sessions across all channels. + Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) + + // FindSessionByCharID looks up a session by character ID across all channels. + FindSessionByCharID(charID uint32) *Session + + // DisconnectUser disconnects all sessions belonging to the given character IDs. + DisconnectUser(cids []uint32) + + // FindChannelForStage searches all channels for a stage whose ID has the + // given suffix and returns the owning channel's GlobalID, or "" if not found. + FindChannelForStage(stageSuffix string) string + + // SearchSessions searches sessions across all channels using a predicate, + // returning up to max snapshot results. + SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot + + // SearchStages searches stages across all channels with a prefix filter, + // returning up to max snapshot results. + SearchStages(stagePrefix string, max int) []StageSnapshot + + // NotifyMailToCharID finds the session for charID and sends a mail notification. + NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) +} + +// SessionSnapshot is an immutable copy of session data taken under lock. +type SessionSnapshot struct { + CharID uint32 + Name string + StageID string + ServerIP net.IP + ServerPort uint16 + UserBinary3 []byte // Copy of userBinaryParts index 3 +} + +// StageSnapshot is an immutable copy of stage data taken under lock. +type StageSnapshot struct { + ServerIP net.IP + ServerPort uint16 + StageID string + ClientCount int + Reserved int + MaxPlayers uint16 + RawBinData0 []byte + RawBinData1 []byte + RawBinData3 []byte +} diff --git a/server/channelserver/channel_registry_local.go b/server/channelserver/channel_registry_local.go new file mode 100644 index 000000000..a04651d0e --- /dev/null +++ b/server/channelserver/channel_registry_local.go @@ -0,0 +1,156 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" + "strings" +) + +// LocalChannelRegistry is the in-process ChannelRegistry backed by []*Server. +type LocalChannelRegistry struct { + channels []*Server +} + +// NewLocalChannelRegistry creates a LocalChannelRegistry wrapping the given channels. +func NewLocalChannelRegistry(channels []*Server) *LocalChannelRegistry { + return &LocalChannelRegistry{channels: channels} +} + +func (r *LocalChannelRegistry) Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + for _, c := range r.channels { + if c == ignoredChannel { + continue + } + c.BroadcastMHF(pkt, ignoredSession) + } +} + +func (r *LocalChannelRegistry) FindSessionByCharID(charID uint32) *Session { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + if session.charID == charID { + c.Unlock() + return session + } + } + c.Unlock() + } + return nil +} + +func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + for _, cid := range cids { + if session.charID == cid { + _ = session.rawConn.Close() + break + } + } + } + c.Unlock() + } +} + +func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string { + for _, channel := range r.channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, stageSuffix) { + gid := channel.GlobalID + channel.stagesLock.RUnlock() + return gid + } + } + channel.stagesLock.RUnlock() + } + return "" +} + +func (r *LocalChannelRegistry) SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot { + var results []SessionSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() + for _, session := range c.sessions { + if len(results) >= max { + break + } + snap := SessionSnapshot{ + CharID: session.charID, + Name: session.Name, + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + } + if session.stage != nil { + snap.StageID = session.stage.id + } + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + if len(ub3) > 0 { + snap.UserBinary3 = make([]byte, len(ub3)) + copy(snap.UserBinary3, ub3) + } + if predicate(snap) { + results = append(results, snap) + } + } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + return results +} + +func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []StageSnapshot { + var results []StageSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.stagesLock.RLock() + for _, stage := range c.stages { + if len(results) >= max { + break + } + if !strings.HasPrefix(stage.id, stagePrefix) { + continue + } + stage.RLock() + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + bin3 := stage.rawBinaryData[stageBinaryKey{1, 3}] + bin3Copy := make([]byte, len(bin3)) + copy(bin3Copy, bin3) + + results = append(results, StageSnapshot{ + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + StageID: stage.id, + ClientCount: len(stage.clients) + len(stage.reservedClientSlots), + Reserved: len(stage.reservedClientSlots), + MaxPlayers: stage.maxPlayers, + RawBinData0: bin0Copy, + RawBinData1: bin1Copy, + RawBinData3: bin3Copy, + }) + stage.RUnlock() + } + c.stagesLock.RUnlock() + } + return results +} + +func (r *LocalChannelRegistry) NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) { + session := r.FindSessionByCharID(charID) + if session != nil { + SendMailNotification(sender, mail, session) + } +} diff --git a/server/channelserver/channel_registry_test.go b/server/channelserver/channel_registry_test.go new file mode 100644 index 000000000..bdeccffbf --- /dev/null +++ b/server/channelserver/channel_registry_test.go @@ -0,0 +1,190 @@ +package channelserver + +import ( + "net" + "sync" + "testing" +) + +func createTestChannels(count int) []*Server { + channels := make([]*Server, count) + for i := 0; i < count; i++ { + s := createTestServer() + s.ID = uint16(0x1010 + i) + s.IP = "10.0.0.1" + s.Port = uint16(54001 + i) + s.GlobalID = "0101" + s.userBinaryParts = make(map[userBinaryPartID][]byte) + channels[i] = s + } + return channels +} + +func TestLocalRegistryFindSessionByCharID(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + conn1 := &mockConn{} + sess1 := createTestSessionForServer(channels[0], conn1, 100, "Alice") + channels[0].Lock() + channels[0].sessions[conn1] = sess1 + channels[0].Unlock() + + conn2 := &mockConn{} + sess2 := createTestSessionForServer(channels[1], conn2, 200, "Bob") + channels[1].Lock() + channels[1].sessions[conn2] = sess2 + channels[1].Unlock() + + // Find on first channel + found := reg.FindSessionByCharID(100) + if found == nil || found.charID != 100 { + t.Errorf("FindSessionByCharID(100) = %v, want session with charID 100", found) + } + + // Find on second channel + found = reg.FindSessionByCharID(200) + if found == nil || found.charID != 200 { + t.Errorf("FindSessionByCharID(200) = %v, want session with charID 200", found) + } + + // Not found + found = reg.FindSessionByCharID(999) + if found != nil { + t.Errorf("FindSessionByCharID(999) = %v, want nil", found) + } +} + +func TestLocalRegistryFindChannelForStage(t *testing.T) { + channels := createTestChannels(2) + channels[0].GlobalID = "0101" + channels[1].GlobalID = "0102" + reg := NewLocalChannelRegistry(channels) + + channels[1].stagesLock.Lock() + channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42") + channels[1].stagesLock.Unlock() + + gid := reg.FindChannelForStage("u42") + if gid != "0102" { + t.Errorf("FindChannelForStage(u42) = %q, want %q", gid, "0102") + } + + gid = reg.FindChannelForStage("u999") + if gid != "" { + t.Errorf("FindChannelForStage(u999) = %q, want empty", gid) + } +} + +func TestLocalRegistryDisconnectUser(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + conn := &mockConn{} + sess := createTestSessionForServer(channels[0], conn, 42, "Target") + channels[0].Lock() + channels[0].sessions[conn] = sess + channels[0].Unlock() + + reg.DisconnectUser([]uint32{42}) + + if !conn.WasClosed() { + t.Error("DisconnectUser should have closed the connection for charID 42") + } +} + +func TestLocalRegistrySearchSessions(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Add 3 sessions across 2 channels + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + conn3 := &mockConn{} + sess3 := createTestSessionForServer(channels[0], conn3, 3, "Player") + sess3.stage = NewStage("sl1Ns200p0a0u0") + channels[0].Lock() + channels[0].sessions[conn3] = sess3 + channels[0].Unlock() + + // Search all + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 3 { + t.Errorf("SearchSessions(all) returned %d results, want 3", len(results)) + } + + // Search with max + results = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 2) + if len(results) != 2 { + t.Errorf("SearchSessions(max=2) returned %d results, want 2", len(results)) + } + + // Search with predicate + results = reg.SearchSessions(func(s SessionSnapshot) bool { return s.CharID == 1 }, 10) + if len(results) != 1 { + t.Errorf("SearchSessions(charID==1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistrySearchStages(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + channels[0].stagesLock.Lock() + channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1") + channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2") + channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other") + channels[0].stagesLock.Unlock() + + results := reg.SearchStages("sl2Ls210", 10) + if len(results) != 2 { + t.Errorf("SearchStages(sl2Ls210) returned %d results, want 2", len(results)) + } + + results = reg.SearchStages("sl2Ls210", 1) + if len(results) != 1 { + t.Errorf("SearchStages(sl2Ls210, max=1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistryConcurrentAccess(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Populate some sessions + for _, ch := range channels { + for i := 0; i < 10; i++ { + conn := &mockConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 50000 + i}} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + } + + // Run concurrent operations + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(3) + go func(id int) { + defer wg.Done() + _ = reg.FindSessionByCharID(uint32(id%10 + 1)) + }(i) + go func() { + defer wg.Done() + _ = reg.FindChannelForStage("u0") + }() + go func() { + defer wg.Done() + _ = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 5) + }() + } + wg.Wait() +} diff --git a/server/channelserver/handlers_guild_ops.go b/server/channelserver/handlers_guild_ops.go index 5a9a1b3e0..dc4086aea 100644 --- a/server/channelserver/handlers_guild_ops.go +++ b/server/channelserver/handlers_guild_ops.go @@ -304,11 +304,26 @@ func handleMsgMhfOperateGuildMember(s *Session, p mhfpacket.MHFPacket) { doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) } else { _ = mail.Send(s, nil) - for _, channel := range s.server.Channels { - for _, session := range channel.sessions { - if session.charID == pkt.CharID { - SendMailNotification(s, &mail, session) + if s.server.Registry != nil { + s.server.Registry.NotifyMailToCharID(pkt.CharID, s, &mail) + } else { + // Fallback: find the target session under lock, then notify outside the lock. + var targetSession *Session + for _, channel := range s.server.Channels { + channel.Lock() + for _, session := range channel.sessions { + if session.charID == pkt.CharID { + targetSession = session + break + } } + channel.Unlock() + if targetSession != nil { + break + } + } + if targetSession != nil { + SendMailNotification(s, &mail, targetSession) } } doAckSimpleSucceed(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 6e805a7dc..1944a7902 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -399,11 +399,17 @@ func handleMsgSysEcho(s *Session, p mhfpacket.MHFPacket) {} func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysLockGlobalSema) var sgid string - for _, channel := range s.server.Channels { - for id := range channel.stages { - if strings.HasSuffix(id, pkt.UserIDString) { - sgid = channel.GlobalID + if s.server.Registry != nil { + sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString) + } else { + for _, channel := range s.server.Channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, pkt.UserIDString) { + sgid = channel.GlobalID + } } + channel.stagesLock.RUnlock() } } bf := byteframe.NewByteFrame() @@ -468,7 +474,23 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { resp.WriteUint16(0) switch pkt.SearchType { case 1, 2, 3: // usersearchidx, usersearchname, lobbysearchname + // Snapshot matching sessions under lock, then build response outside locks. + type sessionResult struct { + charID uint32 + name []byte + stageID []byte + ip net.IP + port uint16 + userBin3 []byte + } + var results []sessionResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() for _, session := range c.sessions { if count == maxResults { break @@ -483,31 +505,45 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { continue } count++ - sessionName := stringsupport.UTF8ToSJIS(session.Name) - sessionStage := stringsupport.UTF8ToSJIS(session.stage.id) - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - resp.WriteUint32(session.charID) - resp.WriteUint8(uint8(len(sessionStage) + 1)) - resp.WriteUint8(uint8(len(sessionName) + 1)) - resp.WriteUint16(uint16(len(c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}]))) - - // TODO: This case might be <=G2 - if _config.ErupeConfig.RealClientMode <= _config.G1 { - resp.WriteBytes(make([]byte, 8)) - } else { - resp.WriteBytes(make([]byte, 40)) - } - resp.WriteBytes(make([]byte, 8)) - - resp.WriteNullTerminatedBytes(sessionStage) - resp.WriteNullTerminatedBytes(sessionName) - resp.WriteBytes(c.userBinaryParts[userBinaryPartID{session.charID, 3}]) + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + ub3Copy := make([]byte, len(ub3)) + copy(ub3Copy, ub3) + results = append(results, sessionResult{ + charID: session.charID, + name: stringsupport.UTF8ToSJIS(session.Name), + stageID: stringsupport.UTF8ToSJIS(session.stage.id), + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + userBin3: ub3Copy, + }) } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + + for _, r := range results { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(r.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(r.port) + resp.WriteUint32(r.charID) + resp.WriteUint8(uint8(len(r.stageID) + 1)) + resp.WriteUint8(uint8(len(r.name) + 1)) + resp.WriteUint16(uint16(len(r.userBin3))) + + // TODO: This case might be <=G2 + if _config.ErupeConfig.RealClientMode <= _config.G1 { + resp.WriteBytes(make([]byte, 8)) + } else { + resp.WriteBytes(make([]byte, 40)) + } + resp.WriteBytes(make([]byte, 8)) + + resp.WriteNullTerminatedBytes(r.stageID) + resp.WriteNullTerminatedBytes(r.name) + resp.WriteBytes(r.userBin3) } case 4: // lobbysearch type FindPartyParams struct { @@ -594,12 +630,31 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } } + // Snapshot matching stages under lock, then build response outside locks. + type stageResult struct { + ip net.IP + port uint16 + clientCount int + reserved int + maxPlayers uint16 + stageID string + stageData []int16 + rawBinData0 []byte + rawBinData1 []byte + } + var stageResults []stageResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.stagesLock.RLock() for _, stage := range c.stages { if count == maxResults { break } if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) { + stage.RLock() sb3 := byteframe.NewByteFrameFromBytes(stage.rawBinaryData[stageBinaryKey{1, 3}]) _, _ = sb3.Seek(4, 0) @@ -621,6 +676,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { if findPartyParams.RankRestriction >= 0 { if stageData[0] > findPartyParams.RankRestriction { + stage.RUnlock() continue } } @@ -634,47 +690,72 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } if !hasTarget { + stage.RUnlock() continue } } + // Copy binary data under lock + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + count++ - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - - resp.WriteUint16(0) // Static? - resp.WriteUint16(0) // Unk, [0 1 2] - resp.WriteUint16(uint16(len(stage.clients) + len(stage.reservedClientSlots))) - resp.WriteUint16(stage.maxPlayers) - // TODO: Retail returned the number of clients in quests, not workshop/my series - resp.WriteUint16(uint16(len(stage.reservedClientSlots))) - - resp.WriteUint8(0) // Static? - resp.WriteUint8(uint8(stage.maxPlayers)) - resp.WriteUint8(1) // Static? - resp.WriteUint8(uint8(len(stage.id) + 1)) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 0}]))) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 1}]))) - - for i := range stageData { - if _config.ErupeConfig.RealClientMode >= _config.Z1 { - resp.WriteInt16(stageData[i]) - } else { - resp.WriteInt8(int8(stageData[i])) - } - } - resp.WriteUint8(0) // Unk - resp.WriteUint8(0) // Unk - - resp.WriteNullTerminatedBytes([]byte(stage.id)) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 0}]) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 1}]) + stageResults = append(stageResults, stageResult{ + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + clientCount: len(stage.clients) + len(stage.reservedClientSlots), + reserved: len(stage.reservedClientSlots), + maxPlayers: stage.maxPlayers, + stageID: stage.id, + stageData: stageData, + rawBinData0: bin0Copy, + rawBinData1: bin1Copy, + }) + stage.RUnlock() } } + c.stagesLock.RUnlock() + } + + for _, sr := range stageResults { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(sr.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(sr.port) + + resp.WriteUint16(0) // Static? + resp.WriteUint16(0) // Unk, [0 1 2] + resp.WriteUint16(uint16(sr.clientCount)) + resp.WriteUint16(sr.maxPlayers) + // TODO: Retail returned the number of clients in quests, not workshop/my series + resp.WriteUint16(uint16(sr.reserved)) + + resp.WriteUint8(0) // Static? + resp.WriteUint8(uint8(sr.maxPlayers)) + resp.WriteUint8(1) // Static? + resp.WriteUint8(uint8(len(sr.stageID) + 1)) + resp.WriteUint8(uint8(len(sr.rawBinData0))) + resp.WriteUint8(uint8(len(sr.rawBinData1))) + + for i := range sr.stageData { + if _config.ErupeConfig.RealClientMode >= _config.Z1 { + resp.WriteInt16(sr.stageData[i]) + } else { + resp.WriteInt8(int8(sr.stageData[i])) + } + } + resp.WriteUint8(0) // Unk + resp.WriteUint8(0) // Unk + + resp.WriteNullTerminatedBytes([]byte(sr.stageID)) + resp.WriteBytes(sr.rawBinData0) + resp.WriteBytes(sr.rawBinData1) } } _, _ = resp.Seek(0, io.SeekStart) diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index d951f93d0..f042c6109 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -37,6 +37,7 @@ type userBinaryPartID struct { type Server struct { sync.Mutex Channels []*Server + Registry ChannelRegistry ID uint16 GlobalID string IP string @@ -271,6 +272,10 @@ func (s *Server) BroadcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session) // WorldcastMHF broadcasts a packet to all sessions across all channel servers. func (s *Server) WorldcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + if s.Registry != nil { + s.Registry.Worldcast(pkt, ignoredSession, ignoredChannel) + return + } for _, c := range s.Channels { if c == ignoredChannel { continue @@ -317,12 +322,18 @@ func (s *Server) DiscordScreenShotSend(charName string, title string, descriptio // FindSessionByCharID looks up a session by character ID across all channels. func (s *Server) FindSessionByCharID(charID uint32) *Session { + if s.Registry != nil { + return s.Registry.FindSessionByCharID(charID) + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { if session.charID == charID { + c.Unlock() return session } } + c.Unlock() } return nil } @@ -341,7 +352,12 @@ func (s *Server) DisconnectUser(uid uint32) { cids = append(cids, cid) } } + if s.Registry != nil { + s.Registry.DisconnectUser(cids) + return + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { for _, cid := range cids { if session.charID == cid { @@ -350,6 +366,7 @@ func (s *Server) DisconnectUser(uid uint32) { } } } + c.Unlock() } } From 0e84377e21951ac051bd27a21192820b2ecee477 Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 20 Feb 2026 02:49:23 +0100 Subject: [PATCH 2/5] feat(protbot): add headless MHF protocol bot as cmd/protbot MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Copy MHBridge into the Erupe module as cmd/protbot/ so it can be built, tested, and maintained alongside the server. The bot implements the full sign → entrance → channel login flow and supports lobby entry, chat, session setup, and quest enumeration. The conn/ package keeps its own Blowfish crypto primitives to avoid importing erupe-ce/config (which requires a config file at init). --- cmd/protbot/conn/bin8.go | 37 +++ cmd/protbot/conn/bin8_test.go | 52 ++++ cmd/protbot/conn/conn.go | 52 ++++ cmd/protbot/conn/crypt_conn.go | 115 ++++++++ cmd/protbot/conn/crypt_conn_test.go | 152 ++++++++++ cmd/protbot/conn/crypt_packet.go | 78 +++++ cmd/protbot/main.go | 154 ++++++++++ cmd/protbot/protocol/channel.go | 190 ++++++++++++ cmd/protbot/protocol/entrance.go | 142 +++++++++ cmd/protbot/protocol/opcodes.go | 23 ++ cmd/protbot/protocol/packets.go | 229 +++++++++++++++ cmd/protbot/protocol/packets_test.go | 412 +++++++++++++++++++++++++++ cmd/protbot/protocol/sign.go | 105 +++++++ cmd/protbot/scenario/chat.go | 74 +++++ cmd/protbot/scenario/login.go | 82 ++++++ cmd/protbot/scenario/logout.go | 17 ++ cmd/protbot/scenario/quest.go | 31 ++ cmd/protbot/scenario/session.go | 50 ++++ cmd/protbot/scenario/stage.go | 111 ++++++++ 19 files changed, 2106 insertions(+) create mode 100644 cmd/protbot/conn/bin8.go create mode 100644 cmd/protbot/conn/bin8_test.go create mode 100644 cmd/protbot/conn/conn.go create mode 100644 cmd/protbot/conn/crypt_conn.go create mode 100644 cmd/protbot/conn/crypt_conn_test.go create mode 100644 cmd/protbot/conn/crypt_packet.go create mode 100644 cmd/protbot/main.go create mode 100644 cmd/protbot/protocol/channel.go create mode 100644 cmd/protbot/protocol/entrance.go create mode 100644 cmd/protbot/protocol/opcodes.go create mode 100644 cmd/protbot/protocol/packets.go create mode 100644 cmd/protbot/protocol/packets_test.go create mode 100644 cmd/protbot/protocol/sign.go create mode 100644 cmd/protbot/scenario/chat.go create mode 100644 cmd/protbot/scenario/login.go create mode 100644 cmd/protbot/scenario/logout.go create mode 100644 cmd/protbot/scenario/quest.go create mode 100644 cmd/protbot/scenario/session.go create mode 100644 cmd/protbot/scenario/stage.go diff --git a/cmd/protbot/conn/bin8.go b/cmd/protbot/conn/bin8.go new file mode 100644 index 000000000..4a1256fc5 --- /dev/null +++ b/cmd/protbot/conn/bin8.go @@ -0,0 +1,37 @@ +package conn + +import "encoding/binary" + +var ( + bin8Key = []byte{0x01, 0x23, 0x34, 0x45, 0x56, 0xAB, 0xCD, 0xEF} + sum32Table0 = []byte{0x35, 0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12} + sum32Table1 = []byte{0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12, 0xDE, 0xDE, 0x35} +) + +// CalcSum32 calculates the custom MHF "sum32" checksum. +func CalcSum32(data []byte) uint32 { + tableIdx0 := (len(data) + 1) & 0xFF + tableIdx1 := int((data[len(data)>>1] + 1) & 0xFF) + out := make([]byte, 4) + for i := 0; i < len(data); i++ { + key := data[i] ^ sum32Table0[(tableIdx0+i)%7] ^ sum32Table1[(tableIdx1+i)%9] + out[i&3] = (out[i&3] + key) & 0xFF + } + return binary.BigEndian.Uint32(out) +} + +func rotate(k *uint32) { + *k = uint32(((54323 * uint(*k)) + 1) & 0xFFFFFFFF) +} + +// DecryptBin8 decrypts MHF "binary8" data. +func DecryptBin8(data []byte, key byte) []byte { + k := uint32(key) + output := make([]byte, len(data)) + for i := 0; i < len(data); i++ { + rotate(&k) + tmp := data[i] ^ byte((k>>13)&0xFF) + output[i] = tmp ^ bin8Key[i&7] + } + return output +} diff --git a/cmd/protbot/conn/bin8_test.go b/cmd/protbot/conn/bin8_test.go new file mode 100644 index 000000000..fa820c030 --- /dev/null +++ b/cmd/protbot/conn/bin8_test.go @@ -0,0 +1,52 @@ +package conn + +import ( + "testing" +) + +// TestCalcSum32 verifies the checksum against a known input. +func TestCalcSum32(t *testing.T) { + // Verify determinism: same input gives same output. + data := []byte("Hello, MHF!") + sum1 := CalcSum32(data) + sum2 := CalcSum32(data) + if sum1 != sum2 { + t.Fatalf("CalcSum32 not deterministic: %08X != %08X", sum1, sum2) + } + + // Different inputs produce different outputs (basic sanity). + data2 := []byte("Hello, MHF?") + sum3 := CalcSum32(data2) + if sum1 == sum3 { + t.Fatalf("CalcSum32 collision on different inputs: both %08X", sum1) + } +} + +// TestDecryptBin8RoundTrip verifies that encrypting and decrypting with Bin8 +// produces the original data. We only have DecryptBin8, but we can verify +// the encrypt→decrypt path by implementing encrypt inline here. +func TestDecryptBin8RoundTrip(t *testing.T) { + original := []byte("Test data for Bin8 encryption round-trip") + key := byte(0x42) + + // Encrypt (inline copy of Erupe's EncryptBin8) + k := uint32(key) + encrypted := make([]byte, len(original)) + for i := 0; i < len(original); i++ { + rotate(&k) + tmp := bin8Key[i&7] ^ byte((k>>13)&0xFF) + encrypted[i] = original[i] ^ tmp + } + + // Decrypt + decrypted := DecryptBin8(encrypted, key) + + if len(decrypted) != len(original) { + t.Fatalf("length mismatch: got %d, want %d", len(decrypted), len(original)) + } + for i := range original { + if decrypted[i] != original[i] { + t.Fatalf("byte %d: got 0x%02X, want 0x%02X", i, decrypted[i], original[i]) + } + } +} diff --git a/cmd/protbot/conn/conn.go b/cmd/protbot/conn/conn.go new file mode 100644 index 000000000..b7ad33173 --- /dev/null +++ b/cmd/protbot/conn/conn.go @@ -0,0 +1,52 @@ +package conn + +import ( + "fmt" + "net" +) + +// MHFConn wraps a CryptConn and provides convenience methods for MHF connections. +type MHFConn struct { + *CryptConn + RawConn net.Conn +} + +// DialWithInit connects to addr and sends the 8 NULL byte initialization +// required by sign and entrance servers. +func DialWithInit(addr string) (*MHFConn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + // Sign and entrance servers expect 8 NULL bytes to initialize the connection. + _, err = conn.Write(make([]byte, 8)) + if err != nil { + conn.Close() + return nil, fmt.Errorf("write init bytes to %s: %w", addr, err) + } + + return &MHFConn{ + CryptConn: NewCryptConn(conn), + RawConn: conn, + }, nil +} + +// DialDirect connects to addr without sending initialization bytes. +// Used for channel server connections. +func DialDirect(addr string) (*MHFConn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + return &MHFConn{ + CryptConn: NewCryptConn(conn), + RawConn: conn, + }, nil +} + +// Close closes the underlying connection. +func (c *MHFConn) Close() error { + return c.RawConn.Close() +} diff --git a/cmd/protbot/conn/crypt_conn.go b/cmd/protbot/conn/crypt_conn.go new file mode 100644 index 000000000..e07bcf5f5 --- /dev/null +++ b/cmd/protbot/conn/crypt_conn.go @@ -0,0 +1,115 @@ +package conn + +import ( + "encoding/hex" + "errors" + "erupe-ce/network/crypto" + "fmt" + "io" + "net" +) + +// CryptConn is an MHF encrypted two-way connection. +// Adapted from Erupe's network/crypt_conn.go with config dependency removed. +// Hardcoded to ZZ mode (supports Pf0-based extended data size). +type CryptConn struct { + conn net.Conn + readKeyRot uint32 + sendKeyRot uint32 + sentPackets int32 + prevRecvPacketCombinedCheck uint16 + prevSendPacketCombinedCheck uint16 +} + +// NewCryptConn creates a new CryptConn with proper default values. +func NewCryptConn(conn net.Conn) *CryptConn { + return &CryptConn{ + conn: conn, + readKeyRot: 995117, + sendKeyRot: 995117, + } +} + +// ReadPacket reads a packet from the connection and returns the decrypted data. +func (cc *CryptConn) ReadPacket() ([]byte, error) { + headerData := make([]byte, CryptPacketHeaderLength) + _, err := io.ReadFull(cc.conn, headerData) + if err != nil { + return nil, err + } + + cph, err := NewCryptPacketHeader(headerData) + if err != nil { + return nil, err + } + + // ZZ mode: extended data size using Pf0 field. + encryptedPacketBody := make([]byte, uint32(cph.DataSize)+(uint32(cph.Pf0-0x03)*0x1000)) + _, err = io.ReadFull(cc.conn, encryptedPacketBody) + if err != nil { + return nil, err + } + + if cph.KeyRotDelta != 0 { + cc.readKeyRot = uint32(cph.KeyRotDelta) * (cc.readKeyRot + 1) + } + + out, combinedCheck, check0, check1, check2 := crypto.Crypto(encryptedPacketBody, cc.readKeyRot, false, nil) + if cph.Check0 != check0 || cph.Check1 != check1 || cph.Check2 != check2 { + fmt.Printf("got c0 %X, c1 %X, c2 %X\n", check0, check1, check2) + fmt.Printf("want c0 %X, c1 %X, c2 %X\n", cph.Check0, cph.Check1, cph.Check2) + fmt.Printf("headerData:\n%s\n", hex.Dump(headerData)) + fmt.Printf("encryptedPacketBody:\n%s\n", hex.Dump(encryptedPacketBody)) + + // Attempt bruteforce recovery. + fmt.Println("Crypto out of sync? Attempting bruteforce") + for key := byte(0); key < 255; key++ { + out, combinedCheck, check0, check1, check2 = crypto.Crypto(encryptedPacketBody, 0, false, &key) + if cph.Check0 == check0 && cph.Check1 == check1 && cph.Check2 == check2 { + fmt.Printf("Bruteforce successful, override key: 0x%X\n", key) + cc.prevRecvPacketCombinedCheck = combinedCheck + return out, nil + } + } + + return nil, errors.New("decrypted data checksum doesn't match header") + } + + cc.prevRecvPacketCombinedCheck = combinedCheck + return out, nil +} + +// SendPacket encrypts and sends a packet. +func (cc *CryptConn) SendPacket(data []byte) error { + keyRotDelta := byte(3) + + if keyRotDelta != 0 { + cc.sendKeyRot = uint32(keyRotDelta) * (cc.sendKeyRot + 1) + } + + encData, combinedCheck, check0, check1, check2 := crypto.Crypto(data, cc.sendKeyRot, true, nil) + + header := &CryptPacketHeader{} + header.Pf0 = byte(((uint(len(encData)) >> 12) & 0xF3) | 3) + header.KeyRotDelta = keyRotDelta + header.PacketNum = uint16(cc.sentPackets) + header.DataSize = uint16(len(encData)) + header.PrevPacketCombinedCheck = cc.prevSendPacketCombinedCheck + header.Check0 = check0 + header.Check1 = check1 + header.Check2 = check2 + + headerBytes, err := header.Encode() + if err != nil { + return err + } + + _, err = cc.conn.Write(append(headerBytes, encData...)) + if err != nil { + return err + } + cc.sentPackets++ + cc.prevSendPacketCombinedCheck = combinedCheck + + return nil +} diff --git a/cmd/protbot/conn/crypt_conn_test.go b/cmd/protbot/conn/crypt_conn_test.go new file mode 100644 index 000000000..cb03004db --- /dev/null +++ b/cmd/protbot/conn/crypt_conn_test.go @@ -0,0 +1,152 @@ +package conn + +import ( + "io" + "net" + "testing" +) + +// TestCryptConnRoundTrip verifies that encrypting and decrypting a packet +// through a pair of CryptConn instances produces the original data. +func TestCryptConnRoundTrip(t *testing.T) { + // Create an in-process TCP pipe. + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + sender := NewCryptConn(client) + receiver := NewCryptConn(server) + + testCases := [][]byte{ + {0x00, 0x14, 0x00, 0x00, 0x00, 0x01}, // Minimal login-like packet + {0xDE, 0xAD, 0xBE, 0xEF}, + make([]byte, 256), // Larger packet + } + + for i, original := range testCases { + // Send in a goroutine to avoid blocking. + errCh := make(chan error, 1) + go func() { + errCh <- sender.SendPacket(original) + }() + + received, err := receiver.ReadPacket() + if err != nil { + t.Fatalf("case %d: ReadPacket error: %v", i, err) + } + + if err := <-errCh; err != nil { + t.Fatalf("case %d: SendPacket error: %v", i, err) + } + + if len(received) != len(original) { + t.Fatalf("case %d: length mismatch: got %d, want %d", i, len(received), len(original)) + } + for j := range original { + if received[j] != original[j] { + t.Fatalf("case %d: byte %d mismatch: got 0x%02X, want 0x%02X", i, j, received[j], original[j]) + } + } + } +} + +// TestCryptPacketHeaderRoundTrip verifies header encode/decode. +func TestCryptPacketHeaderRoundTrip(t *testing.T) { + original := &CryptPacketHeader{ + Pf0: 0x03, + KeyRotDelta: 0x03, + PacketNum: 42, + DataSize: 100, + PrevPacketCombinedCheck: 0x1234, + Check0: 0xAAAA, + Check1: 0xBBBB, + Check2: 0xCCCC, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("Encode error: %v", err) + } + + if len(encoded) != CryptPacketHeaderLength { + t.Fatalf("encoded length: got %d, want %d", len(encoded), CryptPacketHeaderLength) + } + + decoded, err := NewCryptPacketHeader(encoded) + if err != nil { + t.Fatalf("NewCryptPacketHeader error: %v", err) + } + + if *decoded != *original { + t.Fatalf("header mismatch:\ngot %+v\nwant %+v", *decoded, *original) + } +} + +// TestMultiPacketSequence verifies that key rotation stays in sync across +// multiple sequential packets. +func TestMultiPacketSequence(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + sender := NewCryptConn(client) + receiver := NewCryptConn(server) + + for i := 0; i < 10; i++ { + data := []byte{byte(i), byte(i + 1), byte(i + 2), byte(i + 3)} + + errCh := make(chan error, 1) + go func() { + errCh <- sender.SendPacket(data) + }() + + received, err := receiver.ReadPacket() + if err != nil { + t.Fatalf("packet %d: ReadPacket error: %v", i, err) + } + + if err := <-errCh; err != nil { + t.Fatalf("packet %d: SendPacket error: %v", i, err) + } + + for j := range data { + if received[j] != data[j] { + t.Fatalf("packet %d byte %d: got 0x%02X, want 0x%02X", i, j, received[j], data[j]) + } + } + } +} + +// TestDialWithInit verifies that DialWithInit sends 8 NULL bytes on connect. +func TestDialWithInit(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + done := make(chan []byte, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 8) + _, _ = io.ReadFull(conn, buf) + done <- buf + }() + + c, err := DialWithInit(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + initBytes := <-done + for i, b := range initBytes { + if b != 0 { + t.Fatalf("init byte %d: got 0x%02X, want 0x00", i, b) + } + } +} diff --git a/cmd/protbot/conn/crypt_packet.go b/cmd/protbot/conn/crypt_packet.go new file mode 100644 index 000000000..058a7e2bb --- /dev/null +++ b/cmd/protbot/conn/crypt_packet.go @@ -0,0 +1,78 @@ +// Package conn provides MHF encrypted connection primitives. +// +// This is adapted from Erupe's network/crypt_packet.go to avoid importing +// erupe-ce/config (whose init() calls os.Exit without a config file). +package conn + +import ( + "bytes" + "encoding/binary" +) + +const CryptPacketHeaderLength = 14 + +// CryptPacketHeader represents the parsed information of an encrypted packet header. +type CryptPacketHeader struct { + Pf0 byte + KeyRotDelta byte + PacketNum uint16 + DataSize uint16 + PrevPacketCombinedCheck uint16 + Check0 uint16 + Check1 uint16 + Check2 uint16 +} + +// NewCryptPacketHeader parses raw bytes into a CryptPacketHeader. +func NewCryptPacketHeader(data []byte) (*CryptPacketHeader, error) { + var c CryptPacketHeader + r := bytes.NewReader(data) + + if err := binary.Read(r, binary.BigEndian, &c.Pf0); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.KeyRotDelta); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.PacketNum); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.DataSize); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.PrevPacketCombinedCheck); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check0); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check1); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check2); err != nil { + return nil, err + } + + return &c, nil +} + +// Encode encodes the CryptPacketHeader into raw bytes. +func (c *CryptPacketHeader) Encode() ([]byte, error) { + buf := bytes.NewBuffer([]byte{}) + data := []interface{}{ + c.Pf0, + c.KeyRotDelta, + c.PacketNum, + c.DataSize, + c.PrevPacketCombinedCheck, + c.Check0, + c.Check1, + c.Check2, + } + for _, v := range data { + if err := binary.Write(buf, binary.BigEndian, v); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} diff --git a/cmd/protbot/main.go b/cmd/protbot/main.go new file mode 100644 index 000000000..4b1e0f72b --- /dev/null +++ b/cmd/protbot/main.go @@ -0,0 +1,154 @@ +// protbot is a headless MHF protocol bot for testing Erupe server instances. +// +// Usage: +// +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action login +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action lobby +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action session +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action chat --message "Hello" +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action quests +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "erupe-ce/cmd/protbot/scenario" +) + +func main() { + signAddr := flag.String("sign-addr", "127.0.0.1:53312", "Sign server address (host:port)") + user := flag.String("user", "", "Username") + pass := flag.String("pass", "", "Password") + action := flag.String("action", "login", "Action to perform: login, lobby, session, chat, quests") + message := flag.String("message", "", "Chat message to send (used with --action chat)") + flag.Parse() + + if *user == "" || *pass == "" { + fmt.Fprintln(os.Stderr, "error: --user and --pass are required") + flag.Usage() + os.Exit(1) + } + + switch *action { + case "login": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + fmt.Println("[done] Login successful!") + result.Channel.Close() + + case "lobby": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + fmt.Println("[done] Lobby entry successful!") + result.Channel.Close() + + case "session": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + fmt.Println("[session] Connected. Press Ctrl+C to disconnect.") + waitForSignal() + scenario.Logout(result.Channel) + + case "chat": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + + // Register chat listener. + scenario.ListenChat(result.Channel, func(msg scenario.ChatMessage) { + fmt.Printf("[chat] <%s> (type=%d): %s\n", msg.SenderName, msg.ChatType, msg.Message) + }) + + // Send a message if provided. + if *message != "" { + if err := scenario.SendChat(result.Channel, 0x03, 1, *message, *user); err != nil { + fmt.Fprintf(os.Stderr, "send chat failed: %v\n", err) + } + } + + fmt.Println("[chat] Listening for chat messages. Press Ctrl+C to disconnect.") + waitForSignal() + scenario.Logout(result.Channel) + + case "quests": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + + data, err := scenario.EnumerateQuests(result.Channel, 0, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "enumerate quests failed: %v\n", err) + scenario.Logout(result.Channel) + os.Exit(1) + } + fmt.Printf("[quests] Received %d bytes of quest data\n", len(data)) + scenario.Logout(result.Channel) + + default: + fmt.Fprintf(os.Stderr, "unknown action: %s (supported: login, lobby, session, chat, quests)\n", *action) + os.Exit(1) + } +} + +// waitForSignal blocks until SIGINT or SIGTERM is received. +func waitForSignal() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + fmt.Println("\n[signal] Shutting down...") +} diff --git a/cmd/protbot/protocol/channel.go b/cmd/protbot/protocol/channel.go new file mode 100644 index 000000000..ba8ff9552 --- /dev/null +++ b/cmd/protbot/protocol/channel.go @@ -0,0 +1,190 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "sync" + "sync/atomic" + "time" + + "erupe-ce/cmd/protbot/conn" +) + +// PacketHandler is a callback invoked when a server-pushed packet is received. +type PacketHandler func(opcode uint16, data []byte) + +// ChannelConn manages a connection to a channel server. +type ChannelConn struct { + conn *conn.MHFConn + ackCounter uint32 + waiters sync.Map // map[uint32]chan *AckResponse + handlers sync.Map // map[uint16]PacketHandler + closed atomic.Bool +} + +// OnPacket registers a handler for a specific server-pushed opcode. +// Only one handler per opcode; later registrations replace earlier ones. +func (ch *ChannelConn) OnPacket(opcode uint16, handler PacketHandler) { + ch.handlers.Store(opcode, handler) +} + +// AckResponse holds the parsed ACK data from the server. +type AckResponse struct { + AckHandle uint32 + IsBufferResponse bool + ErrorCode uint8 + Data []byte +} + +// ConnectChannel establishes a connection to a channel server. +// Channel servers do NOT use the 8 NULL byte initialization. +func ConnectChannel(addr string) (*ChannelConn, error) { + c, err := conn.DialDirect(addr) + if err != nil { + return nil, fmt.Errorf("channel connect: %w", err) + } + + ch := &ChannelConn{ + conn: c, + } + + go ch.recvLoop() + return ch, nil +} + +// NextAckHandle returns the next unique ACK handle for packet requests. +func (ch *ChannelConn) NextAckHandle() uint32 { + return atomic.AddUint32(&ch.ackCounter, 1) +} + +// SendPacket encrypts and sends raw packet data (including the 0x00 0x10 terminator +// which is already appended by the Build* functions in packets.go). +func (ch *ChannelConn) SendPacket(data []byte) error { + return ch.conn.SendPacket(data) +} + +// WaitForAck waits for an ACK response matching the given handle. +func (ch *ChannelConn) WaitForAck(handle uint32, timeout time.Duration) (*AckResponse, error) { + waitCh := make(chan *AckResponse, 1) + ch.waiters.Store(handle, waitCh) + defer ch.waiters.Delete(handle) + + select { + case resp := <-waitCh: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("ACK timeout for handle %d", handle) + } +} + +// Close closes the channel connection. +func (ch *ChannelConn) Close() error { + ch.closed.Store(true) + return ch.conn.Close() +} + +// recvLoop continuously reads packets from the channel server and dispatches ACKs. +func (ch *ChannelConn) recvLoop() { + for { + if ch.closed.Load() { + return + } + + pkt, err := ch.conn.ReadPacket() + if err != nil { + if ch.closed.Load() { + return + } + fmt.Printf("[channel] read error: %v\n", err) + return + } + + if len(pkt) < 2 { + continue + } + + // Strip trailing 0x00 0x10 terminator if present for opcode parsing. + // Packets from server: [opcode uint16][fields...][0x00 0x10] + opcode := binary.BigEndian.Uint16(pkt[0:2]) + + switch opcode { + case MSG_SYS_ACK: + ch.handleAck(pkt[2:]) + case MSG_SYS_PING: + ch.handlePing(pkt[2:]) + default: + if val, ok := ch.handlers.Load(opcode); ok { + val.(PacketHandler)(opcode, pkt[2:]) + } else { + fmt.Printf("[channel] recv opcode 0x%04X (%d bytes)\n", opcode, len(pkt)) + } + } + } +} + +// handleAck parses an ACK packet and dispatches it to a waiting caller. +// Reference: Erupe network/mhfpacket/msg_sys_ack.go +func (ch *ChannelConn) handleAck(data []byte) { + if len(data) < 8 { + return + } + + ackHandle := binary.BigEndian.Uint32(data[0:4]) + isBuffer := data[4] > 0 + errorCode := data[5] + + var ackData []byte + if isBuffer { + payloadSize := binary.BigEndian.Uint16(data[6:8]) + offset := uint32(8) + if payloadSize == 0xFFFF { + if len(data) < 12 { + return + } + payloadSize32 := binary.BigEndian.Uint32(data[8:12]) + offset = 12 + if uint32(len(data)) >= offset+payloadSize32 { + ackData = data[offset : offset+payloadSize32] + } + } else { + if uint32(len(data)) >= offset+uint32(payloadSize) { + ackData = data[offset : offset+uint32(payloadSize)] + } + } + } else { + // Simple ACK: 4 bytes of data after the uint16 field. + if len(data) >= 12 { + ackData = data[8:12] + } + } + + resp := &AckResponse{ + AckHandle: ackHandle, + IsBufferResponse: isBuffer, + ErrorCode: errorCode, + Data: ackData, + } + + if val, ok := ch.waiters.Load(ackHandle); ok { + waitCh := val.(chan *AckResponse) + select { + case waitCh <- resp: + default: + } + } else { + fmt.Printf("[channel] unexpected ACK handle %d (error=%d, buffer=%v, %d bytes)\n", + ackHandle, errorCode, isBuffer, len(ackData)) + } +} + +// handlePing responds to a server ping to keep the connection alive. +func (ch *ChannelConn) handlePing(data []byte) { + var ackHandle uint32 + if len(data) >= 4 { + ackHandle = binary.BigEndian.Uint32(data[0:4]) + } + pkt := BuildPingPacket(ackHandle) + if err := ch.conn.SendPacket(pkt); err != nil { + fmt.Printf("[channel] ping response failed: %v\n", err) + } +} diff --git a/cmd/protbot/protocol/entrance.go b/cmd/protbot/protocol/entrance.go new file mode 100644 index 000000000..dcab094e1 --- /dev/null +++ b/cmd/protbot/protocol/entrance.go @@ -0,0 +1,142 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "net" + + "erupe-ce/common/byteframe" + + "erupe-ce/cmd/protbot/conn" +) + +// ServerEntry represents a channel server from the entrance server response. +type ServerEntry struct { + IP string + Port uint16 + Name string +} + +// DoEntrance connects to the entrance server and retrieves the server list. +// Reference: Erupe server/entranceserver/entrance_server.go and make_resp.go. +func DoEntrance(addr string) ([]ServerEntry, error) { + c, err := conn.DialWithInit(addr) + if err != nil { + return nil, fmt.Errorf("entrance connect: %w", err) + } + defer c.Close() + + // Send a minimal packet (the entrance server reads it, checks len > 5 for USR data). + // An empty/short packet triggers only SV2 response. + bf := byteframe.NewByteFrame() + bf.WriteUint8(0) + if err := c.SendPacket(bf.Data()); err != nil { + return nil, fmt.Errorf("entrance send: %w", err) + } + + resp, err := c.ReadPacket() + if err != nil { + return nil, fmt.Errorf("entrance recv: %w", err) + } + + return parseEntranceResponse(resp) +} + +// parseEntranceResponse parses the Bin8-encrypted entrance server response. +// Reference: Erupe server/entranceserver/make_resp.go (makeHeader, makeSv2Resp) +func parseEntranceResponse(data []byte) ([]ServerEntry, error) { + if len(data) < 2 { + return nil, fmt.Errorf("entrance response too short") + } + + // First byte is the Bin8 encryption key. + key := data[0] + decrypted := conn.DecryptBin8(data[1:], key) + + rbf := byteframe.NewByteFrameFromBytes(decrypted) + + // Read response type header: "SV2" or "SVR" + respType := string(rbf.ReadBytes(3)) + if respType != "SV2" && respType != "SVR" { + return nil, fmt.Errorf("unexpected entrance response type: %s", respType) + } + + entryCount := rbf.ReadUint16() + dataLen := rbf.ReadUint16() + if dataLen == 0 { + return nil, nil + } + expectedSum := rbf.ReadUint32() + serverData := rbf.ReadBytes(uint(dataLen)) + + actualSum := conn.CalcSum32(serverData) + if expectedSum != actualSum { + return nil, fmt.Errorf("entrance checksum mismatch: expected %08X, got %08X", expectedSum, actualSum) + } + + return parseServerEntries(serverData, entryCount) +} + +// parseServerEntries parses the server info binary blob. +// Reference: Erupe server/entranceserver/make_resp.go (encodeServerInfo) +func parseServerEntries(data []byte, entryCount uint16) ([]ServerEntry, error) { + bf := byteframe.NewByteFrameFromBytes(data) + var entries []ServerEntry + + for i := uint16(0); i < entryCount; i++ { + ipBytes := bf.ReadBytes(4) + ip := net.IP([]byte{ + byte(ipBytes[3]), byte(ipBytes[2]), + byte(ipBytes[1]), byte(ipBytes[0]), + }) + + _ = bf.ReadUint16() // serverIdx | 16 + _ = bf.ReadUint16() // 0 + channelCount := bf.ReadUint16() + _ = bf.ReadUint8() // Type + _ = bf.ReadUint8() // Season/rotation + + // G1+ recommended flag + _ = bf.ReadUint8() + + // G51+ (ZZ): skip 1 byte, then read 65-byte padded name + _ = bf.ReadUint8() + nameBytes := bf.ReadBytes(65) + + // GG+: AllowedClientFlags + _ = bf.ReadUint32() + + // Parse name (null-separated: name + description) + name := "" + for j := 0; j < len(nameBytes); j++ { + if nameBytes[j] == 0 { + break + } + name += string(nameBytes[j]) + } + + // Read channel entries + for j := uint16(0); j < channelCount; j++ { + port := bf.ReadUint16() + _ = bf.ReadUint16() // channelIdx | 16 + _ = bf.ReadUint16() // maxPlayers + _ = bf.ReadUint16() // currentPlayers + _ = bf.ReadBytes(14) // remaining channel fields (7 x uint16) + _ = bf.ReadUint16() // 12345 + + serverIP := ip.String() + // Convert 127.0.0.1 representation + if binary.LittleEndian.Uint32(ipBytes) == 0x0100007F { + serverIP = "127.0.0.1" + } + + entries = append(entries, ServerEntry{ + IP: serverIP, + Port: port, + Name: fmt.Sprintf("%s ch%d", name, j+1), + }) + } + } + + return entries, nil +} diff --git a/cmd/protbot/protocol/opcodes.go b/cmd/protbot/protocol/opcodes.go new file mode 100644 index 000000000..37c57a158 --- /dev/null +++ b/cmd/protbot/protocol/opcodes.go @@ -0,0 +1,23 @@ +// Package protocol implements MHF network protocol message building and parsing. +package protocol + +// Packet opcodes (subset from Erupe's network/packetid.go iota). +const ( + MSG_SYS_ACK uint16 = 0x0012 + MSG_SYS_LOGIN uint16 = 0x0014 + MSG_SYS_LOGOUT uint16 = 0x0015 + MSG_SYS_PING uint16 = 0x0017 + MSG_SYS_CAST_BINARY uint16 = 0x0018 + MSG_SYS_TIME uint16 = 0x001A + MSG_SYS_CASTED_BINARY uint16 = 0x001B + MSG_SYS_ISSUE_LOGKEY uint16 = 0x001D + MSG_SYS_ENTER_STAGE uint16 = 0x0022 + MSG_SYS_ENUMERATE_STAGE uint16 = 0x002F + MSG_SYS_INSERT_USER uint16 = 0x0050 + MSG_SYS_DELETE_USER uint16 = 0x0051 + MSG_SYS_UPDATE_RIGHT uint16 = 0x0058 + MSG_SYS_RIGHTS_RELOAD uint16 = 0x005D + MSG_MHF_LOADDATA uint16 = 0x0061 + MSG_MHF_ENUMERATE_QUEST uint16 = 0x009F + MSG_MHF_GET_WEEKLY_SCHED uint16 = 0x00E1 +) diff --git a/cmd/protbot/protocol/packets.go b/cmd/protbot/protocol/packets.go new file mode 100644 index 000000000..7c65f7804 --- /dev/null +++ b/cmd/protbot/protocol/packets.go @@ -0,0 +1,229 @@ +package protocol + +import ( + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" +) + +// BuildLoginPacket builds a MSG_SYS_LOGIN packet. +// Layout mirrors Erupe's MsgSysLogin.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint32 charID +// uint32 loginTokenNumber +// uint16 hardcodedZero +// uint16 requestVersion (set to 0xCAFE as dummy) +// uint32 charID (repeated) +// uint16 zeroed +// uint16 always 11 +// null-terminated tokenString +// 0x00 0x10 terminator +func BuildLoginPacket(ackHandle, charID, tokenNumber uint32, tokenString string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_LOGIN) + bf.WriteUint32(ackHandle) + bf.WriteUint32(charID) + bf.WriteUint32(tokenNumber) + bf.WriteUint16(0) // HardcodedZero0 + bf.WriteUint16(0xCAFE) // RequestVersion (dummy) + bf.WriteUint32(charID) // CharID1 (repeated) + bf.WriteUint16(0) // Zeroed + bf.WriteUint16(11) // Always 11 + bf.WriteNullTerminatedBytes([]byte(tokenString)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildEnumerateStagePacket builds a MSG_SYS_ENUMERATE_STAGE packet. +// Layout mirrors Erupe's MsgSysEnumerateStage.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint8 always 1 +// uint8 prefix length (including null terminator) +// null-terminated stagePrefix +// 0x00 0x10 terminator +func BuildEnumerateStagePacket(ackHandle uint32, prefix string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ENUMERATE_STAGE) + bf.WriteUint32(ackHandle) + bf.WriteUint8(1) // Always 1 + bf.WriteUint8(uint8(len(prefix) + 1)) // Length including null terminator + bf.WriteNullTerminatedBytes([]byte(prefix)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildEnterStagePacket builds a MSG_SYS_ENTER_STAGE packet. +// Layout mirrors Erupe's MsgSysEnterStage.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint8 isQuest (0=false) +// uint8 stageID length (including null terminator) +// null-terminated stageID +// 0x00 0x10 terminator +func BuildEnterStagePacket(ackHandle uint32, stageID string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ENTER_STAGE) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // IsQuest = false + bf.WriteUint8(uint8(len(stageID) + 1)) // Length including null terminator + bf.WriteNullTerminatedBytes([]byte(stageID)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildPingPacket builds a MSG_SYS_PING response packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildPingPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_PING) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildLogoutPacket builds a MSG_SYS_LOGOUT packet. +// +// uint16 opcode +// uint8 logoutType (1 = normal logout) +// 0x00 0x10 terminator +func BuildLogoutPacket() []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_LOGOUT) + bf.WriteUint8(1) // LogoutType = normal + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildIssueLogkeyPacket builds a MSG_SYS_ISSUE_LOGKEY packet. +// +// uint16 opcode +// uint32 ackHandle +// uint16 unk0 +// uint16 unk1 +// 0x00 0x10 terminator +func BuildIssueLogkeyPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ISSUE_LOGKEY) + bf.WriteUint32(ackHandle) + bf.WriteUint16(0) + bf.WriteUint16(0) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildRightsReloadPacket builds a MSG_SYS_RIGHTS_RELOAD packet. +// +// uint16 opcode +// uint32 ackHandle +// uint8 count (0 = empty) +// 0x00 0x10 terminator +func BuildRightsReloadPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_RIGHTS_RELOAD) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // Count = 0 (no rights entries) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildLoaddataPacket builds a MSG_MHF_LOADDATA packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildLoaddataPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_LOADDATA) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildCastBinaryPacket builds a MSG_SYS_CAST_BINARY packet. +// Layout mirrors Erupe's MsgSysCastBinary.Parse: +// +// uint16 opcode +// uint32 unk (always 0) +// uint8 broadcastType +// uint8 messageType +// uint16 dataSize +// []byte payload +// 0x00 0x10 terminator +func BuildCastBinaryPacket(broadcastType, messageType uint8, payload []byte) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_CAST_BINARY) + bf.WriteUint32(0) // Unk + bf.WriteUint8(broadcastType) + bf.WriteUint8(messageType) + bf.WriteUint16(uint16(len(payload))) + bf.WriteBytes(payload) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildChatPayload builds the inner MsgBinChat binary blob for use with BuildCastBinaryPacket. +// Layout mirrors Erupe's binpacket/msg_bin_chat.go Build: +// +// uint8 unk0 (always 0) +// uint8 chatType +// uint16 flags (always 0) +// uint16 senderNameLen (SJIS bytes + null terminator) +// uint16 messageLen (SJIS bytes + null terminator) +// null-terminated SJIS message +// null-terminated SJIS senderName +func BuildChatPayload(chatType uint8, message, senderName string) []byte { + sjisMsg := stringsupport.UTF8ToSJIS(message) + sjisName := stringsupport.UTF8ToSJIS(senderName) + bf := byteframe.NewByteFrame() + bf.WriteUint8(0) // Unk0 + bf.WriteUint8(chatType) // Type + bf.WriteUint16(0) // Flags + bf.WriteUint16(uint16(len(sjisName) + 1)) // SenderName length (+ null term) + bf.WriteUint16(uint16(len(sjisMsg) + 1)) // Message length (+ null term) + bf.WriteNullTerminatedBytes(sjisMsg) // Message + bf.WriteNullTerminatedBytes(sjisName) // SenderName + return bf.Data() +} + +// BuildEnumerateQuestPacket builds a MSG_MHF_ENUMERATE_QUEST packet. +// +// uint16 opcode +// uint32 ackHandle +// uint8 unk0 (always 0) +// uint8 world +// uint16 counter +// uint16 offset +// uint8 unk1 (always 0) +// 0x00 0x10 terminator +func BuildEnumerateQuestPacket(ackHandle uint32, world uint8, counter, offset uint16) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_ENUMERATE_QUEST) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // Unk0 + bf.WriteUint8(world) + bf.WriteUint16(counter) + bf.WriteUint16(offset) + bf.WriteUint8(0) // Unk1 + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildGetWeeklySchedulePacket builds a MSG_MHF_GET_WEEKLY_SCHEDULE packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildGetWeeklySchedulePacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_GET_WEEKLY_SCHED) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} diff --git a/cmd/protbot/protocol/packets_test.go b/cmd/protbot/protocol/packets_test.go new file mode 100644 index 000000000..2b348f419 --- /dev/null +++ b/cmd/protbot/protocol/packets_test.go @@ -0,0 +1,412 @@ +package protocol + +import ( + "encoding/binary" + "testing" + + "erupe-ce/common/byteframe" +) + +// TestBuildLoginPacket verifies that the binary layout matches Erupe's Parse. +func TestBuildLoginPacket(t *testing.T) { + ackHandle := uint32(1) + charID := uint32(100) + tokenNumber := uint32(42) + tokenString := "0123456789ABCDEF" + + pkt := BuildLoginPacket(ackHandle, charID, tokenNumber, tokenString) + + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_LOGIN { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_LOGIN) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + gotCharID0 := bf.ReadUint32() + if gotCharID0 != charID { + t.Fatalf("charID0: got %d, want %d", gotCharID0, charID) + } + + gotTokenNum := bf.ReadUint32() + if gotTokenNum != tokenNumber { + t.Fatalf("tokenNumber: got %d, want %d", gotTokenNum, tokenNumber) + } + + gotZero := bf.ReadUint16() + if gotZero != 0 { + t.Fatalf("hardcodedZero: got %d, want 0", gotZero) + } + + gotVersion := bf.ReadUint16() + if gotVersion != 0xCAFE { + t.Fatalf("requestVersion: got 0x%04X, want 0xCAFE", gotVersion) + } + + gotCharID1 := bf.ReadUint32() + if gotCharID1 != charID { + t.Fatalf("charID1: got %d, want %d", gotCharID1, charID) + } + + gotZeroed := bf.ReadUint16() + if gotZeroed != 0 { + t.Fatalf("zeroed: got %d, want 0", gotZeroed) + } + + gotEleven := bf.ReadUint16() + if gotEleven != 11 { + t.Fatalf("always11: got %d, want 11", gotEleven) + } + + gotToken := string(bf.ReadNullTerminatedBytes()) + if gotToken != tokenString { + t.Fatalf("tokenString: got %q, want %q", gotToken, tokenString) + } + + // Verify terminator. + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildEnumerateStagePacket verifies binary layout matches Erupe's Parse. +func TestBuildEnumerateStagePacket(t *testing.T) { + ackHandle := uint32(5) + prefix := "sl1Ns" + + pkt := BuildEnumerateStagePacket(ackHandle, prefix) + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_ENUMERATE_STAGE { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENUMERATE_STAGE) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + alwaysOne := bf.ReadUint8() + if alwaysOne != 1 { + t.Fatalf("alwaysOne: got %d, want 1", alwaysOne) + } + + prefixLen := bf.ReadUint8() + if prefixLen != uint8(len(prefix)+1) { + t.Fatalf("prefixLen: got %d, want %d", prefixLen, len(prefix)+1) + } + + gotPrefix := string(bf.ReadNullTerminatedBytes()) + if gotPrefix != prefix { + t.Fatalf("prefix: got %q, want %q", gotPrefix, prefix) + } + + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildEnterStagePacket verifies binary layout matches Erupe's Parse. +func TestBuildEnterStagePacket(t *testing.T) { + ackHandle := uint32(7) + stageID := "sl1Ns200p0a0u0" + + pkt := BuildEnterStagePacket(ackHandle, stageID) + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_ENTER_STAGE { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENTER_STAGE) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + isQuest := bf.ReadUint8() + if isQuest != 0 { + t.Fatalf("isQuest: got %d, want 0", isQuest) + } + + stageLen := bf.ReadUint8() + if stageLen != uint8(len(stageID)+1) { + t.Fatalf("stageLen: got %d, want %d", stageLen, len(stageID)+1) + } + + gotStage := string(bf.ReadNullTerminatedBytes()) + if gotStage != stageID { + t.Fatalf("stageID: got %q, want %q", gotStage, stageID) + } + + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildPingPacket verifies MSG_SYS_PING binary layout. +func TestBuildPingPacket(t *testing.T) { + ackHandle := uint32(99) + pkt := BuildPingPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_PING { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_PING) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildLogoutPacket verifies MSG_SYS_LOGOUT binary layout. +func TestBuildLogoutPacket(t *testing.T) { + pkt := BuildLogoutPacket() + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_LOGOUT { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_LOGOUT) + } + if lt := bf.ReadUint8(); lt != 1 { + t.Fatalf("logoutType: got %d, want 1", lt) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildIssueLogkeyPacket verifies MSG_SYS_ISSUE_LOGKEY binary layout. +func TestBuildIssueLogkeyPacket(t *testing.T) { + ackHandle := uint32(10) + pkt := BuildIssueLogkeyPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_ISSUE_LOGKEY { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_ISSUE_LOGKEY) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if v := bf.ReadUint16(); v != 0 { + t.Fatalf("unk0: got %d, want 0", v) + } + if v := bf.ReadUint16(); v != 0 { + t.Fatalf("unk1: got %d, want 0", v) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildRightsReloadPacket verifies MSG_SYS_RIGHTS_RELOAD binary layout. +func TestBuildRightsReloadPacket(t *testing.T) { + ackHandle := uint32(20) + pkt := BuildRightsReloadPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_RIGHTS_RELOAD { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_RIGHTS_RELOAD) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if c := bf.ReadUint8(); c != 0 { + t.Fatalf("count: got %d, want 0", c) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildLoaddataPacket verifies MSG_MHF_LOADDATA binary layout. +func TestBuildLoaddataPacket(t *testing.T) { + ackHandle := uint32(30) + pkt := BuildLoaddataPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_LOADDATA { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_LOADDATA) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildCastBinaryPacket verifies MSG_SYS_CAST_BINARY binary layout. +func TestBuildCastBinaryPacket(t *testing.T) { + payload := []byte{0xDE, 0xAD, 0xBE, 0xEF} + pkt := BuildCastBinaryPacket(0x03, 1, payload) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_CAST_BINARY { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_CAST_BINARY) + } + if unk := bf.ReadUint32(); unk != 0 { + t.Fatalf("unk: got %d, want 0", unk) + } + if bt := bf.ReadUint8(); bt != 0x03 { + t.Fatalf("broadcastType: got %d, want 3", bt) + } + if mt := bf.ReadUint8(); mt != 1 { + t.Fatalf("messageType: got %d, want 1", mt) + } + if ds := bf.ReadUint16(); ds != uint16(len(payload)) { + t.Fatalf("dataSize: got %d, want %d", ds, len(payload)) + } + gotPayload := bf.ReadBytes(uint(len(payload))) + for i, b := range payload { + if gotPayload[i] != b { + t.Fatalf("payload[%d]: got 0x%02X, want 0x%02X", i, gotPayload[i], b) + } + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildChatPayload verifies the MsgBinChat inner binary layout and SJIS encoding. +func TestBuildChatPayload(t *testing.T) { + chatType := uint8(1) + message := "Hello" + senderName := "TestUser" + + payload := BuildChatPayload(chatType, message, senderName) + bf := byteframe.NewByteFrameFromBytes(payload) + + if unk := bf.ReadUint8(); unk != 0 { + t.Fatalf("unk0: got %d, want 0", unk) + } + if ct := bf.ReadUint8(); ct != chatType { + t.Fatalf("chatType: got %d, want %d", ct, chatType) + } + if flags := bf.ReadUint16(); flags != 0 { + t.Fatalf("flags: got %d, want 0", flags) + } + nameLen := bf.ReadUint16() + msgLen := bf.ReadUint16() + // "Hello" in ASCII/SJIS = 5 bytes + 1 null = 6 + if msgLen != 6 { + t.Fatalf("messageLen: got %d, want 6", msgLen) + } + // "TestUser" in ASCII/SJIS = 8 bytes + 1 null = 9 + if nameLen != 9 { + t.Fatalf("senderNameLen: got %d, want 9", nameLen) + } + + gotMsg := string(bf.ReadNullTerminatedBytes()) + if gotMsg != message { + t.Fatalf("message: got %q, want %q", gotMsg, message) + } + gotName := string(bf.ReadNullTerminatedBytes()) + if gotName != senderName { + t.Fatalf("senderName: got %q, want %q", gotName, senderName) + } +} + +// TestBuildEnumerateQuestPacket verifies MSG_MHF_ENUMERATE_QUEST binary layout. +func TestBuildEnumerateQuestPacket(t *testing.T) { + ackHandle := uint32(40) + world := uint8(2) + counter := uint16(100) + offset := uint16(50) + + pkt := BuildEnumerateQuestPacket(ackHandle, world, counter, offset) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_ENUMERATE_QUEST { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_ENUMERATE_QUEST) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if u0 := bf.ReadUint8(); u0 != 0 { + t.Fatalf("unk0: got %d, want 0", u0) + } + if w := bf.ReadUint8(); w != world { + t.Fatalf("world: got %d, want %d", w, world) + } + if c := bf.ReadUint16(); c != counter { + t.Fatalf("counter: got %d, want %d", c, counter) + } + if o := bf.ReadUint16(); o != offset { + t.Fatalf("offset: got %d, want %d", o, offset) + } + if u1 := bf.ReadUint8(); u1 != 0 { + t.Fatalf("unk1: got %d, want 0", u1) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildGetWeeklySchedulePacket verifies MSG_MHF_GET_WEEKLY_SCHEDULE binary layout. +func TestBuildGetWeeklySchedulePacket(t *testing.T) { + ackHandle := uint32(50) + pkt := BuildGetWeeklySchedulePacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_GET_WEEKLY_SCHED { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_GET_WEEKLY_SCHED) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestOpcodeValues verifies opcode constants match Erupe's iota-based enum. +func TestOpcodeValues(t *testing.T) { + _ = binary.BigEndian // ensure import used + tests := []struct { + name string + got uint16 + want uint16 + }{ + {"MSG_SYS_ACK", MSG_SYS_ACK, 0x0012}, + {"MSG_SYS_LOGIN", MSG_SYS_LOGIN, 0x0014}, + {"MSG_SYS_LOGOUT", MSG_SYS_LOGOUT, 0x0015}, + {"MSG_SYS_PING", MSG_SYS_PING, 0x0017}, + {"MSG_SYS_CAST_BINARY", MSG_SYS_CAST_BINARY, 0x0018}, + {"MSG_SYS_TIME", MSG_SYS_TIME, 0x001A}, + {"MSG_SYS_CASTED_BINARY", MSG_SYS_CASTED_BINARY, 0x001B}, + {"MSG_SYS_ISSUE_LOGKEY", MSG_SYS_ISSUE_LOGKEY, 0x001D}, + {"MSG_SYS_ENTER_STAGE", MSG_SYS_ENTER_STAGE, 0x0022}, + {"MSG_SYS_ENUMERATE_STAGE", MSG_SYS_ENUMERATE_STAGE, 0x002F}, + {"MSG_SYS_INSERT_USER", MSG_SYS_INSERT_USER, 0x0050}, + {"MSG_SYS_DELETE_USER", MSG_SYS_DELETE_USER, 0x0051}, + {"MSG_SYS_UPDATE_RIGHT", MSG_SYS_UPDATE_RIGHT, 0x0058}, + {"MSG_SYS_RIGHTS_RELOAD", MSG_SYS_RIGHTS_RELOAD, 0x005D}, + {"MSG_MHF_LOADDATA", MSG_MHF_LOADDATA, 0x0061}, + {"MSG_MHF_ENUMERATE_QUEST", MSG_MHF_ENUMERATE_QUEST, 0x009F}, + {"MSG_MHF_GET_WEEKLY_SCHED", MSG_MHF_GET_WEEKLY_SCHED, 0x00E1}, + } + for _, tt := range tests { + if tt.got != tt.want { + t.Errorf("%s: got 0x%04X, want 0x%04X", tt.name, tt.got, tt.want) + } + } +} diff --git a/cmd/protbot/protocol/sign.go b/cmd/protbot/protocol/sign.go new file mode 100644 index 000000000..490c49739 --- /dev/null +++ b/cmd/protbot/protocol/sign.go @@ -0,0 +1,105 @@ +package protocol + +import ( + "fmt" + + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" + + "erupe-ce/cmd/protbot/conn" +) + +// SignResult holds the parsed response from a successful DSGN sign-in. +type SignResult struct { + TokenID uint32 + TokenString string // 16 raw bytes as string + Timestamp uint32 + EntranceAddr string + CharIDs []uint32 +} + +// DoSign connects to the sign server and performs a DSGN login. +// Reference: Erupe server/signserver/session.go (handleDSGN) and dsgn_resp.go (makeSignResponse). +func DoSign(addr, username, password string) (*SignResult, error) { + c, err := conn.DialWithInit(addr) + if err != nil { + return nil, fmt.Errorf("sign connect: %w", err) + } + defer c.Close() + + // Build DSGN request: "DSGN:\x00" + SJIS(user) + "\x00" + SJIS(pass) + "\x00" + "\x00" + // The server reads: null-terminated request type, null-terminated user, null-terminated pass, null-terminated unk. + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:\x00")) // reqType (server strips last 3 chars to get "DSGN:") + bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(username)) + bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(password)) + bf.WriteUint8(0) // Unk null-terminated empty string + + if err := c.SendPacket(bf.Data()); err != nil { + return nil, fmt.Errorf("sign send: %w", err) + } + + resp, err := c.ReadPacket() + if err != nil { + return nil, fmt.Errorf("sign recv: %w", err) + } + + return parseSignResponse(resp) +} + +// parseSignResponse parses the binary response from the sign server. +// Reference: Erupe server/signserver/dsgn_resp.go:makeSignResponse +func parseSignResponse(data []byte) (*SignResult, error) { + if len(data) < 1 { + return nil, fmt.Errorf("empty sign response") + } + + rbf := byteframe.NewByteFrameFromBytes(data) + + resultCode := rbf.ReadUint8() + if resultCode != 1 { // SIGN_SUCCESS = 1 + return nil, fmt.Errorf("sign failed with code %d", resultCode) + } + + patchCount := rbf.ReadUint8() // patch server count (usually 2) + _ = rbf.ReadUint8() // entrance server count (usually 1) + charCount := rbf.ReadUint8() // character count + + result := &SignResult{} + result.TokenID = rbf.ReadUint32() + result.TokenString = string(rbf.ReadBytes(16)) // 16 raw bytes + result.Timestamp = rbf.ReadUint32() + + // Skip patch server URLs (pascal strings with uint8 length prefix) + for i := uint8(0); i < patchCount; i++ { + strLen := rbf.ReadUint8() + _ = rbf.ReadBytes(uint(strLen)) + } + + // Read entrance server address (pascal string with uint8 length prefix) + entranceLen := rbf.ReadUint8() + result.EntranceAddr = string(rbf.ReadBytes(uint(entranceLen - 1))) + _ = rbf.ReadUint8() // null terminator + + // Read character entries + for i := uint8(0); i < charCount; i++ { + charID := rbf.ReadUint32() + result.CharIDs = append(result.CharIDs, charID) + + _ = rbf.ReadUint16() // HR + _ = rbf.ReadUint16() // WeaponType + _ = rbf.ReadUint32() // LastLogin + _ = rbf.ReadUint8() // IsFemale + _ = rbf.ReadUint8() // IsNewCharacter + _ = rbf.ReadUint8() // Old GR + _ = rbf.ReadUint8() // Use uint16 GR flag + _ = rbf.ReadBytes(16) // Character name (padded) + _ = rbf.ReadBytes(32) // Unk desc string (padded) + // ZZ mode: additional fields + _ = rbf.ReadUint16() // GR + _ = rbf.ReadUint8() // Unk + _ = rbf.ReadUint8() // Unk + } + + return result, nil +} diff --git a/cmd/protbot/scenario/chat.go b/cmd/protbot/scenario/chat.go new file mode 100644 index 000000000..61394dc4b --- /dev/null +++ b/cmd/protbot/scenario/chat.go @@ -0,0 +1,74 @@ +package scenario + +import ( + "fmt" + + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" + + "erupe-ce/cmd/protbot/protocol" +) + +// ChatMessage holds a parsed incoming chat message. +type ChatMessage struct { + ChatType uint8 + SenderName string + Message string +} + +// SendChat sends a chat message via MSG_SYS_CAST_BINARY with a MsgBinChat payload. +// broadcastType controls delivery scope: 0x03 = stage, 0x06 = world. +func SendChat(ch *protocol.ChannelConn, broadcastType, chatType uint8, message, senderName string) error { + payload := protocol.BuildChatPayload(chatType, message, senderName) + pkt := protocol.BuildCastBinaryPacket(broadcastType, 1, payload) + fmt.Printf("[chat] Sending chat (type=%d, broadcast=%d): %s\n", chatType, broadcastType, message) + return ch.SendPacket(pkt) +} + +// ChatCallback is invoked when a chat message is received. +type ChatCallback func(msg ChatMessage) + +// ListenChat registers a handler on MSG_SYS_CASTED_BINARY that parses chat +// messages (messageType=1) and invokes the callback. +func ListenChat(ch *protocol.ChannelConn, cb ChatCallback) { + ch.OnPacket(protocol.MSG_SYS_CASTED_BINARY, func(opcode uint16, data []byte) { + // MSG_SYS_CASTED_BINARY layout from server: + // uint32 unk + // uint8 broadcastType + // uint8 messageType + // uint16 dataSize + // []byte payload + if len(data) < 8 { + return + } + messageType := data[5] + if messageType != 1 { // Only handle chat messages. + return + } + bf := byteframe.NewByteFrameFromBytes(data) + _ = bf.ReadUint32() // unk + _ = bf.ReadUint8() // broadcastType + _ = bf.ReadUint8() // messageType + dataSize := bf.ReadUint16() + if dataSize == 0 { + return + } + payload := bf.ReadBytes(uint(dataSize)) + + // Parse MsgBinChat inner payload. + pbf := byteframe.NewByteFrameFromBytes(payload) + _ = pbf.ReadUint8() // unk0 + chatType := pbf.ReadUint8() + _ = pbf.ReadUint16() // flags + _ = pbf.ReadUint16() // senderNameLen + _ = pbf.ReadUint16() // messageLen + msg := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes()) + sender := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes()) + + cb(ChatMessage{ + ChatType: chatType, + SenderName: sender, + Message: msg, + }) + }) +} diff --git a/cmd/protbot/scenario/login.go b/cmd/protbot/scenario/login.go new file mode 100644 index 000000000..a90941ef0 --- /dev/null +++ b/cmd/protbot/scenario/login.go @@ -0,0 +1,82 @@ +// Package scenario provides high-level MHF protocol flows. +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// LoginResult holds the outcome of a full login flow. +type LoginResult struct { + Sign *protocol.SignResult + Servers []protocol.ServerEntry + Channel *protocol.ChannelConn +} + +// Login performs the full sign → entrance → channel login flow. +func Login(signAddr, username, password string) (*LoginResult, error) { + // Step 1: Sign server authentication. + fmt.Printf("[sign] Connecting to %s...\n", signAddr) + sign, err := protocol.DoSign(signAddr, username, password) + if err != nil { + return nil, fmt.Errorf("sign: %w", err) + } + fmt.Printf("[sign] OK — tokenID=%d, %d character(s), entrance=%s\n", + sign.TokenID, len(sign.CharIDs), sign.EntranceAddr) + + if len(sign.CharIDs) == 0 { + return nil, fmt.Errorf("no characters on account") + } + + // Step 2: Entrance server — get server/channel list. + fmt.Printf("[entrance] Connecting to %s...\n", sign.EntranceAddr) + servers, err := protocol.DoEntrance(sign.EntranceAddr) + if err != nil { + return nil, fmt.Errorf("entrance: %w", err) + } + if len(servers) == 0 { + return nil, fmt.Errorf("no channels available") + } + for i, s := range servers { + fmt.Printf("[entrance] [%d] %s — %s:%d\n", i, s.Name, s.IP, s.Port) + } + + // Step 3: Connect to the first channel server. + first := servers[0] + channelAddr := fmt.Sprintf("%s:%d", first.IP, first.Port) + fmt.Printf("[channel] Connecting to %s...\n", channelAddr) + ch, err := protocol.ConnectChannel(channelAddr) + if err != nil { + return nil, fmt.Errorf("channel connect: %w", err) + } + + // Step 4: Send MSG_SYS_LOGIN. + charID := sign.CharIDs[0] + ack := ch.NextAckHandle() + loginPkt := protocol.BuildLoginPacket(ack, charID, sign.TokenID, sign.TokenString) + fmt.Printf("[channel] Sending MSG_SYS_LOGIN (charID=%d, ackHandle=%d)...\n", charID, ack) + if err := ch.SendPacket(loginPkt); err != nil { + ch.Close() + return nil, fmt.Errorf("channel send login: %w", err) + } + + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + ch.Close() + return nil, fmt.Errorf("channel login ack: %w", err) + } + if resp.ErrorCode != 0 { + ch.Close() + return nil, fmt.Errorf("channel login failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[channel] Login ACK received (error=%d, %d bytes data)\n", + resp.ErrorCode, len(resp.Data)) + + return &LoginResult{ + Sign: sign, + Servers: servers, + Channel: ch, + }, nil +} diff --git a/cmd/protbot/scenario/logout.go b/cmd/protbot/scenario/logout.go new file mode 100644 index 000000000..692c97dda --- /dev/null +++ b/cmd/protbot/scenario/logout.go @@ -0,0 +1,17 @@ +package scenario + +import ( + "fmt" + + "erupe-ce/cmd/protbot/protocol" +) + +// Logout sends MSG_SYS_LOGOUT and closes the channel connection. +func Logout(ch *protocol.ChannelConn) error { + fmt.Println("[logout] Sending MSG_SYS_LOGOUT...") + if err := ch.SendPacket(protocol.BuildLogoutPacket()); err != nil { + ch.Close() + return fmt.Errorf("logout send: %w", err) + } + return ch.Close() +} diff --git a/cmd/protbot/scenario/quest.go b/cmd/protbot/scenario/quest.go new file mode 100644 index 000000000..2b3c0b2eb --- /dev/null +++ b/cmd/protbot/scenario/quest.go @@ -0,0 +1,31 @@ +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// EnumerateQuests sends MSG_MHF_ENUMERATE_QUEST and returns the raw quest list data. +func EnumerateQuests(ch *protocol.ChannelConn, world uint8, counter uint16) ([]byte, error) { + ack := ch.NextAckHandle() + pkt := protocol.BuildEnumerateQuestPacket(ack, world, counter, 0) + fmt.Printf("[quest] Sending MSG_MHF_ENUMERATE_QUEST (world=%d, counter=%d, ackHandle=%d)...\n", + world, counter, ack) + if err := ch.SendPacket(pkt); err != nil { + return nil, fmt.Errorf("enumerate quest send: %w", err) + } + + resp, err := ch.WaitForAck(ack, 15*time.Second) + if err != nil { + return nil, fmt.Errorf("enumerate quest ack: %w", err) + } + if resp.ErrorCode != 0 { + return nil, fmt.Errorf("enumerate quest failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[quest] ENUMERATE_QUEST ACK (error=%d, %d bytes data)\n", + resp.ErrorCode, len(resp.Data)) + + return resp.Data, nil +} diff --git a/cmd/protbot/scenario/session.go b/cmd/protbot/scenario/session.go new file mode 100644 index 000000000..0f49f8795 --- /dev/null +++ b/cmd/protbot/scenario/session.go @@ -0,0 +1,50 @@ +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// SetupSession performs the post-login session setup: ISSUE_LOGKEY, RIGHTS_RELOAD, LOADDATA. +// Returns the loaddata response blob for inspection. +func SetupSession(ch *protocol.ChannelConn, charID uint32) ([]byte, error) { + // Step 1: Issue logkey. + ack := ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_SYS_ISSUE_LOGKEY (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildIssueLogkeyPacket(ack)); err != nil { + return nil, fmt.Errorf("issue logkey send: %w", err) + } + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return nil, fmt.Errorf("issue logkey ack: %w", err) + } + fmt.Printf("[session] ISSUE_LOGKEY ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + // Step 2: Rights reload. + ack = ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_SYS_RIGHTS_RELOAD (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildRightsReloadPacket(ack)); err != nil { + return nil, fmt.Errorf("rights reload send: %w", err) + } + resp, err = ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return nil, fmt.Errorf("rights reload ack: %w", err) + } + fmt.Printf("[session] RIGHTS_RELOAD ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + // Step 3: Load save data. + ack = ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_MHF_LOADDATA (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildLoaddataPacket(ack)); err != nil { + return nil, fmt.Errorf("loaddata send: %w", err) + } + resp, err = ch.WaitForAck(ack, 30*time.Second) + if err != nil { + return nil, fmt.Errorf("loaddata ack: %w", err) + } + fmt.Printf("[session] LOADDATA ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + return resp.Data, nil +} diff --git a/cmd/protbot/scenario/stage.go b/cmd/protbot/scenario/stage.go new file mode 100644 index 000000000..27b5b757d --- /dev/null +++ b/cmd/protbot/scenario/stage.go @@ -0,0 +1,111 @@ +package scenario + +import ( + "encoding/binary" + "fmt" + "time" + + "erupe-ce/common/byteframe" + + "erupe-ce/cmd/protbot/protocol" +) + +// StageInfo holds a parsed stage entry from MSG_SYS_ENUMERATE_STAGE response. +type StageInfo struct { + ID string + Reserved uint16 + Clients uint16 + Displayed uint16 + MaxPlayers uint16 + Flags uint8 +} + +// EnterLobby enumerates available lobby stages and enters the first one. +func EnterLobby(ch *protocol.ChannelConn) error { + // Step 1: Enumerate stages with "sl1Ns" prefix (main lobby stages). + ack := ch.NextAckHandle() + enumPkt := protocol.BuildEnumerateStagePacket(ack, "sl1Ns") + fmt.Printf("[stage] Sending MSG_SYS_ENUMERATE_STAGE (prefix=\"sl1Ns\", ackHandle=%d)...\n", ack) + if err := ch.SendPacket(enumPkt); err != nil { + return fmt.Errorf("enumerate stage send: %w", err) + } + + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return fmt.Errorf("enumerate stage ack: %w", err) + } + if resp.ErrorCode != 0 { + return fmt.Errorf("enumerate stage failed: error code %d", resp.ErrorCode) + } + + stages := parseEnumerateStageResponse(resp.Data) + fmt.Printf("[stage] Found %d stage(s)\n", len(stages)) + for i, s := range stages { + fmt.Printf("[stage] [%d] %s — %d/%d players, flags=0x%02X\n", + i, s.ID, s.Clients, s.MaxPlayers, s.Flags) + } + + // Step 2: Enter the default lobby stage. + // Even if no stages were enumerated, use the default stage ID. + stageID := "sl1Ns200p0a0u0" + if len(stages) > 0 { + stageID = stages[0].ID + } + + ack = ch.NextAckHandle() + enterPkt := protocol.BuildEnterStagePacket(ack, stageID) + fmt.Printf("[stage] Sending MSG_SYS_ENTER_STAGE (stageID=%q, ackHandle=%d)...\n", stageID, ack) + if err := ch.SendPacket(enterPkt); err != nil { + return fmt.Errorf("enter stage send: %w", err) + } + + resp, err = ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return fmt.Errorf("enter stage ack: %w", err) + } + if resp.ErrorCode != 0 { + return fmt.Errorf("enter stage failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[stage] Enter stage ACK received (error=%d)\n", resp.ErrorCode) + + return nil +} + +// parseEnumerateStageResponse parses the ACK data from MSG_SYS_ENUMERATE_STAGE. +// Reference: Erupe server/channelserver/handlers_stage.go (handleMsgSysEnumerateStage) +func parseEnumerateStageResponse(data []byte) []StageInfo { + if len(data) < 2 { + return nil + } + + bf := byteframe.NewByteFrameFromBytes(data) + count := bf.ReadUint16() + + var stages []StageInfo + for i := uint16(0); i < count; i++ { + s := StageInfo{} + s.Reserved = bf.ReadUint16() + s.Clients = bf.ReadUint16() + s.Displayed = bf.ReadUint16() + s.MaxPlayers = bf.ReadUint16() + s.Flags = bf.ReadUint8() + + // Stage ID is a pascal string with uint8 length prefix. + strLen := bf.ReadUint8() + if strLen > 0 { + idBytes := bf.ReadBytes(uint(strLen)) + // Remove null terminator if present. + if len(idBytes) > 0 && idBytes[len(idBytes)-1] == 0 { + idBytes = idBytes[:len(idBytes)-1] + } + s.ID = string(idBytes) + } + + stages = append(stages, s) + } + + // After stages: uint32 timestamp, uint32 max clan members (we ignore these). + _ = binary.BigEndian // suppress unused import if needed + + return stages +} From 486be65a38d4d38d8be1338076c0e35a6c7698fe Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 20 Feb 2026 14:17:40 +0100 Subject: [PATCH 3/5] fix(protbot,channelserver): fix sign protocol and entrance parsing, guard nil save data The protbot sent "DSGN:\x00" as the sign request type, but the server strips the last 3 characters as a version suffix. Send "DSGN:041" (ZZ client mode 41) to match the real client format. The entrance channel entry parser read 14 bytes for remaining fields but the server writes 18 bytes (9 uint16, not 7), causing a panic when parsing the server list. The channel server panicked on disconnect when a session had no decompressed save data (e.g. protbot or early client disconnect). Guard Save() against nil decompSave. Also fix docker-compose volume mount for Postgres 18 which changed its data directory layout. --- cmd/protbot/protocol/entrance.go | 4 ++-- cmd/protbot/protocol/sign.go | 5 +++-- docker/docker-compose.yml | 2 +- server/channelserver/handlers_character.go | 7 +++++++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/cmd/protbot/protocol/entrance.go b/cmd/protbot/protocol/entrance.go index dcab094e1..d7c516a3f 100644 --- a/cmd/protbot/protocol/entrance.go +++ b/cmd/protbot/protocol/entrance.go @@ -115,13 +115,13 @@ func parseServerEntries(data []byte, entryCount uint16) ([]ServerEntry, error) { name += string(nameBytes[j]) } - // Read channel entries + // Read channel entries (14 x uint16 = 28 bytes each) for j := uint16(0); j < channelCount; j++ { port := bf.ReadUint16() _ = bf.ReadUint16() // channelIdx | 16 _ = bf.ReadUint16() // maxPlayers _ = bf.ReadUint16() // currentPlayers - _ = bf.ReadBytes(14) // remaining channel fields (7 x uint16) + _ = bf.ReadBytes(18) // remaining channel fields (9 x uint16: 6 zeros + unk319 + unk254 + unk255) _ = bf.ReadUint16() // 12345 serverIP := ip.String() diff --git a/cmd/protbot/protocol/sign.go b/cmd/protbot/protocol/sign.go index 490c49739..4f6670b6f 100644 --- a/cmd/protbot/protocol/sign.go +++ b/cmd/protbot/protocol/sign.go @@ -27,10 +27,11 @@ func DoSign(addr, username, password string) (*SignResult, error) { } defer c.Close() - // Build DSGN request: "DSGN:\x00" + SJIS(user) + "\x00" + SJIS(pass) + "\x00" + "\x00" + // Build DSGN request: "DSGN:041" + \x00 + SJIS(user) + \x00 + SJIS(pass) + \x00 + \x00 // The server reads: null-terminated request type, null-terminated user, null-terminated pass, null-terminated unk. + // The request type has a 3-char version suffix (e.g. "041" for ZZ client mode 41) that the server strips. bf := byteframe.NewByteFrame() - bf.WriteNullTerminatedBytes([]byte("DSGN:\x00")) // reqType (server strips last 3 chars to get "DSGN:") + bf.WriteNullTerminatedBytes([]byte("DSGN:041")) // reqType with version suffix (server strips last 3 chars to get "DSGN:") bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(username)) bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(password)) bf.WriteUint8(0) // Unk null-terminated empty string diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 759b16b87..34535e2ef 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -13,7 +13,7 @@ services: ports: - "5432:5432" volumes: - - ./db-data/:/var/lib/postgresql/data/ + - ./db-data/:/var/lib/postgresql/ - ../schemas/:/schemas/ - ./init/setup.sh:/docker-entrypoint-initdb.d/setup.sh healthcheck: diff --git a/server/channelserver/handlers_character.go b/server/channelserver/handlers_character.go index c5558e839..14cd59084 100644 --- a/server/channelserver/handlers_character.go +++ b/server/channelserver/handlers_character.go @@ -48,6 +48,13 @@ func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error) } func (save *CharacterSaveData) Save(s *Session) { + if save.decompSave == nil { + s.logger.Warn("No decompressed save data, skipping save", + zap.Uint32("charID", save.CharID), + ) + return + } + if !s.kqfOverride { s.kqf = save.KQF } else { From eab7d1fc4fad848d03f90c0cda1153f55467d3ae Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 20 Feb 2026 14:36:37 +0100 Subject: [PATCH 4/5] fix(channelserver): eliminate data races in shutdown and session lifecycle The channel server had several concurrency issues found by the race detector during isolation testing: - acceptClients could send on a closed acceptConns channel during shutdown, causing a panic. Replace close(acceptConns) with a done channel and select-based shutdown signaling in both acceptClients and manageSessions. - invalidateSessions read isShuttingDown and iterated sessions without holding the lock. Rewrite with ticker + done channel select and snapshot sessions under lock before processing timeouts. - sendLoop/recvLoop accessed global _config.ErupeConfig.LoopDelay which races with tests modifying the global. Use the per-server erupeConfig instead. - logoutPlayer panicked on DB errors and crashed on nil DB (no-db test scenarios). Guard with nil check and log errors instead. - Shutdown was not idempotent, double-calling caused double-close panic on done channel. Add 5 channel isolation tests verifying independent shutdown, listener failure, session panic recovery, cross-channel registry after shutdown, and stage isolation. --- .../channelserver/channel_isolation_test.go | 214 ++++++++++++++++++ server/channelserver/handlers_session.go | 16 +- server/channelserver/sys_channel_server.go | 55 +++-- server/channelserver/sys_session.go | 5 +- 4 files changed, 260 insertions(+), 30 deletions(-) create mode 100644 server/channelserver/channel_isolation_test.go diff --git a/server/channelserver/channel_isolation_test.go b/server/channelserver/channel_isolation_test.go new file mode 100644 index 000000000..158fca9a3 --- /dev/null +++ b/server/channelserver/channel_isolation_test.go @@ -0,0 +1,214 @@ +package channelserver + +import ( + "net" + "testing" + "time" + + _config "erupe-ce/config" + + "go.uber.org/zap" +) + +// createListeningTestServer creates a channel server that binds to a real TCP port. +// Port 0 lets the OS assign a free port. The server is automatically shut down +// when the test completes. +func createListeningTestServer(t *testing.T, id uint16) *Server { + t.Helper() + logger, _ := zap.NewDevelopment() + s := NewServer(&Config{ + ID: id, + Logger: logger, + ErupeConfig: &_config.Config{ + DebugOptions: _config.DebugOptions{ + LogOutboundMessages: false, + LogInboundMessages: false, + }, + }, + }) + s.Port = 0 // Let OS pick a free port + if err := s.Start(); err != nil { + t.Fatalf("channel %d failed to start: %v", id, err) + } + t.Cleanup(func() { + s.Shutdown() + time.Sleep(200 * time.Millisecond) // Let background goroutines and sessions exit. + }) + return s +} + +// listenerAddr returns the address the server is listening on. +func listenerAddr(s *Server) string { + return s.listener.Addr().String() +} + +// TestChannelIsolation_ShutdownDoesNotAffectOthers verifies that shutting down +// one channel server does not prevent other channels from accepting connections. +func TestChannelIsolation_ShutdownDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + ch3 := createListeningTestServer(t, 3) + + addr1 := listenerAddr(ch1) + addr2 := listenerAddr(ch2) + addr3 := listenerAddr(ch3) + + // Verify all three channels accept connections initially. + for _, addr := range []string{addr1, addr2, addr3} { + conn, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("initial connection to %s failed: %v", addr, err) + } + conn.Close() + } + + // Shut down channel 1. + ch1.Shutdown() + time.Sleep(50 * time.Millisecond) + + // Channel 1 should refuse connections. + _, err := net.DialTimeout("tcp", addr1, 500*time.Millisecond) + if err == nil { + t.Error("channel 1 should refuse connections after shutdown") + } + + // Channels 2 and 3 must still accept connections. + for _, tc := range []struct { + name string + addr string + }{ + {"channel 2", addr2}, + {"channel 3", addr3}, + } { + conn, err := net.DialTimeout("tcp", tc.addr, time.Second) + if err != nil { + t.Errorf("%s should still accept connections after channel 1 shutdown, got: %v", tc.name, err) + } else { + conn.Close() + } + } +} + +// TestChannelIsolation_ListenerCloseDoesNotAffectOthers simulates an unexpected +// listener failure (e.g. port conflict, OS-level error) on one channel and +// verifies other channels continue operating. +func TestChannelIsolation_ListenerCloseDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + + addr2 := listenerAddr(ch2) + + // Forcibly close channel 1's listener (simulating unexpected failure). + ch1.listener.Close() + time.Sleep(50 * time.Millisecond) + + // Channel 2 must still work. + conn, err := net.DialTimeout("tcp", addr2, time.Second) + if err != nil { + t.Fatalf("channel 2 should still accept connections after channel 1 listener closed: %v", err) + } + conn.Close() +} + +// TestChannelIsolation_SessionPanicDoesNotAffectChannel verifies that a panic +// inside a session handler is recovered and does not crash the channel server. +func TestChannelIsolation_SessionPanicDoesNotAffectChannel(t *testing.T) { + ch := createListeningTestServer(t, 1) + addr := listenerAddr(ch) + + // Connect a client that will trigger a session. + conn1, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("first connection failed: %v", err) + } + + // Send garbage data that will cause handlePacketGroup to hit the panic recovery. + // The session's defer/recover should catch it without killing the channel. + conn1.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF}) + time.Sleep(100 * time.Millisecond) + conn1.Close() + time.Sleep(100 * time.Millisecond) + + // The channel should still accept new connections after the panic. + conn2, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("channel should still accept connections after session panic: %v", err) + } + conn2.Close() +} + +// TestChannelIsolation_CrossChannelRegistryAfterShutdown verifies that the +// channel registry handles a shut-down channel gracefully during cross-channel +// operations (search, find, disconnect). +func TestChannelIsolation_CrossChannelRegistryAfterShutdown(t *testing.T) { + channels := createTestChannels(3) + reg := NewLocalChannelRegistry(channels) + + // Add sessions to all channels. + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + + // Simulate channel 1 shutting down by marking it and clearing sessions. + channels[0].Lock() + channels[0].isShuttingDown = true + channels[0].sessions = make(map[net.Conn]*Session) + channels[0].Unlock() + + // Registry operations should still work for remaining channels. + found := reg.FindSessionByCharID(2) + if found == nil { + t.Error("FindSessionByCharID(2) should find session on channel 2") + } + + found = reg.FindSessionByCharID(3) + if found == nil { + t.Error("FindSessionByCharID(3) should find session on channel 3") + } + + // Session from shut-down channel should not be found. + found = reg.FindSessionByCharID(1) + if found != nil { + t.Error("FindSessionByCharID(1) should not find session on shut-down channel") + } + + // SearchSessions should return only sessions from live channels. + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 2 { + t.Errorf("SearchSessions should return 2 results from live channels, got %d", len(results)) + } +} + +// TestChannelIsolation_IndependentStages verifies that stages are per-channel +// and one channel's stages don't leak into another. +func TestChannelIsolation_IndependentStages(t *testing.T) { + channels := createTestChannels(2) + + stageName := "sl1Qs999p0a0u42" + + // Add stage only to channel 1. + channels[0].stagesLock.Lock() + channels[0].stages[stageName] = NewStage(stageName) + channels[0].stagesLock.Unlock() + + // Channel 1 should have the stage. + channels[0].stagesLock.RLock() + _, ok1 := channels[0].stages[stageName] + channels[0].stagesLock.RUnlock() + if !ok1 { + t.Error("channel 1 should have the stage") + } + + // Channel 2 should NOT have the stage. + channels[1].stagesLock.RLock() + _, ok2 := channels[1].stages[stageName] + channels[1].stagesLock.RUnlock() + if ok2 { + t.Error("channel 2 should not have channel 1's stage") + } +} diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 1944a7902..ab39c6b95 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -293,14 +293,16 @@ func logoutPlayer(s *Session) { } // Update sign sessions and server player count - _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) - if err != nil { - panic(err) - } + if s.server.db != nil { + _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) + if err != nil { + s.logger.Error("Failed to clear sign session", zap.Error(err)) + } - _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) - if err != nil { - panic(err) + _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) + if err != nil { + s.logger.Error("Failed to update player count", zap.Error(err)) + } } if s.stage == nil { diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index f042c6109..bfd22414d 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -50,6 +50,7 @@ type Server struct { sessions map[net.Conn]*Session listener net.Listener // Listener that is created when Server.Start is called. isShuttingDown bool + done chan struct{} // Closed on Shutdown to wake background goroutines. stagesLock sync.RWMutex stages map[string]*Stage @@ -91,6 +92,7 @@ func NewServer(config *Config) *Server { erupeConfig: config.ErupeConfig, acceptConns: make(chan net.Conn), deleteConns: make(chan net.Conn), + done: make(chan struct{}), sessions: make(map[net.Conn]*Session), stages: make(map[string]*Stage), userBinaryParts: make(map[userBinaryPartID][]byte), @@ -156,19 +158,23 @@ func (s *Server) Start() error { return nil } -// Shutdown tries to shut down the server gracefully. +// Shutdown tries to shut down the server gracefully. Safe to call multiple times. func (s *Server) Shutdown() { s.Lock() + alreadyShutDown := s.isShuttingDown s.isShuttingDown = true s.Unlock() + if alreadyShutDown { + return + } + + close(s.done) + if s.listener != nil { _ = s.listener.Close() } - if s.acceptConns != nil { - close(s.acceptConns) - } } func (s *Server) acceptClients() { @@ -186,25 +192,21 @@ func (s *Server) acceptClients() { continue } } - s.acceptConns <- conn + select { + case s.acceptConns <- conn: + case <-s.done: + _ = conn.Close() + return + } } } func (s *Server) manageSessions() { for { select { + case <-s.done: + return case newConn := <-s.acceptConns: - // Gracefully handle acceptConns channel closing. - if newConn == nil { - s.Lock() - shutdown := s.isShuttingDown - s.Unlock() - - if shutdown { - return - } - } - session := NewSession(s, newConn) s.Lock() @@ -236,15 +238,28 @@ func (s *Server) getObjectId() uint16 { } func (s *Server) invalidateSessions() { - for !s.isShuttingDown { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-s.done: + return + case <-ticker.C: + } + s.Lock() + var timedOut []*Session for _, sess := range s.sessions { if time.Since(sess.lastPacket) > time.Second*time.Duration(30) { - s.logger.Info("session timeout", zap.String("Name", sess.Name)) - logoutPlayer(sess) + timedOut = append(timedOut, sess) } } - time.Sleep(time.Second * 10) + s.Unlock() + + for _, sess := range timedOut { + s.logger.Info("session timeout", zap.String("Name", sess.Name)) + logoutPlayer(sess) + } } } diff --git a/server/channelserver/sys_session.go b/server/channelserver/sys_session.go index 294d470ab..b30190aec 100644 --- a/server/channelserver/sys_session.go +++ b/server/channelserver/sys_session.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "encoding/hex" "erupe-ce/common/mhfcourse" - _config "erupe-ce/config" "fmt" "io" "net" @@ -172,7 +171,7 @@ func (s *Session) sendLoop() { s.logger.Warn("Failed to send packet", zap.Error(err)) } } - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } } @@ -215,7 +214,7 @@ func (s *Session) recvLoop() { return } s.handlePacketGroup(pkt) - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } } From c8996e0672b47ca12d42e58cff08fe64a156ea3c Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 20 Feb 2026 14:38:35 +0100 Subject: [PATCH 5/5] chore: add protbot binary to gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 4132bac7f..17e860fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ screenshots/* # We don't need built files erupe-ce erupe +protbot tools/loganalyzer/loganalyzer # config is install dependent