diff --git a/.gitignore b/.gitignore index 4132bac7f..17e860fa7 100644 --- a/.gitignore +++ b/.gitignore @@ -16,6 +16,7 @@ screenshots/* # We don't need built files erupe-ce erupe +protbot tools/loganalyzer/loganalyzer # config is install dependent diff --git a/cmd/protbot/conn/bin8.go b/cmd/protbot/conn/bin8.go new file mode 100644 index 000000000..4a1256fc5 --- /dev/null +++ b/cmd/protbot/conn/bin8.go @@ -0,0 +1,37 @@ +package conn + +import "encoding/binary" + +var ( + bin8Key = []byte{0x01, 0x23, 0x34, 0x45, 0x56, 0xAB, 0xCD, 0xEF} + sum32Table0 = []byte{0x35, 0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12} + sum32Table1 = []byte{0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12, 0xDE, 0xDE, 0x35} +) + +// CalcSum32 calculates the custom MHF "sum32" checksum. +func CalcSum32(data []byte) uint32 { + tableIdx0 := (len(data) + 1) & 0xFF + tableIdx1 := int((data[len(data)>>1] + 1) & 0xFF) + out := make([]byte, 4) + for i := 0; i < len(data); i++ { + key := data[i] ^ sum32Table0[(tableIdx0+i)%7] ^ sum32Table1[(tableIdx1+i)%9] + out[i&3] = (out[i&3] + key) & 0xFF + } + return binary.BigEndian.Uint32(out) +} + +func rotate(k *uint32) { + *k = uint32(((54323 * uint(*k)) + 1) & 0xFFFFFFFF) +} + +// DecryptBin8 decrypts MHF "binary8" data. +func DecryptBin8(data []byte, key byte) []byte { + k := uint32(key) + output := make([]byte, len(data)) + for i := 0; i < len(data); i++ { + rotate(&k) + tmp := data[i] ^ byte((k>>13)&0xFF) + output[i] = tmp ^ bin8Key[i&7] + } + return output +} diff --git a/cmd/protbot/conn/bin8_test.go b/cmd/protbot/conn/bin8_test.go new file mode 100644 index 000000000..fa820c030 --- /dev/null +++ b/cmd/protbot/conn/bin8_test.go @@ -0,0 +1,52 @@ +package conn + +import ( + "testing" +) + +// TestCalcSum32 verifies the checksum against a known input. +func TestCalcSum32(t *testing.T) { + // Verify determinism: same input gives same output. + data := []byte("Hello, MHF!") + sum1 := CalcSum32(data) + sum2 := CalcSum32(data) + if sum1 != sum2 { + t.Fatalf("CalcSum32 not deterministic: %08X != %08X", sum1, sum2) + } + + // Different inputs produce different outputs (basic sanity). + data2 := []byte("Hello, MHF?") + sum3 := CalcSum32(data2) + if sum1 == sum3 { + t.Fatalf("CalcSum32 collision on different inputs: both %08X", sum1) + } +} + +// TestDecryptBin8RoundTrip verifies that encrypting and decrypting with Bin8 +// produces the original data. We only have DecryptBin8, but we can verify +// the encrypt→decrypt path by implementing encrypt inline here. +func TestDecryptBin8RoundTrip(t *testing.T) { + original := []byte("Test data for Bin8 encryption round-trip") + key := byte(0x42) + + // Encrypt (inline copy of Erupe's EncryptBin8) + k := uint32(key) + encrypted := make([]byte, len(original)) + for i := 0; i < len(original); i++ { + rotate(&k) + tmp := bin8Key[i&7] ^ byte((k>>13)&0xFF) + encrypted[i] = original[i] ^ tmp + } + + // Decrypt + decrypted := DecryptBin8(encrypted, key) + + if len(decrypted) != len(original) { + t.Fatalf("length mismatch: got %d, want %d", len(decrypted), len(original)) + } + for i := range original { + if decrypted[i] != original[i] { + t.Fatalf("byte %d: got 0x%02X, want 0x%02X", i, decrypted[i], original[i]) + } + } +} diff --git a/cmd/protbot/conn/conn.go b/cmd/protbot/conn/conn.go new file mode 100644 index 000000000..b7ad33173 --- /dev/null +++ b/cmd/protbot/conn/conn.go @@ -0,0 +1,52 @@ +package conn + +import ( + "fmt" + "net" +) + +// MHFConn wraps a CryptConn and provides convenience methods for MHF connections. +type MHFConn struct { + *CryptConn + RawConn net.Conn +} + +// DialWithInit connects to addr and sends the 8 NULL byte initialization +// required by sign and entrance servers. +func DialWithInit(addr string) (*MHFConn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + // Sign and entrance servers expect 8 NULL bytes to initialize the connection. + _, err = conn.Write(make([]byte, 8)) + if err != nil { + conn.Close() + return nil, fmt.Errorf("write init bytes to %s: %w", addr, err) + } + + return &MHFConn{ + CryptConn: NewCryptConn(conn), + RawConn: conn, + }, nil +} + +// DialDirect connects to addr without sending initialization bytes. +// Used for channel server connections. +func DialDirect(addr string) (*MHFConn, error) { + conn, err := net.Dial("tcp", addr) + if err != nil { + return nil, fmt.Errorf("dial %s: %w", addr, err) + } + + return &MHFConn{ + CryptConn: NewCryptConn(conn), + RawConn: conn, + }, nil +} + +// Close closes the underlying connection. +func (c *MHFConn) Close() error { + return c.RawConn.Close() +} diff --git a/cmd/protbot/conn/crypt_conn.go b/cmd/protbot/conn/crypt_conn.go new file mode 100644 index 000000000..e07bcf5f5 --- /dev/null +++ b/cmd/protbot/conn/crypt_conn.go @@ -0,0 +1,115 @@ +package conn + +import ( + "encoding/hex" + "errors" + "erupe-ce/network/crypto" + "fmt" + "io" + "net" +) + +// CryptConn is an MHF encrypted two-way connection. +// Adapted from Erupe's network/crypt_conn.go with config dependency removed. +// Hardcoded to ZZ mode (supports Pf0-based extended data size). +type CryptConn struct { + conn net.Conn + readKeyRot uint32 + sendKeyRot uint32 + sentPackets int32 + prevRecvPacketCombinedCheck uint16 + prevSendPacketCombinedCheck uint16 +} + +// NewCryptConn creates a new CryptConn with proper default values. +func NewCryptConn(conn net.Conn) *CryptConn { + return &CryptConn{ + conn: conn, + readKeyRot: 995117, + sendKeyRot: 995117, + } +} + +// ReadPacket reads a packet from the connection and returns the decrypted data. +func (cc *CryptConn) ReadPacket() ([]byte, error) { + headerData := make([]byte, CryptPacketHeaderLength) + _, err := io.ReadFull(cc.conn, headerData) + if err != nil { + return nil, err + } + + cph, err := NewCryptPacketHeader(headerData) + if err != nil { + return nil, err + } + + // ZZ mode: extended data size using Pf0 field. + encryptedPacketBody := make([]byte, uint32(cph.DataSize)+(uint32(cph.Pf0-0x03)*0x1000)) + _, err = io.ReadFull(cc.conn, encryptedPacketBody) + if err != nil { + return nil, err + } + + if cph.KeyRotDelta != 0 { + cc.readKeyRot = uint32(cph.KeyRotDelta) * (cc.readKeyRot + 1) + } + + out, combinedCheck, check0, check1, check2 := crypto.Crypto(encryptedPacketBody, cc.readKeyRot, false, nil) + if cph.Check0 != check0 || cph.Check1 != check1 || cph.Check2 != check2 { + fmt.Printf("got c0 %X, c1 %X, c2 %X\n", check0, check1, check2) + fmt.Printf("want c0 %X, c1 %X, c2 %X\n", cph.Check0, cph.Check1, cph.Check2) + fmt.Printf("headerData:\n%s\n", hex.Dump(headerData)) + fmt.Printf("encryptedPacketBody:\n%s\n", hex.Dump(encryptedPacketBody)) + + // Attempt bruteforce recovery. + fmt.Println("Crypto out of sync? Attempting bruteforce") + for key := byte(0); key < 255; key++ { + out, combinedCheck, check0, check1, check2 = crypto.Crypto(encryptedPacketBody, 0, false, &key) + if cph.Check0 == check0 && cph.Check1 == check1 && cph.Check2 == check2 { + fmt.Printf("Bruteforce successful, override key: 0x%X\n", key) + cc.prevRecvPacketCombinedCheck = combinedCheck + return out, nil + } + } + + return nil, errors.New("decrypted data checksum doesn't match header") + } + + cc.prevRecvPacketCombinedCheck = combinedCheck + return out, nil +} + +// SendPacket encrypts and sends a packet. +func (cc *CryptConn) SendPacket(data []byte) error { + keyRotDelta := byte(3) + + if keyRotDelta != 0 { + cc.sendKeyRot = uint32(keyRotDelta) * (cc.sendKeyRot + 1) + } + + encData, combinedCheck, check0, check1, check2 := crypto.Crypto(data, cc.sendKeyRot, true, nil) + + header := &CryptPacketHeader{} + header.Pf0 = byte(((uint(len(encData)) >> 12) & 0xF3) | 3) + header.KeyRotDelta = keyRotDelta + header.PacketNum = uint16(cc.sentPackets) + header.DataSize = uint16(len(encData)) + header.PrevPacketCombinedCheck = cc.prevSendPacketCombinedCheck + header.Check0 = check0 + header.Check1 = check1 + header.Check2 = check2 + + headerBytes, err := header.Encode() + if err != nil { + return err + } + + _, err = cc.conn.Write(append(headerBytes, encData...)) + if err != nil { + return err + } + cc.sentPackets++ + cc.prevSendPacketCombinedCheck = combinedCheck + + return nil +} diff --git a/cmd/protbot/conn/crypt_conn_test.go b/cmd/protbot/conn/crypt_conn_test.go new file mode 100644 index 000000000..cb03004db --- /dev/null +++ b/cmd/protbot/conn/crypt_conn_test.go @@ -0,0 +1,152 @@ +package conn + +import ( + "io" + "net" + "testing" +) + +// TestCryptConnRoundTrip verifies that encrypting and decrypting a packet +// through a pair of CryptConn instances produces the original data. +func TestCryptConnRoundTrip(t *testing.T) { + // Create an in-process TCP pipe. + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + sender := NewCryptConn(client) + receiver := NewCryptConn(server) + + testCases := [][]byte{ + {0x00, 0x14, 0x00, 0x00, 0x00, 0x01}, // Minimal login-like packet + {0xDE, 0xAD, 0xBE, 0xEF}, + make([]byte, 256), // Larger packet + } + + for i, original := range testCases { + // Send in a goroutine to avoid blocking. + errCh := make(chan error, 1) + go func() { + errCh <- sender.SendPacket(original) + }() + + received, err := receiver.ReadPacket() + if err != nil { + t.Fatalf("case %d: ReadPacket error: %v", i, err) + } + + if err := <-errCh; err != nil { + t.Fatalf("case %d: SendPacket error: %v", i, err) + } + + if len(received) != len(original) { + t.Fatalf("case %d: length mismatch: got %d, want %d", i, len(received), len(original)) + } + for j := range original { + if received[j] != original[j] { + t.Fatalf("case %d: byte %d mismatch: got 0x%02X, want 0x%02X", i, j, received[j], original[j]) + } + } + } +} + +// TestCryptPacketHeaderRoundTrip verifies header encode/decode. +func TestCryptPacketHeaderRoundTrip(t *testing.T) { + original := &CryptPacketHeader{ + Pf0: 0x03, + KeyRotDelta: 0x03, + PacketNum: 42, + DataSize: 100, + PrevPacketCombinedCheck: 0x1234, + Check0: 0xAAAA, + Check1: 0xBBBB, + Check2: 0xCCCC, + } + + encoded, err := original.Encode() + if err != nil { + t.Fatalf("Encode error: %v", err) + } + + if len(encoded) != CryptPacketHeaderLength { + t.Fatalf("encoded length: got %d, want %d", len(encoded), CryptPacketHeaderLength) + } + + decoded, err := NewCryptPacketHeader(encoded) + if err != nil { + t.Fatalf("NewCryptPacketHeader error: %v", err) + } + + if *decoded != *original { + t.Fatalf("header mismatch:\ngot %+v\nwant %+v", *decoded, *original) + } +} + +// TestMultiPacketSequence verifies that key rotation stays in sync across +// multiple sequential packets. +func TestMultiPacketSequence(t *testing.T) { + server, client := net.Pipe() + defer server.Close() + defer client.Close() + + sender := NewCryptConn(client) + receiver := NewCryptConn(server) + + for i := 0; i < 10; i++ { + data := []byte{byte(i), byte(i + 1), byte(i + 2), byte(i + 3)} + + errCh := make(chan error, 1) + go func() { + errCh <- sender.SendPacket(data) + }() + + received, err := receiver.ReadPacket() + if err != nil { + t.Fatalf("packet %d: ReadPacket error: %v", i, err) + } + + if err := <-errCh; err != nil { + t.Fatalf("packet %d: SendPacket error: %v", i, err) + } + + for j := range data { + if received[j] != data[j] { + t.Fatalf("packet %d byte %d: got 0x%02X, want 0x%02X", i, j, received[j], data[j]) + } + } + } +} + +// TestDialWithInit verifies that DialWithInit sends 8 NULL bytes on connect. +func TestDialWithInit(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer listener.Close() + + done := make(chan []byte, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + buf := make([]byte, 8) + _, _ = io.ReadFull(conn, buf) + done <- buf + }() + + c, err := DialWithInit(listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + + initBytes := <-done + for i, b := range initBytes { + if b != 0 { + t.Fatalf("init byte %d: got 0x%02X, want 0x00", i, b) + } + } +} diff --git a/cmd/protbot/conn/crypt_packet.go b/cmd/protbot/conn/crypt_packet.go new file mode 100644 index 000000000..058a7e2bb --- /dev/null +++ b/cmd/protbot/conn/crypt_packet.go @@ -0,0 +1,78 @@ +// Package conn provides MHF encrypted connection primitives. +// +// This is adapted from Erupe's network/crypt_packet.go to avoid importing +// erupe-ce/config (whose init() calls os.Exit without a config file). +package conn + +import ( + "bytes" + "encoding/binary" +) + +const CryptPacketHeaderLength = 14 + +// CryptPacketHeader represents the parsed information of an encrypted packet header. +type CryptPacketHeader struct { + Pf0 byte + KeyRotDelta byte + PacketNum uint16 + DataSize uint16 + PrevPacketCombinedCheck uint16 + Check0 uint16 + Check1 uint16 + Check2 uint16 +} + +// NewCryptPacketHeader parses raw bytes into a CryptPacketHeader. +func NewCryptPacketHeader(data []byte) (*CryptPacketHeader, error) { + var c CryptPacketHeader + r := bytes.NewReader(data) + + if err := binary.Read(r, binary.BigEndian, &c.Pf0); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.KeyRotDelta); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.PacketNum); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.DataSize); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.PrevPacketCombinedCheck); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check0); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check1); err != nil { + return nil, err + } + if err := binary.Read(r, binary.BigEndian, &c.Check2); err != nil { + return nil, err + } + + return &c, nil +} + +// Encode encodes the CryptPacketHeader into raw bytes. +func (c *CryptPacketHeader) Encode() ([]byte, error) { + buf := bytes.NewBuffer([]byte{}) + data := []interface{}{ + c.Pf0, + c.KeyRotDelta, + c.PacketNum, + c.DataSize, + c.PrevPacketCombinedCheck, + c.Check0, + c.Check1, + c.Check2, + } + for _, v := range data { + if err := binary.Write(buf, binary.BigEndian, v); err != nil { + return nil, err + } + } + return buf.Bytes(), nil +} diff --git a/cmd/protbot/main.go b/cmd/protbot/main.go new file mode 100644 index 000000000..4b1e0f72b --- /dev/null +++ b/cmd/protbot/main.go @@ -0,0 +1,154 @@ +// protbot is a headless MHF protocol bot for testing Erupe server instances. +// +// Usage: +// +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action login +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action lobby +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action session +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action chat --message "Hello" +// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action quests +package main + +import ( + "flag" + "fmt" + "os" + "os/signal" + "syscall" + + "erupe-ce/cmd/protbot/scenario" +) + +func main() { + signAddr := flag.String("sign-addr", "127.0.0.1:53312", "Sign server address (host:port)") + user := flag.String("user", "", "Username") + pass := flag.String("pass", "", "Password") + action := flag.String("action", "login", "Action to perform: login, lobby, session, chat, quests") + message := flag.String("message", "", "Chat message to send (used with --action chat)") + flag.Parse() + + if *user == "" || *pass == "" { + fmt.Fprintln(os.Stderr, "error: --user and --pass are required") + flag.Usage() + os.Exit(1) + } + + switch *action { + case "login": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + fmt.Println("[done] Login successful!") + result.Channel.Close() + + case "lobby": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + fmt.Println("[done] Lobby entry successful!") + result.Channel.Close() + + case "session": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + fmt.Println("[session] Connected. Press Ctrl+C to disconnect.") + waitForSignal() + scenario.Logout(result.Channel) + + case "chat": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + + // Register chat listener. + scenario.ListenChat(result.Channel, func(msg scenario.ChatMessage) { + fmt.Printf("[chat] <%s> (type=%d): %s\n", msg.SenderName, msg.ChatType, msg.Message) + }) + + // Send a message if provided. + if *message != "" { + if err := scenario.SendChat(result.Channel, 0x03, 1, *message, *user); err != nil { + fmt.Fprintf(os.Stderr, "send chat failed: %v\n", err) + } + } + + fmt.Println("[chat] Listening for chat messages. Press Ctrl+C to disconnect.") + waitForSignal() + scenario.Logout(result.Channel) + + case "quests": + result, err := scenario.Login(*signAddr, *user, *pass) + if err != nil { + fmt.Fprintf(os.Stderr, "login failed: %v\n", err) + os.Exit(1) + } + charID := result.Sign.CharIDs[0] + if _, err := scenario.SetupSession(result.Channel, charID); err != nil { + fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + if err := scenario.EnterLobby(result.Channel); err != nil { + fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err) + result.Channel.Close() + os.Exit(1) + } + + data, err := scenario.EnumerateQuests(result.Channel, 0, 0) + if err != nil { + fmt.Fprintf(os.Stderr, "enumerate quests failed: %v\n", err) + scenario.Logout(result.Channel) + os.Exit(1) + } + fmt.Printf("[quests] Received %d bytes of quest data\n", len(data)) + scenario.Logout(result.Channel) + + default: + fmt.Fprintf(os.Stderr, "unknown action: %s (supported: login, lobby, session, chat, quests)\n", *action) + os.Exit(1) + } +} + +// waitForSignal blocks until SIGINT or SIGTERM is received. +func waitForSignal() { + sig := make(chan os.Signal, 1) + signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM) + <-sig + fmt.Println("\n[signal] Shutting down...") +} diff --git a/cmd/protbot/protocol/channel.go b/cmd/protbot/protocol/channel.go new file mode 100644 index 000000000..ba8ff9552 --- /dev/null +++ b/cmd/protbot/protocol/channel.go @@ -0,0 +1,190 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "sync" + "sync/atomic" + "time" + + "erupe-ce/cmd/protbot/conn" +) + +// PacketHandler is a callback invoked when a server-pushed packet is received. +type PacketHandler func(opcode uint16, data []byte) + +// ChannelConn manages a connection to a channel server. +type ChannelConn struct { + conn *conn.MHFConn + ackCounter uint32 + waiters sync.Map // map[uint32]chan *AckResponse + handlers sync.Map // map[uint16]PacketHandler + closed atomic.Bool +} + +// OnPacket registers a handler for a specific server-pushed opcode. +// Only one handler per opcode; later registrations replace earlier ones. +func (ch *ChannelConn) OnPacket(opcode uint16, handler PacketHandler) { + ch.handlers.Store(opcode, handler) +} + +// AckResponse holds the parsed ACK data from the server. +type AckResponse struct { + AckHandle uint32 + IsBufferResponse bool + ErrorCode uint8 + Data []byte +} + +// ConnectChannel establishes a connection to a channel server. +// Channel servers do NOT use the 8 NULL byte initialization. +func ConnectChannel(addr string) (*ChannelConn, error) { + c, err := conn.DialDirect(addr) + if err != nil { + return nil, fmt.Errorf("channel connect: %w", err) + } + + ch := &ChannelConn{ + conn: c, + } + + go ch.recvLoop() + return ch, nil +} + +// NextAckHandle returns the next unique ACK handle for packet requests. +func (ch *ChannelConn) NextAckHandle() uint32 { + return atomic.AddUint32(&ch.ackCounter, 1) +} + +// SendPacket encrypts and sends raw packet data (including the 0x00 0x10 terminator +// which is already appended by the Build* functions in packets.go). +func (ch *ChannelConn) SendPacket(data []byte) error { + return ch.conn.SendPacket(data) +} + +// WaitForAck waits for an ACK response matching the given handle. +func (ch *ChannelConn) WaitForAck(handle uint32, timeout time.Duration) (*AckResponse, error) { + waitCh := make(chan *AckResponse, 1) + ch.waiters.Store(handle, waitCh) + defer ch.waiters.Delete(handle) + + select { + case resp := <-waitCh: + return resp, nil + case <-time.After(timeout): + return nil, fmt.Errorf("ACK timeout for handle %d", handle) + } +} + +// Close closes the channel connection. +func (ch *ChannelConn) Close() error { + ch.closed.Store(true) + return ch.conn.Close() +} + +// recvLoop continuously reads packets from the channel server and dispatches ACKs. +func (ch *ChannelConn) recvLoop() { + for { + if ch.closed.Load() { + return + } + + pkt, err := ch.conn.ReadPacket() + if err != nil { + if ch.closed.Load() { + return + } + fmt.Printf("[channel] read error: %v\n", err) + return + } + + if len(pkt) < 2 { + continue + } + + // Strip trailing 0x00 0x10 terminator if present for opcode parsing. + // Packets from server: [opcode uint16][fields...][0x00 0x10] + opcode := binary.BigEndian.Uint16(pkt[0:2]) + + switch opcode { + case MSG_SYS_ACK: + ch.handleAck(pkt[2:]) + case MSG_SYS_PING: + ch.handlePing(pkt[2:]) + default: + if val, ok := ch.handlers.Load(opcode); ok { + val.(PacketHandler)(opcode, pkt[2:]) + } else { + fmt.Printf("[channel] recv opcode 0x%04X (%d bytes)\n", opcode, len(pkt)) + } + } + } +} + +// handleAck parses an ACK packet and dispatches it to a waiting caller. +// Reference: Erupe network/mhfpacket/msg_sys_ack.go +func (ch *ChannelConn) handleAck(data []byte) { + if len(data) < 8 { + return + } + + ackHandle := binary.BigEndian.Uint32(data[0:4]) + isBuffer := data[4] > 0 + errorCode := data[5] + + var ackData []byte + if isBuffer { + payloadSize := binary.BigEndian.Uint16(data[6:8]) + offset := uint32(8) + if payloadSize == 0xFFFF { + if len(data) < 12 { + return + } + payloadSize32 := binary.BigEndian.Uint32(data[8:12]) + offset = 12 + if uint32(len(data)) >= offset+payloadSize32 { + ackData = data[offset : offset+payloadSize32] + } + } else { + if uint32(len(data)) >= offset+uint32(payloadSize) { + ackData = data[offset : offset+uint32(payloadSize)] + } + } + } else { + // Simple ACK: 4 bytes of data after the uint16 field. + if len(data) >= 12 { + ackData = data[8:12] + } + } + + resp := &AckResponse{ + AckHandle: ackHandle, + IsBufferResponse: isBuffer, + ErrorCode: errorCode, + Data: ackData, + } + + if val, ok := ch.waiters.Load(ackHandle); ok { + waitCh := val.(chan *AckResponse) + select { + case waitCh <- resp: + default: + } + } else { + fmt.Printf("[channel] unexpected ACK handle %d (error=%d, buffer=%v, %d bytes)\n", + ackHandle, errorCode, isBuffer, len(ackData)) + } +} + +// handlePing responds to a server ping to keep the connection alive. +func (ch *ChannelConn) handlePing(data []byte) { + var ackHandle uint32 + if len(data) >= 4 { + ackHandle = binary.BigEndian.Uint32(data[0:4]) + } + pkt := BuildPingPacket(ackHandle) + if err := ch.conn.SendPacket(pkt); err != nil { + fmt.Printf("[channel] ping response failed: %v\n", err) + } +} diff --git a/cmd/protbot/protocol/entrance.go b/cmd/protbot/protocol/entrance.go new file mode 100644 index 000000000..d7c516a3f --- /dev/null +++ b/cmd/protbot/protocol/entrance.go @@ -0,0 +1,142 @@ +package protocol + +import ( + "encoding/binary" + "fmt" + "net" + + "erupe-ce/common/byteframe" + + "erupe-ce/cmd/protbot/conn" +) + +// ServerEntry represents a channel server from the entrance server response. +type ServerEntry struct { + IP string + Port uint16 + Name string +} + +// DoEntrance connects to the entrance server and retrieves the server list. +// Reference: Erupe server/entranceserver/entrance_server.go and make_resp.go. +func DoEntrance(addr string) ([]ServerEntry, error) { + c, err := conn.DialWithInit(addr) + if err != nil { + return nil, fmt.Errorf("entrance connect: %w", err) + } + defer c.Close() + + // Send a minimal packet (the entrance server reads it, checks len > 5 for USR data). + // An empty/short packet triggers only SV2 response. + bf := byteframe.NewByteFrame() + bf.WriteUint8(0) + if err := c.SendPacket(bf.Data()); err != nil { + return nil, fmt.Errorf("entrance send: %w", err) + } + + resp, err := c.ReadPacket() + if err != nil { + return nil, fmt.Errorf("entrance recv: %w", err) + } + + return parseEntranceResponse(resp) +} + +// parseEntranceResponse parses the Bin8-encrypted entrance server response. +// Reference: Erupe server/entranceserver/make_resp.go (makeHeader, makeSv2Resp) +func parseEntranceResponse(data []byte) ([]ServerEntry, error) { + if len(data) < 2 { + return nil, fmt.Errorf("entrance response too short") + } + + // First byte is the Bin8 encryption key. + key := data[0] + decrypted := conn.DecryptBin8(data[1:], key) + + rbf := byteframe.NewByteFrameFromBytes(decrypted) + + // Read response type header: "SV2" or "SVR" + respType := string(rbf.ReadBytes(3)) + if respType != "SV2" && respType != "SVR" { + return nil, fmt.Errorf("unexpected entrance response type: %s", respType) + } + + entryCount := rbf.ReadUint16() + dataLen := rbf.ReadUint16() + if dataLen == 0 { + return nil, nil + } + expectedSum := rbf.ReadUint32() + serverData := rbf.ReadBytes(uint(dataLen)) + + actualSum := conn.CalcSum32(serverData) + if expectedSum != actualSum { + return nil, fmt.Errorf("entrance checksum mismatch: expected %08X, got %08X", expectedSum, actualSum) + } + + return parseServerEntries(serverData, entryCount) +} + +// parseServerEntries parses the server info binary blob. +// Reference: Erupe server/entranceserver/make_resp.go (encodeServerInfo) +func parseServerEntries(data []byte, entryCount uint16) ([]ServerEntry, error) { + bf := byteframe.NewByteFrameFromBytes(data) + var entries []ServerEntry + + for i := uint16(0); i < entryCount; i++ { + ipBytes := bf.ReadBytes(4) + ip := net.IP([]byte{ + byte(ipBytes[3]), byte(ipBytes[2]), + byte(ipBytes[1]), byte(ipBytes[0]), + }) + + _ = bf.ReadUint16() // serverIdx | 16 + _ = bf.ReadUint16() // 0 + channelCount := bf.ReadUint16() + _ = bf.ReadUint8() // Type + _ = bf.ReadUint8() // Season/rotation + + // G1+ recommended flag + _ = bf.ReadUint8() + + // G51+ (ZZ): skip 1 byte, then read 65-byte padded name + _ = bf.ReadUint8() + nameBytes := bf.ReadBytes(65) + + // GG+: AllowedClientFlags + _ = bf.ReadUint32() + + // Parse name (null-separated: name + description) + name := "" + for j := 0; j < len(nameBytes); j++ { + if nameBytes[j] == 0 { + break + } + name += string(nameBytes[j]) + } + + // Read channel entries (14 x uint16 = 28 bytes each) + for j := uint16(0); j < channelCount; j++ { + port := bf.ReadUint16() + _ = bf.ReadUint16() // channelIdx | 16 + _ = bf.ReadUint16() // maxPlayers + _ = bf.ReadUint16() // currentPlayers + _ = bf.ReadBytes(18) // remaining channel fields (9 x uint16: 6 zeros + unk319 + unk254 + unk255) + _ = bf.ReadUint16() // 12345 + + serverIP := ip.String() + // Convert 127.0.0.1 representation + if binary.LittleEndian.Uint32(ipBytes) == 0x0100007F { + serverIP = "127.0.0.1" + } + + entries = append(entries, ServerEntry{ + IP: serverIP, + Port: port, + Name: fmt.Sprintf("%s ch%d", name, j+1), + }) + } + } + + return entries, nil +} diff --git a/cmd/protbot/protocol/opcodes.go b/cmd/protbot/protocol/opcodes.go new file mode 100644 index 000000000..37c57a158 --- /dev/null +++ b/cmd/protbot/protocol/opcodes.go @@ -0,0 +1,23 @@ +// Package protocol implements MHF network protocol message building and parsing. +package protocol + +// Packet opcodes (subset from Erupe's network/packetid.go iota). +const ( + MSG_SYS_ACK uint16 = 0x0012 + MSG_SYS_LOGIN uint16 = 0x0014 + MSG_SYS_LOGOUT uint16 = 0x0015 + MSG_SYS_PING uint16 = 0x0017 + MSG_SYS_CAST_BINARY uint16 = 0x0018 + MSG_SYS_TIME uint16 = 0x001A + MSG_SYS_CASTED_BINARY uint16 = 0x001B + MSG_SYS_ISSUE_LOGKEY uint16 = 0x001D + MSG_SYS_ENTER_STAGE uint16 = 0x0022 + MSG_SYS_ENUMERATE_STAGE uint16 = 0x002F + MSG_SYS_INSERT_USER uint16 = 0x0050 + MSG_SYS_DELETE_USER uint16 = 0x0051 + MSG_SYS_UPDATE_RIGHT uint16 = 0x0058 + MSG_SYS_RIGHTS_RELOAD uint16 = 0x005D + MSG_MHF_LOADDATA uint16 = 0x0061 + MSG_MHF_ENUMERATE_QUEST uint16 = 0x009F + MSG_MHF_GET_WEEKLY_SCHED uint16 = 0x00E1 +) diff --git a/cmd/protbot/protocol/packets.go b/cmd/protbot/protocol/packets.go new file mode 100644 index 000000000..7c65f7804 --- /dev/null +++ b/cmd/protbot/protocol/packets.go @@ -0,0 +1,229 @@ +package protocol + +import ( + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" +) + +// BuildLoginPacket builds a MSG_SYS_LOGIN packet. +// Layout mirrors Erupe's MsgSysLogin.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint32 charID +// uint32 loginTokenNumber +// uint16 hardcodedZero +// uint16 requestVersion (set to 0xCAFE as dummy) +// uint32 charID (repeated) +// uint16 zeroed +// uint16 always 11 +// null-terminated tokenString +// 0x00 0x10 terminator +func BuildLoginPacket(ackHandle, charID, tokenNumber uint32, tokenString string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_LOGIN) + bf.WriteUint32(ackHandle) + bf.WriteUint32(charID) + bf.WriteUint32(tokenNumber) + bf.WriteUint16(0) // HardcodedZero0 + bf.WriteUint16(0xCAFE) // RequestVersion (dummy) + bf.WriteUint32(charID) // CharID1 (repeated) + bf.WriteUint16(0) // Zeroed + bf.WriteUint16(11) // Always 11 + bf.WriteNullTerminatedBytes([]byte(tokenString)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildEnumerateStagePacket builds a MSG_SYS_ENUMERATE_STAGE packet. +// Layout mirrors Erupe's MsgSysEnumerateStage.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint8 always 1 +// uint8 prefix length (including null terminator) +// null-terminated stagePrefix +// 0x00 0x10 terminator +func BuildEnumerateStagePacket(ackHandle uint32, prefix string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ENUMERATE_STAGE) + bf.WriteUint32(ackHandle) + bf.WriteUint8(1) // Always 1 + bf.WriteUint8(uint8(len(prefix) + 1)) // Length including null terminator + bf.WriteNullTerminatedBytes([]byte(prefix)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildEnterStagePacket builds a MSG_SYS_ENTER_STAGE packet. +// Layout mirrors Erupe's MsgSysEnterStage.Parse: +// +// uint16 opcode +// uint32 ackHandle +// uint8 isQuest (0=false) +// uint8 stageID length (including null terminator) +// null-terminated stageID +// 0x00 0x10 terminator +func BuildEnterStagePacket(ackHandle uint32, stageID string) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ENTER_STAGE) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // IsQuest = false + bf.WriteUint8(uint8(len(stageID) + 1)) // Length including null terminator + bf.WriteNullTerminatedBytes([]byte(stageID)) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildPingPacket builds a MSG_SYS_PING response packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildPingPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_PING) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildLogoutPacket builds a MSG_SYS_LOGOUT packet. +// +// uint16 opcode +// uint8 logoutType (1 = normal logout) +// 0x00 0x10 terminator +func BuildLogoutPacket() []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_LOGOUT) + bf.WriteUint8(1) // LogoutType = normal + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildIssueLogkeyPacket builds a MSG_SYS_ISSUE_LOGKEY packet. +// +// uint16 opcode +// uint32 ackHandle +// uint16 unk0 +// uint16 unk1 +// 0x00 0x10 terminator +func BuildIssueLogkeyPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_ISSUE_LOGKEY) + bf.WriteUint32(ackHandle) + bf.WriteUint16(0) + bf.WriteUint16(0) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildRightsReloadPacket builds a MSG_SYS_RIGHTS_RELOAD packet. +// +// uint16 opcode +// uint32 ackHandle +// uint8 count (0 = empty) +// 0x00 0x10 terminator +func BuildRightsReloadPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_RIGHTS_RELOAD) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // Count = 0 (no rights entries) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildLoaddataPacket builds a MSG_MHF_LOADDATA packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildLoaddataPacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_LOADDATA) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildCastBinaryPacket builds a MSG_SYS_CAST_BINARY packet. +// Layout mirrors Erupe's MsgSysCastBinary.Parse: +// +// uint16 opcode +// uint32 unk (always 0) +// uint8 broadcastType +// uint8 messageType +// uint16 dataSize +// []byte payload +// 0x00 0x10 terminator +func BuildCastBinaryPacket(broadcastType, messageType uint8, payload []byte) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_SYS_CAST_BINARY) + bf.WriteUint32(0) // Unk + bf.WriteUint8(broadcastType) + bf.WriteUint8(messageType) + bf.WriteUint16(uint16(len(payload))) + bf.WriteBytes(payload) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildChatPayload builds the inner MsgBinChat binary blob for use with BuildCastBinaryPacket. +// Layout mirrors Erupe's binpacket/msg_bin_chat.go Build: +// +// uint8 unk0 (always 0) +// uint8 chatType +// uint16 flags (always 0) +// uint16 senderNameLen (SJIS bytes + null terminator) +// uint16 messageLen (SJIS bytes + null terminator) +// null-terminated SJIS message +// null-terminated SJIS senderName +func BuildChatPayload(chatType uint8, message, senderName string) []byte { + sjisMsg := stringsupport.UTF8ToSJIS(message) + sjisName := stringsupport.UTF8ToSJIS(senderName) + bf := byteframe.NewByteFrame() + bf.WriteUint8(0) // Unk0 + bf.WriteUint8(chatType) // Type + bf.WriteUint16(0) // Flags + bf.WriteUint16(uint16(len(sjisName) + 1)) // SenderName length (+ null term) + bf.WriteUint16(uint16(len(sjisMsg) + 1)) // Message length (+ null term) + bf.WriteNullTerminatedBytes(sjisMsg) // Message + bf.WriteNullTerminatedBytes(sjisName) // SenderName + return bf.Data() +} + +// BuildEnumerateQuestPacket builds a MSG_MHF_ENUMERATE_QUEST packet. +// +// uint16 opcode +// uint32 ackHandle +// uint8 unk0 (always 0) +// uint8 world +// uint16 counter +// uint16 offset +// uint8 unk1 (always 0) +// 0x00 0x10 terminator +func BuildEnumerateQuestPacket(ackHandle uint32, world uint8, counter, offset uint16) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_ENUMERATE_QUEST) + bf.WriteUint32(ackHandle) + bf.WriteUint8(0) // Unk0 + bf.WriteUint8(world) + bf.WriteUint16(counter) + bf.WriteUint16(offset) + bf.WriteUint8(0) // Unk1 + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} + +// BuildGetWeeklySchedulePacket builds a MSG_MHF_GET_WEEKLY_SCHEDULE packet. +// +// uint16 opcode +// uint32 ackHandle +// 0x00 0x10 terminator +func BuildGetWeeklySchedulePacket(ackHandle uint32) []byte { + bf := byteframe.NewByteFrame() + bf.WriteUint16(MSG_MHF_GET_WEEKLY_SCHED) + bf.WriteUint32(ackHandle) + bf.WriteBytes([]byte{0x00, 0x10}) + return bf.Data() +} diff --git a/cmd/protbot/protocol/packets_test.go b/cmd/protbot/protocol/packets_test.go new file mode 100644 index 000000000..2b348f419 --- /dev/null +++ b/cmd/protbot/protocol/packets_test.go @@ -0,0 +1,412 @@ +package protocol + +import ( + "encoding/binary" + "testing" + + "erupe-ce/common/byteframe" +) + +// TestBuildLoginPacket verifies that the binary layout matches Erupe's Parse. +func TestBuildLoginPacket(t *testing.T) { + ackHandle := uint32(1) + charID := uint32(100) + tokenNumber := uint32(42) + tokenString := "0123456789ABCDEF" + + pkt := BuildLoginPacket(ackHandle, charID, tokenNumber, tokenString) + + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_LOGIN { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_LOGIN) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + gotCharID0 := bf.ReadUint32() + if gotCharID0 != charID { + t.Fatalf("charID0: got %d, want %d", gotCharID0, charID) + } + + gotTokenNum := bf.ReadUint32() + if gotTokenNum != tokenNumber { + t.Fatalf("tokenNumber: got %d, want %d", gotTokenNum, tokenNumber) + } + + gotZero := bf.ReadUint16() + if gotZero != 0 { + t.Fatalf("hardcodedZero: got %d, want 0", gotZero) + } + + gotVersion := bf.ReadUint16() + if gotVersion != 0xCAFE { + t.Fatalf("requestVersion: got 0x%04X, want 0xCAFE", gotVersion) + } + + gotCharID1 := bf.ReadUint32() + if gotCharID1 != charID { + t.Fatalf("charID1: got %d, want %d", gotCharID1, charID) + } + + gotZeroed := bf.ReadUint16() + if gotZeroed != 0 { + t.Fatalf("zeroed: got %d, want 0", gotZeroed) + } + + gotEleven := bf.ReadUint16() + if gotEleven != 11 { + t.Fatalf("always11: got %d, want 11", gotEleven) + } + + gotToken := string(bf.ReadNullTerminatedBytes()) + if gotToken != tokenString { + t.Fatalf("tokenString: got %q, want %q", gotToken, tokenString) + } + + // Verify terminator. + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildEnumerateStagePacket verifies binary layout matches Erupe's Parse. +func TestBuildEnumerateStagePacket(t *testing.T) { + ackHandle := uint32(5) + prefix := "sl1Ns" + + pkt := BuildEnumerateStagePacket(ackHandle, prefix) + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_ENUMERATE_STAGE { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENUMERATE_STAGE) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + alwaysOne := bf.ReadUint8() + if alwaysOne != 1 { + t.Fatalf("alwaysOne: got %d, want 1", alwaysOne) + } + + prefixLen := bf.ReadUint8() + if prefixLen != uint8(len(prefix)+1) { + t.Fatalf("prefixLen: got %d, want %d", prefixLen, len(prefix)+1) + } + + gotPrefix := string(bf.ReadNullTerminatedBytes()) + if gotPrefix != prefix { + t.Fatalf("prefix: got %q, want %q", gotPrefix, prefix) + } + + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildEnterStagePacket verifies binary layout matches Erupe's Parse. +func TestBuildEnterStagePacket(t *testing.T) { + ackHandle := uint32(7) + stageID := "sl1Ns200p0a0u0" + + pkt := BuildEnterStagePacket(ackHandle, stageID) + bf := byteframe.NewByteFrameFromBytes(pkt) + + opcode := bf.ReadUint16() + if opcode != MSG_SYS_ENTER_STAGE { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENTER_STAGE) + } + + gotAck := bf.ReadUint32() + if gotAck != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle) + } + + isQuest := bf.ReadUint8() + if isQuest != 0 { + t.Fatalf("isQuest: got %d, want 0", isQuest) + } + + stageLen := bf.ReadUint8() + if stageLen != uint8(len(stageID)+1) { + t.Fatalf("stageLen: got %d, want %d", stageLen, len(stageID)+1) + } + + gotStage := string(bf.ReadNullTerminatedBytes()) + if gotStage != stageID { + t.Fatalf("stageID: got %q, want %q", gotStage, stageID) + } + + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildPingPacket verifies MSG_SYS_PING binary layout. +func TestBuildPingPacket(t *testing.T) { + ackHandle := uint32(99) + pkt := BuildPingPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_PING { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_PING) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildLogoutPacket verifies MSG_SYS_LOGOUT binary layout. +func TestBuildLogoutPacket(t *testing.T) { + pkt := BuildLogoutPacket() + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_LOGOUT { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_LOGOUT) + } + if lt := bf.ReadUint8(); lt != 1 { + t.Fatalf("logoutType: got %d, want 1", lt) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildIssueLogkeyPacket verifies MSG_SYS_ISSUE_LOGKEY binary layout. +func TestBuildIssueLogkeyPacket(t *testing.T) { + ackHandle := uint32(10) + pkt := BuildIssueLogkeyPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_ISSUE_LOGKEY { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_ISSUE_LOGKEY) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if v := bf.ReadUint16(); v != 0 { + t.Fatalf("unk0: got %d, want 0", v) + } + if v := bf.ReadUint16(); v != 0 { + t.Fatalf("unk1: got %d, want 0", v) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildRightsReloadPacket verifies MSG_SYS_RIGHTS_RELOAD binary layout. +func TestBuildRightsReloadPacket(t *testing.T) { + ackHandle := uint32(20) + pkt := BuildRightsReloadPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_RIGHTS_RELOAD { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_RIGHTS_RELOAD) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if c := bf.ReadUint8(); c != 0 { + t.Fatalf("count: got %d, want 0", c) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildLoaddataPacket verifies MSG_MHF_LOADDATA binary layout. +func TestBuildLoaddataPacket(t *testing.T) { + ackHandle := uint32(30) + pkt := BuildLoaddataPacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_LOADDATA { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_LOADDATA) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildCastBinaryPacket verifies MSG_SYS_CAST_BINARY binary layout. +func TestBuildCastBinaryPacket(t *testing.T) { + payload := []byte{0xDE, 0xAD, 0xBE, 0xEF} + pkt := BuildCastBinaryPacket(0x03, 1, payload) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_SYS_CAST_BINARY { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_CAST_BINARY) + } + if unk := bf.ReadUint32(); unk != 0 { + t.Fatalf("unk: got %d, want 0", unk) + } + if bt := bf.ReadUint8(); bt != 0x03 { + t.Fatalf("broadcastType: got %d, want 3", bt) + } + if mt := bf.ReadUint8(); mt != 1 { + t.Fatalf("messageType: got %d, want 1", mt) + } + if ds := bf.ReadUint16(); ds != uint16(len(payload)) { + t.Fatalf("dataSize: got %d, want %d", ds, len(payload)) + } + gotPayload := bf.ReadBytes(uint(len(payload))) + for i, b := range payload { + if gotPayload[i] != b { + t.Fatalf("payload[%d]: got 0x%02X, want 0x%02X", i, gotPayload[i], b) + } + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildChatPayload verifies the MsgBinChat inner binary layout and SJIS encoding. +func TestBuildChatPayload(t *testing.T) { + chatType := uint8(1) + message := "Hello" + senderName := "TestUser" + + payload := BuildChatPayload(chatType, message, senderName) + bf := byteframe.NewByteFrameFromBytes(payload) + + if unk := bf.ReadUint8(); unk != 0 { + t.Fatalf("unk0: got %d, want 0", unk) + } + if ct := bf.ReadUint8(); ct != chatType { + t.Fatalf("chatType: got %d, want %d", ct, chatType) + } + if flags := bf.ReadUint16(); flags != 0 { + t.Fatalf("flags: got %d, want 0", flags) + } + nameLen := bf.ReadUint16() + msgLen := bf.ReadUint16() + // "Hello" in ASCII/SJIS = 5 bytes + 1 null = 6 + if msgLen != 6 { + t.Fatalf("messageLen: got %d, want 6", msgLen) + } + // "TestUser" in ASCII/SJIS = 8 bytes + 1 null = 9 + if nameLen != 9 { + t.Fatalf("senderNameLen: got %d, want 9", nameLen) + } + + gotMsg := string(bf.ReadNullTerminatedBytes()) + if gotMsg != message { + t.Fatalf("message: got %q, want %q", gotMsg, message) + } + gotName := string(bf.ReadNullTerminatedBytes()) + if gotName != senderName { + t.Fatalf("senderName: got %q, want %q", gotName, senderName) + } +} + +// TestBuildEnumerateQuestPacket verifies MSG_MHF_ENUMERATE_QUEST binary layout. +func TestBuildEnumerateQuestPacket(t *testing.T) { + ackHandle := uint32(40) + world := uint8(2) + counter := uint16(100) + offset := uint16(50) + + pkt := BuildEnumerateQuestPacket(ackHandle, world, counter, offset) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_ENUMERATE_QUEST { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_ENUMERATE_QUEST) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + if u0 := bf.ReadUint8(); u0 != 0 { + t.Fatalf("unk0: got %d, want 0", u0) + } + if w := bf.ReadUint8(); w != world { + t.Fatalf("world: got %d, want %d", w, world) + } + if c := bf.ReadUint16(); c != counter { + t.Fatalf("counter: got %d, want %d", c, counter) + } + if o := bf.ReadUint16(); o != offset { + t.Fatalf("offset: got %d, want %d", o, offset) + } + if u1 := bf.ReadUint8(); u1 != 0 { + t.Fatalf("unk1: got %d, want 0", u1) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestBuildGetWeeklySchedulePacket verifies MSG_MHF_GET_WEEKLY_SCHEDULE binary layout. +func TestBuildGetWeeklySchedulePacket(t *testing.T) { + ackHandle := uint32(50) + pkt := BuildGetWeeklySchedulePacket(ackHandle) + bf := byteframe.NewByteFrameFromBytes(pkt) + + if op := bf.ReadUint16(); op != MSG_MHF_GET_WEEKLY_SCHED { + t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_GET_WEEKLY_SCHED) + } + if ack := bf.ReadUint32(); ack != ackHandle { + t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle) + } + term := bf.ReadBytes(2) + if term[0] != 0x00 || term[1] != 0x10 { + t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1]) + } +} + +// TestOpcodeValues verifies opcode constants match Erupe's iota-based enum. +func TestOpcodeValues(t *testing.T) { + _ = binary.BigEndian // ensure import used + tests := []struct { + name string + got uint16 + want uint16 + }{ + {"MSG_SYS_ACK", MSG_SYS_ACK, 0x0012}, + {"MSG_SYS_LOGIN", MSG_SYS_LOGIN, 0x0014}, + {"MSG_SYS_LOGOUT", MSG_SYS_LOGOUT, 0x0015}, + {"MSG_SYS_PING", MSG_SYS_PING, 0x0017}, + {"MSG_SYS_CAST_BINARY", MSG_SYS_CAST_BINARY, 0x0018}, + {"MSG_SYS_TIME", MSG_SYS_TIME, 0x001A}, + {"MSG_SYS_CASTED_BINARY", MSG_SYS_CASTED_BINARY, 0x001B}, + {"MSG_SYS_ISSUE_LOGKEY", MSG_SYS_ISSUE_LOGKEY, 0x001D}, + {"MSG_SYS_ENTER_STAGE", MSG_SYS_ENTER_STAGE, 0x0022}, + {"MSG_SYS_ENUMERATE_STAGE", MSG_SYS_ENUMERATE_STAGE, 0x002F}, + {"MSG_SYS_INSERT_USER", MSG_SYS_INSERT_USER, 0x0050}, + {"MSG_SYS_DELETE_USER", MSG_SYS_DELETE_USER, 0x0051}, + {"MSG_SYS_UPDATE_RIGHT", MSG_SYS_UPDATE_RIGHT, 0x0058}, + {"MSG_SYS_RIGHTS_RELOAD", MSG_SYS_RIGHTS_RELOAD, 0x005D}, + {"MSG_MHF_LOADDATA", MSG_MHF_LOADDATA, 0x0061}, + {"MSG_MHF_ENUMERATE_QUEST", MSG_MHF_ENUMERATE_QUEST, 0x009F}, + {"MSG_MHF_GET_WEEKLY_SCHED", MSG_MHF_GET_WEEKLY_SCHED, 0x00E1}, + } + for _, tt := range tests { + if tt.got != tt.want { + t.Errorf("%s: got 0x%04X, want 0x%04X", tt.name, tt.got, tt.want) + } + } +} diff --git a/cmd/protbot/protocol/sign.go b/cmd/protbot/protocol/sign.go new file mode 100644 index 000000000..4f6670b6f --- /dev/null +++ b/cmd/protbot/protocol/sign.go @@ -0,0 +1,106 @@ +package protocol + +import ( + "fmt" + + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" + + "erupe-ce/cmd/protbot/conn" +) + +// SignResult holds the parsed response from a successful DSGN sign-in. +type SignResult struct { + TokenID uint32 + TokenString string // 16 raw bytes as string + Timestamp uint32 + EntranceAddr string + CharIDs []uint32 +} + +// DoSign connects to the sign server and performs a DSGN login. +// Reference: Erupe server/signserver/session.go (handleDSGN) and dsgn_resp.go (makeSignResponse). +func DoSign(addr, username, password string) (*SignResult, error) { + c, err := conn.DialWithInit(addr) + if err != nil { + return nil, fmt.Errorf("sign connect: %w", err) + } + defer c.Close() + + // Build DSGN request: "DSGN:041" + \x00 + SJIS(user) + \x00 + SJIS(pass) + \x00 + \x00 + // The server reads: null-terminated request type, null-terminated user, null-terminated pass, null-terminated unk. + // The request type has a 3-char version suffix (e.g. "041" for ZZ client mode 41) that the server strips. + bf := byteframe.NewByteFrame() + bf.WriteNullTerminatedBytes([]byte("DSGN:041")) // reqType with version suffix (server strips last 3 chars to get "DSGN:") + bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(username)) + bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(password)) + bf.WriteUint8(0) // Unk null-terminated empty string + + if err := c.SendPacket(bf.Data()); err != nil { + return nil, fmt.Errorf("sign send: %w", err) + } + + resp, err := c.ReadPacket() + if err != nil { + return nil, fmt.Errorf("sign recv: %w", err) + } + + return parseSignResponse(resp) +} + +// parseSignResponse parses the binary response from the sign server. +// Reference: Erupe server/signserver/dsgn_resp.go:makeSignResponse +func parseSignResponse(data []byte) (*SignResult, error) { + if len(data) < 1 { + return nil, fmt.Errorf("empty sign response") + } + + rbf := byteframe.NewByteFrameFromBytes(data) + + resultCode := rbf.ReadUint8() + if resultCode != 1 { // SIGN_SUCCESS = 1 + return nil, fmt.Errorf("sign failed with code %d", resultCode) + } + + patchCount := rbf.ReadUint8() // patch server count (usually 2) + _ = rbf.ReadUint8() // entrance server count (usually 1) + charCount := rbf.ReadUint8() // character count + + result := &SignResult{} + result.TokenID = rbf.ReadUint32() + result.TokenString = string(rbf.ReadBytes(16)) // 16 raw bytes + result.Timestamp = rbf.ReadUint32() + + // Skip patch server URLs (pascal strings with uint8 length prefix) + for i := uint8(0); i < patchCount; i++ { + strLen := rbf.ReadUint8() + _ = rbf.ReadBytes(uint(strLen)) + } + + // Read entrance server address (pascal string with uint8 length prefix) + entranceLen := rbf.ReadUint8() + result.EntranceAddr = string(rbf.ReadBytes(uint(entranceLen - 1))) + _ = rbf.ReadUint8() // null terminator + + // Read character entries + for i := uint8(0); i < charCount; i++ { + charID := rbf.ReadUint32() + result.CharIDs = append(result.CharIDs, charID) + + _ = rbf.ReadUint16() // HR + _ = rbf.ReadUint16() // WeaponType + _ = rbf.ReadUint32() // LastLogin + _ = rbf.ReadUint8() // IsFemale + _ = rbf.ReadUint8() // IsNewCharacter + _ = rbf.ReadUint8() // Old GR + _ = rbf.ReadUint8() // Use uint16 GR flag + _ = rbf.ReadBytes(16) // Character name (padded) + _ = rbf.ReadBytes(32) // Unk desc string (padded) + // ZZ mode: additional fields + _ = rbf.ReadUint16() // GR + _ = rbf.ReadUint8() // Unk + _ = rbf.ReadUint8() // Unk + } + + return result, nil +} diff --git a/cmd/protbot/scenario/chat.go b/cmd/protbot/scenario/chat.go new file mode 100644 index 000000000..61394dc4b --- /dev/null +++ b/cmd/protbot/scenario/chat.go @@ -0,0 +1,74 @@ +package scenario + +import ( + "fmt" + + "erupe-ce/common/byteframe" + "erupe-ce/common/stringsupport" + + "erupe-ce/cmd/protbot/protocol" +) + +// ChatMessage holds a parsed incoming chat message. +type ChatMessage struct { + ChatType uint8 + SenderName string + Message string +} + +// SendChat sends a chat message via MSG_SYS_CAST_BINARY with a MsgBinChat payload. +// broadcastType controls delivery scope: 0x03 = stage, 0x06 = world. +func SendChat(ch *protocol.ChannelConn, broadcastType, chatType uint8, message, senderName string) error { + payload := protocol.BuildChatPayload(chatType, message, senderName) + pkt := protocol.BuildCastBinaryPacket(broadcastType, 1, payload) + fmt.Printf("[chat] Sending chat (type=%d, broadcast=%d): %s\n", chatType, broadcastType, message) + return ch.SendPacket(pkt) +} + +// ChatCallback is invoked when a chat message is received. +type ChatCallback func(msg ChatMessage) + +// ListenChat registers a handler on MSG_SYS_CASTED_BINARY that parses chat +// messages (messageType=1) and invokes the callback. +func ListenChat(ch *protocol.ChannelConn, cb ChatCallback) { + ch.OnPacket(protocol.MSG_SYS_CASTED_BINARY, func(opcode uint16, data []byte) { + // MSG_SYS_CASTED_BINARY layout from server: + // uint32 unk + // uint8 broadcastType + // uint8 messageType + // uint16 dataSize + // []byte payload + if len(data) < 8 { + return + } + messageType := data[5] + if messageType != 1 { // Only handle chat messages. + return + } + bf := byteframe.NewByteFrameFromBytes(data) + _ = bf.ReadUint32() // unk + _ = bf.ReadUint8() // broadcastType + _ = bf.ReadUint8() // messageType + dataSize := bf.ReadUint16() + if dataSize == 0 { + return + } + payload := bf.ReadBytes(uint(dataSize)) + + // Parse MsgBinChat inner payload. + pbf := byteframe.NewByteFrameFromBytes(payload) + _ = pbf.ReadUint8() // unk0 + chatType := pbf.ReadUint8() + _ = pbf.ReadUint16() // flags + _ = pbf.ReadUint16() // senderNameLen + _ = pbf.ReadUint16() // messageLen + msg := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes()) + sender := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes()) + + cb(ChatMessage{ + ChatType: chatType, + SenderName: sender, + Message: msg, + }) + }) +} diff --git a/cmd/protbot/scenario/login.go b/cmd/protbot/scenario/login.go new file mode 100644 index 000000000..a90941ef0 --- /dev/null +++ b/cmd/protbot/scenario/login.go @@ -0,0 +1,82 @@ +// Package scenario provides high-level MHF protocol flows. +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// LoginResult holds the outcome of a full login flow. +type LoginResult struct { + Sign *protocol.SignResult + Servers []protocol.ServerEntry + Channel *protocol.ChannelConn +} + +// Login performs the full sign → entrance → channel login flow. +func Login(signAddr, username, password string) (*LoginResult, error) { + // Step 1: Sign server authentication. + fmt.Printf("[sign] Connecting to %s...\n", signAddr) + sign, err := protocol.DoSign(signAddr, username, password) + if err != nil { + return nil, fmt.Errorf("sign: %w", err) + } + fmt.Printf("[sign] OK — tokenID=%d, %d character(s), entrance=%s\n", + sign.TokenID, len(sign.CharIDs), sign.EntranceAddr) + + if len(sign.CharIDs) == 0 { + return nil, fmt.Errorf("no characters on account") + } + + // Step 2: Entrance server — get server/channel list. + fmt.Printf("[entrance] Connecting to %s...\n", sign.EntranceAddr) + servers, err := protocol.DoEntrance(sign.EntranceAddr) + if err != nil { + return nil, fmt.Errorf("entrance: %w", err) + } + if len(servers) == 0 { + return nil, fmt.Errorf("no channels available") + } + for i, s := range servers { + fmt.Printf("[entrance] [%d] %s — %s:%d\n", i, s.Name, s.IP, s.Port) + } + + // Step 3: Connect to the first channel server. + first := servers[0] + channelAddr := fmt.Sprintf("%s:%d", first.IP, first.Port) + fmt.Printf("[channel] Connecting to %s...\n", channelAddr) + ch, err := protocol.ConnectChannel(channelAddr) + if err != nil { + return nil, fmt.Errorf("channel connect: %w", err) + } + + // Step 4: Send MSG_SYS_LOGIN. + charID := sign.CharIDs[0] + ack := ch.NextAckHandle() + loginPkt := protocol.BuildLoginPacket(ack, charID, sign.TokenID, sign.TokenString) + fmt.Printf("[channel] Sending MSG_SYS_LOGIN (charID=%d, ackHandle=%d)...\n", charID, ack) + if err := ch.SendPacket(loginPkt); err != nil { + ch.Close() + return nil, fmt.Errorf("channel send login: %w", err) + } + + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + ch.Close() + return nil, fmt.Errorf("channel login ack: %w", err) + } + if resp.ErrorCode != 0 { + ch.Close() + return nil, fmt.Errorf("channel login failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[channel] Login ACK received (error=%d, %d bytes data)\n", + resp.ErrorCode, len(resp.Data)) + + return &LoginResult{ + Sign: sign, + Servers: servers, + Channel: ch, + }, nil +} diff --git a/cmd/protbot/scenario/logout.go b/cmd/protbot/scenario/logout.go new file mode 100644 index 000000000..692c97dda --- /dev/null +++ b/cmd/protbot/scenario/logout.go @@ -0,0 +1,17 @@ +package scenario + +import ( + "fmt" + + "erupe-ce/cmd/protbot/protocol" +) + +// Logout sends MSG_SYS_LOGOUT and closes the channel connection. +func Logout(ch *protocol.ChannelConn) error { + fmt.Println("[logout] Sending MSG_SYS_LOGOUT...") + if err := ch.SendPacket(protocol.BuildLogoutPacket()); err != nil { + ch.Close() + return fmt.Errorf("logout send: %w", err) + } + return ch.Close() +} diff --git a/cmd/protbot/scenario/quest.go b/cmd/protbot/scenario/quest.go new file mode 100644 index 000000000..2b3c0b2eb --- /dev/null +++ b/cmd/protbot/scenario/quest.go @@ -0,0 +1,31 @@ +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// EnumerateQuests sends MSG_MHF_ENUMERATE_QUEST and returns the raw quest list data. +func EnumerateQuests(ch *protocol.ChannelConn, world uint8, counter uint16) ([]byte, error) { + ack := ch.NextAckHandle() + pkt := protocol.BuildEnumerateQuestPacket(ack, world, counter, 0) + fmt.Printf("[quest] Sending MSG_MHF_ENUMERATE_QUEST (world=%d, counter=%d, ackHandle=%d)...\n", + world, counter, ack) + if err := ch.SendPacket(pkt); err != nil { + return nil, fmt.Errorf("enumerate quest send: %w", err) + } + + resp, err := ch.WaitForAck(ack, 15*time.Second) + if err != nil { + return nil, fmt.Errorf("enumerate quest ack: %w", err) + } + if resp.ErrorCode != 0 { + return nil, fmt.Errorf("enumerate quest failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[quest] ENUMERATE_QUEST ACK (error=%d, %d bytes data)\n", + resp.ErrorCode, len(resp.Data)) + + return resp.Data, nil +} diff --git a/cmd/protbot/scenario/session.go b/cmd/protbot/scenario/session.go new file mode 100644 index 000000000..0f49f8795 --- /dev/null +++ b/cmd/protbot/scenario/session.go @@ -0,0 +1,50 @@ +package scenario + +import ( + "fmt" + "time" + + "erupe-ce/cmd/protbot/protocol" +) + +// SetupSession performs the post-login session setup: ISSUE_LOGKEY, RIGHTS_RELOAD, LOADDATA. +// Returns the loaddata response blob for inspection. +func SetupSession(ch *protocol.ChannelConn, charID uint32) ([]byte, error) { + // Step 1: Issue logkey. + ack := ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_SYS_ISSUE_LOGKEY (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildIssueLogkeyPacket(ack)); err != nil { + return nil, fmt.Errorf("issue logkey send: %w", err) + } + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return nil, fmt.Errorf("issue logkey ack: %w", err) + } + fmt.Printf("[session] ISSUE_LOGKEY ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + // Step 2: Rights reload. + ack = ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_SYS_RIGHTS_RELOAD (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildRightsReloadPacket(ack)); err != nil { + return nil, fmt.Errorf("rights reload send: %w", err) + } + resp, err = ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return nil, fmt.Errorf("rights reload ack: %w", err) + } + fmt.Printf("[session] RIGHTS_RELOAD ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + // Step 3: Load save data. + ack = ch.NextAckHandle() + fmt.Printf("[session] Sending MSG_MHF_LOADDATA (ackHandle=%d)...\n", ack) + if err := ch.SendPacket(protocol.BuildLoaddataPacket(ack)); err != nil { + return nil, fmt.Errorf("loaddata send: %w", err) + } + resp, err = ch.WaitForAck(ack, 30*time.Second) + if err != nil { + return nil, fmt.Errorf("loaddata ack: %w", err) + } + fmt.Printf("[session] LOADDATA ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data)) + + return resp.Data, nil +} diff --git a/cmd/protbot/scenario/stage.go b/cmd/protbot/scenario/stage.go new file mode 100644 index 000000000..27b5b757d --- /dev/null +++ b/cmd/protbot/scenario/stage.go @@ -0,0 +1,111 @@ +package scenario + +import ( + "encoding/binary" + "fmt" + "time" + + "erupe-ce/common/byteframe" + + "erupe-ce/cmd/protbot/protocol" +) + +// StageInfo holds a parsed stage entry from MSG_SYS_ENUMERATE_STAGE response. +type StageInfo struct { + ID string + Reserved uint16 + Clients uint16 + Displayed uint16 + MaxPlayers uint16 + Flags uint8 +} + +// EnterLobby enumerates available lobby stages and enters the first one. +func EnterLobby(ch *protocol.ChannelConn) error { + // Step 1: Enumerate stages with "sl1Ns" prefix (main lobby stages). + ack := ch.NextAckHandle() + enumPkt := protocol.BuildEnumerateStagePacket(ack, "sl1Ns") + fmt.Printf("[stage] Sending MSG_SYS_ENUMERATE_STAGE (prefix=\"sl1Ns\", ackHandle=%d)...\n", ack) + if err := ch.SendPacket(enumPkt); err != nil { + return fmt.Errorf("enumerate stage send: %w", err) + } + + resp, err := ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return fmt.Errorf("enumerate stage ack: %w", err) + } + if resp.ErrorCode != 0 { + return fmt.Errorf("enumerate stage failed: error code %d", resp.ErrorCode) + } + + stages := parseEnumerateStageResponse(resp.Data) + fmt.Printf("[stage] Found %d stage(s)\n", len(stages)) + for i, s := range stages { + fmt.Printf("[stage] [%d] %s — %d/%d players, flags=0x%02X\n", + i, s.ID, s.Clients, s.MaxPlayers, s.Flags) + } + + // Step 2: Enter the default lobby stage. + // Even if no stages were enumerated, use the default stage ID. + stageID := "sl1Ns200p0a0u0" + if len(stages) > 0 { + stageID = stages[0].ID + } + + ack = ch.NextAckHandle() + enterPkt := protocol.BuildEnterStagePacket(ack, stageID) + fmt.Printf("[stage] Sending MSG_SYS_ENTER_STAGE (stageID=%q, ackHandle=%d)...\n", stageID, ack) + if err := ch.SendPacket(enterPkt); err != nil { + return fmt.Errorf("enter stage send: %w", err) + } + + resp, err = ch.WaitForAck(ack, 10*time.Second) + if err != nil { + return fmt.Errorf("enter stage ack: %w", err) + } + if resp.ErrorCode != 0 { + return fmt.Errorf("enter stage failed: error code %d", resp.ErrorCode) + } + fmt.Printf("[stage] Enter stage ACK received (error=%d)\n", resp.ErrorCode) + + return nil +} + +// parseEnumerateStageResponse parses the ACK data from MSG_SYS_ENUMERATE_STAGE. +// Reference: Erupe server/channelserver/handlers_stage.go (handleMsgSysEnumerateStage) +func parseEnumerateStageResponse(data []byte) []StageInfo { + if len(data) < 2 { + return nil + } + + bf := byteframe.NewByteFrameFromBytes(data) + count := bf.ReadUint16() + + var stages []StageInfo + for i := uint16(0); i < count; i++ { + s := StageInfo{} + s.Reserved = bf.ReadUint16() + s.Clients = bf.ReadUint16() + s.Displayed = bf.ReadUint16() + s.MaxPlayers = bf.ReadUint16() + s.Flags = bf.ReadUint8() + + // Stage ID is a pascal string with uint8 length prefix. + strLen := bf.ReadUint8() + if strLen > 0 { + idBytes := bf.ReadBytes(uint(strLen)) + // Remove null terminator if present. + if len(idBytes) > 0 && idBytes[len(idBytes)-1] == 0 { + idBytes = idBytes[:len(idBytes)-1] + } + s.ID = string(idBytes) + } + + stages = append(stages, s) + } + + // After stages: uint32 timestamp, uint32 max clan members (we ignore these). + _ = binary.BigEndian // suppress unused import if needed + + return stages +} diff --git a/config.example.json b/config.example.json index a1951e791..e92a7a92f 100644 --- a/config.example.json +++ b/config.example.json @@ -219,34 +219,34 @@ { "Name": "Newbie", "Description": "", "IP": "", "Type": 3, "Recommended": 2, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54001, "MaxPlayers": 100 }, - { "Port": 54002, "MaxPlayers": 100 } + { "Port": 54001, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54002, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Normal", "Description": "", "IP": "", "Type": 1, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54003, "MaxPlayers": 100 }, - { "Port": 54004, "MaxPlayers": 100 } + { "Port": 54003, "MaxPlayers": 100, "Enabled": true }, + { "Port": 54004, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Cities", "Description": "", "IP": "", "Type": 2, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54005, "MaxPlayers": 100 } + { "Port": 54005, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Tavern", "Description": "", "IP": "", "Type": 4, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54006, "MaxPlayers": 100 } + { "Port": 54006, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "Return", "Description": "", "IP": "", "Type": 5, "Recommended": 0, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54007, "MaxPlayers": 100 } + { "Port": 54007, "MaxPlayers": 100, "Enabled": true } ] }, { "Name": "MezFes", "Description": "", "IP": "", "Type": 6, "Recommended": 6, "AllowedClientFlags": 0, "Channels": [ - { "Port": 54008, "MaxPlayers": 100 } + { "Port": 54008, "MaxPlayers": 100, "Enabled": true } ] } ] diff --git a/config/config.go b/config/config.go index e30cbcd12..6c26798aa 100644 --- a/config/config.go +++ b/config/config.go @@ -297,6 +297,15 @@ type EntranceChannelInfo struct { Port uint16 MaxPlayers uint16 CurrentPlayers uint16 + Enabled *bool // nil defaults to true for backward compatibility +} + +// IsEnabled returns whether this channel is enabled. Defaults to true if Enabled is nil. +func (c *EntranceChannelInfo) IsEnabled() bool { + if c.Enabled == nil { + return true + } + return *c.Enabled } var ErupeConfig *Config diff --git a/config/config_test.go b/config/config_test.go index 782b3ef89..cbad553ec 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -536,6 +536,34 @@ func TestEntranceChannelInfo(t *testing.T) { } } +// TestEntranceChannelInfoIsEnabled tests the Enabled field and IsEnabled helper +func TestEntranceChannelInfoIsEnabled(t *testing.T) { + trueVal := true + falseVal := false + + tests := []struct { + name string + enabled *bool + want bool + }{ + {"nil defaults to true", nil, true}, + {"explicit true", &trueVal, true}, + {"explicit false", &falseVal, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + info := EntranceChannelInfo{ + Port: 10001, + Enabled: tt.enabled, + } + if got := info.IsEnabled(); got != tt.want { + t.Errorf("IsEnabled() = %v, want %v", got, tt.want) + } + }) + } +} + // TestDiscord verifies Discord struct func TestDiscord(t *testing.T) { discord := Discord{ diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 759b16b87..34535e2ef 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -13,7 +13,7 @@ services: ports: - "5432:5432" volumes: - - ./db-data/:/var/lib/postgresql/data/ + - ./db-data/:/var/lib/postgresql/ - ../schemas/:/schemas/ - ./init/setup.sh:/docker-entrypoint-initdb.d/setup.sh healthcheck: diff --git a/main.go b/main.go index de44e0693..c6da1c977 100644 --- a/main.go +++ b/main.go @@ -16,6 +16,7 @@ import ( "erupe-ce/server/discordbot" "erupe-ce/server/entranceserver" "erupe-ce/server/signserver" + "strings" "github.com/jmoiron/sqlx" _ "github.com/lib/pq" @@ -129,11 +130,30 @@ func main() { } logger.Info("Database: Started successfully") - // Clear stale data - if config.DebugOptions.ProxyPort == 0 { - _ = db.MustExec("DELETE FROM sign_sessions") + // Pre-compute all server IDs this instance will own, so we only + // delete our own rows (safe for multi-instance on the same DB). + var ownedServerIDs []string + { + si := 0 + for _, ee := range config.Entrance.Entries { + ci := 0 + for range ee.Channels { + sid := (4096 + si*256) + (16 + ci) + ownedServerIDs = append(ownedServerIDs, fmt.Sprint(sid)) + ci++ + } + si++ + } + } + + // Clear stale data scoped to this instance's server IDs + if len(ownedServerIDs) > 0 { + idList := strings.Join(ownedServerIDs, ",") + if config.DebugOptions.ProxyPort == 0 { + _ = db.MustExec("DELETE FROM sign_sessions WHERE server_id IN (" + idList + ")") + } + _ = db.MustExec("DELETE FROM servers WHERE server_id IN (" + idList + ")") } - _ = db.MustExec("DELETE FROM servers") _ = db.MustExec(`UPDATE guild_characters SET treasure_hunt=NULL`) // Clean the DB if the option is on. @@ -213,6 +233,12 @@ func main() { for j, ee := range config.Entrance.Entries { for i, ce := range ee.Channels { sid := (4096 + si*256) + (16 + ci) + if !ce.IsEnabled() { + logger.Info(fmt.Sprintf("Channel %d (%d): Disabled via config", count, ce.Port)) + ci++ + count++ + continue + } c := *channelserver.NewServer(&channelserver.Config{ ID: uint16(sid), Logger: logger.Named("channel-" + fmt.Sprint(count)), @@ -237,9 +263,9 @@ func main() { ) channels = append(channels, &c) logger.Info(fmt.Sprintf("Channel %d (%d): Started successfully", count, ce.Port)) - ci++ count++ } + ci++ } ci = 0 si++ @@ -248,8 +274,10 @@ func main() { // Register all servers in DB _ = db.MustExec(channelQuery) + registry := channelserver.NewLocalChannelRegistry(channels) for _, c := range channels { c.Channels = channels + c.Registry = registry } } diff --git a/server/channelserver/channel_isolation_test.go b/server/channelserver/channel_isolation_test.go new file mode 100644 index 000000000..158fca9a3 --- /dev/null +++ b/server/channelserver/channel_isolation_test.go @@ -0,0 +1,214 @@ +package channelserver + +import ( + "net" + "testing" + "time" + + _config "erupe-ce/config" + + "go.uber.org/zap" +) + +// createListeningTestServer creates a channel server that binds to a real TCP port. +// Port 0 lets the OS assign a free port. The server is automatically shut down +// when the test completes. +func createListeningTestServer(t *testing.T, id uint16) *Server { + t.Helper() + logger, _ := zap.NewDevelopment() + s := NewServer(&Config{ + ID: id, + Logger: logger, + ErupeConfig: &_config.Config{ + DebugOptions: _config.DebugOptions{ + LogOutboundMessages: false, + LogInboundMessages: false, + }, + }, + }) + s.Port = 0 // Let OS pick a free port + if err := s.Start(); err != nil { + t.Fatalf("channel %d failed to start: %v", id, err) + } + t.Cleanup(func() { + s.Shutdown() + time.Sleep(200 * time.Millisecond) // Let background goroutines and sessions exit. + }) + return s +} + +// listenerAddr returns the address the server is listening on. +func listenerAddr(s *Server) string { + return s.listener.Addr().String() +} + +// TestChannelIsolation_ShutdownDoesNotAffectOthers verifies that shutting down +// one channel server does not prevent other channels from accepting connections. +func TestChannelIsolation_ShutdownDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + ch3 := createListeningTestServer(t, 3) + + addr1 := listenerAddr(ch1) + addr2 := listenerAddr(ch2) + addr3 := listenerAddr(ch3) + + // Verify all three channels accept connections initially. + for _, addr := range []string{addr1, addr2, addr3} { + conn, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("initial connection to %s failed: %v", addr, err) + } + conn.Close() + } + + // Shut down channel 1. + ch1.Shutdown() + time.Sleep(50 * time.Millisecond) + + // Channel 1 should refuse connections. + _, err := net.DialTimeout("tcp", addr1, 500*time.Millisecond) + if err == nil { + t.Error("channel 1 should refuse connections after shutdown") + } + + // Channels 2 and 3 must still accept connections. + for _, tc := range []struct { + name string + addr string + }{ + {"channel 2", addr2}, + {"channel 3", addr3}, + } { + conn, err := net.DialTimeout("tcp", tc.addr, time.Second) + if err != nil { + t.Errorf("%s should still accept connections after channel 1 shutdown, got: %v", tc.name, err) + } else { + conn.Close() + } + } +} + +// TestChannelIsolation_ListenerCloseDoesNotAffectOthers simulates an unexpected +// listener failure (e.g. port conflict, OS-level error) on one channel and +// verifies other channels continue operating. +func TestChannelIsolation_ListenerCloseDoesNotAffectOthers(t *testing.T) { + ch1 := createListeningTestServer(t, 1) + ch2 := createListeningTestServer(t, 2) + + addr2 := listenerAddr(ch2) + + // Forcibly close channel 1's listener (simulating unexpected failure). + ch1.listener.Close() + time.Sleep(50 * time.Millisecond) + + // Channel 2 must still work. + conn, err := net.DialTimeout("tcp", addr2, time.Second) + if err != nil { + t.Fatalf("channel 2 should still accept connections after channel 1 listener closed: %v", err) + } + conn.Close() +} + +// TestChannelIsolation_SessionPanicDoesNotAffectChannel verifies that a panic +// inside a session handler is recovered and does not crash the channel server. +func TestChannelIsolation_SessionPanicDoesNotAffectChannel(t *testing.T) { + ch := createListeningTestServer(t, 1) + addr := listenerAddr(ch) + + // Connect a client that will trigger a session. + conn1, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("first connection failed: %v", err) + } + + // Send garbage data that will cause handlePacketGroup to hit the panic recovery. + // The session's defer/recover should catch it without killing the channel. + conn1.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF}) + time.Sleep(100 * time.Millisecond) + conn1.Close() + time.Sleep(100 * time.Millisecond) + + // The channel should still accept new connections after the panic. + conn2, err := net.DialTimeout("tcp", addr, time.Second) + if err != nil { + t.Fatalf("channel should still accept connections after session panic: %v", err) + } + conn2.Close() +} + +// TestChannelIsolation_CrossChannelRegistryAfterShutdown verifies that the +// channel registry handles a shut-down channel gracefully during cross-channel +// operations (search, find, disconnect). +func TestChannelIsolation_CrossChannelRegistryAfterShutdown(t *testing.T) { + channels := createTestChannels(3) + reg := NewLocalChannelRegistry(channels) + + // Add sessions to all channels. + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + + // Simulate channel 1 shutting down by marking it and clearing sessions. + channels[0].Lock() + channels[0].isShuttingDown = true + channels[0].sessions = make(map[net.Conn]*Session) + channels[0].Unlock() + + // Registry operations should still work for remaining channels. + found := reg.FindSessionByCharID(2) + if found == nil { + t.Error("FindSessionByCharID(2) should find session on channel 2") + } + + found = reg.FindSessionByCharID(3) + if found == nil { + t.Error("FindSessionByCharID(3) should find session on channel 3") + } + + // Session from shut-down channel should not be found. + found = reg.FindSessionByCharID(1) + if found != nil { + t.Error("FindSessionByCharID(1) should not find session on shut-down channel") + } + + // SearchSessions should return only sessions from live channels. + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 2 { + t.Errorf("SearchSessions should return 2 results from live channels, got %d", len(results)) + } +} + +// TestChannelIsolation_IndependentStages verifies that stages are per-channel +// and one channel's stages don't leak into another. +func TestChannelIsolation_IndependentStages(t *testing.T) { + channels := createTestChannels(2) + + stageName := "sl1Qs999p0a0u42" + + // Add stage only to channel 1. + channels[0].stagesLock.Lock() + channels[0].stages[stageName] = NewStage(stageName) + channels[0].stagesLock.Unlock() + + // Channel 1 should have the stage. + channels[0].stagesLock.RLock() + _, ok1 := channels[0].stages[stageName] + channels[0].stagesLock.RUnlock() + if !ok1 { + t.Error("channel 1 should have the stage") + } + + // Channel 2 should NOT have the stage. + channels[1].stagesLock.RLock() + _, ok2 := channels[1].stages[stageName] + channels[1].stagesLock.RUnlock() + if ok2 { + t.Error("channel 2 should not have channel 1's stage") + } +} diff --git a/server/channelserver/channel_registry.go b/server/channelserver/channel_registry.go new file mode 100644 index 000000000..af391a727 --- /dev/null +++ b/server/channelserver/channel_registry.go @@ -0,0 +1,58 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" +) + +// ChannelRegistry abstracts cross-channel operations behind an interface. +// The default LocalChannelRegistry wraps the in-process []*Server slice. +// Future implementations may use DB/Redis/NATS for multi-process deployments. +type ChannelRegistry interface { + // Worldcast broadcasts a packet to all sessions across all channels. + Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) + + // FindSessionByCharID looks up a session by character ID across all channels. + FindSessionByCharID(charID uint32) *Session + + // DisconnectUser disconnects all sessions belonging to the given character IDs. + DisconnectUser(cids []uint32) + + // FindChannelForStage searches all channels for a stage whose ID has the + // given suffix and returns the owning channel's GlobalID, or "" if not found. + FindChannelForStage(stageSuffix string) string + + // SearchSessions searches sessions across all channels using a predicate, + // returning up to max snapshot results. + SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot + + // SearchStages searches stages across all channels with a prefix filter, + // returning up to max snapshot results. + SearchStages(stagePrefix string, max int) []StageSnapshot + + // NotifyMailToCharID finds the session for charID and sends a mail notification. + NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) +} + +// SessionSnapshot is an immutable copy of session data taken under lock. +type SessionSnapshot struct { + CharID uint32 + Name string + StageID string + ServerIP net.IP + ServerPort uint16 + UserBinary3 []byte // Copy of userBinaryParts index 3 +} + +// StageSnapshot is an immutable copy of stage data taken under lock. +type StageSnapshot struct { + ServerIP net.IP + ServerPort uint16 + StageID string + ClientCount int + Reserved int + MaxPlayers uint16 + RawBinData0 []byte + RawBinData1 []byte + RawBinData3 []byte +} diff --git a/server/channelserver/channel_registry_local.go b/server/channelserver/channel_registry_local.go new file mode 100644 index 000000000..a04651d0e --- /dev/null +++ b/server/channelserver/channel_registry_local.go @@ -0,0 +1,156 @@ +package channelserver + +import ( + "erupe-ce/network/mhfpacket" + "net" + "strings" +) + +// LocalChannelRegistry is the in-process ChannelRegistry backed by []*Server. +type LocalChannelRegistry struct { + channels []*Server +} + +// NewLocalChannelRegistry creates a LocalChannelRegistry wrapping the given channels. +func NewLocalChannelRegistry(channels []*Server) *LocalChannelRegistry { + return &LocalChannelRegistry{channels: channels} +} + +func (r *LocalChannelRegistry) Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + for _, c := range r.channels { + if c == ignoredChannel { + continue + } + c.BroadcastMHF(pkt, ignoredSession) + } +} + +func (r *LocalChannelRegistry) FindSessionByCharID(charID uint32) *Session { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + if session.charID == charID { + c.Unlock() + return session + } + } + c.Unlock() + } + return nil +} + +func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) { + for _, c := range r.channels { + c.Lock() + for _, session := range c.sessions { + for _, cid := range cids { + if session.charID == cid { + _ = session.rawConn.Close() + break + } + } + } + c.Unlock() + } +} + +func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string { + for _, channel := range r.channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, stageSuffix) { + gid := channel.GlobalID + channel.stagesLock.RUnlock() + return gid + } + } + channel.stagesLock.RUnlock() + } + return "" +} + +func (r *LocalChannelRegistry) SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot { + var results []SessionSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() + for _, session := range c.sessions { + if len(results) >= max { + break + } + snap := SessionSnapshot{ + CharID: session.charID, + Name: session.Name, + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + } + if session.stage != nil { + snap.StageID = session.stage.id + } + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + if len(ub3) > 0 { + snap.UserBinary3 = make([]byte, len(ub3)) + copy(snap.UserBinary3, ub3) + } + if predicate(snap) { + results = append(results, snap) + } + } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + return results +} + +func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []StageSnapshot { + var results []StageSnapshot + for _, c := range r.channels { + if len(results) >= max { + break + } + c.stagesLock.RLock() + for _, stage := range c.stages { + if len(results) >= max { + break + } + if !strings.HasPrefix(stage.id, stagePrefix) { + continue + } + stage.RLock() + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + bin3 := stage.rawBinaryData[stageBinaryKey{1, 3}] + bin3Copy := make([]byte, len(bin3)) + copy(bin3Copy, bin3) + + results = append(results, StageSnapshot{ + ServerIP: net.ParseIP(c.IP).To4(), + ServerPort: c.Port, + StageID: stage.id, + ClientCount: len(stage.clients) + len(stage.reservedClientSlots), + Reserved: len(stage.reservedClientSlots), + MaxPlayers: stage.maxPlayers, + RawBinData0: bin0Copy, + RawBinData1: bin1Copy, + RawBinData3: bin3Copy, + }) + stage.RUnlock() + } + c.stagesLock.RUnlock() + } + return results +} + +func (r *LocalChannelRegistry) NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) { + session := r.FindSessionByCharID(charID) + if session != nil { + SendMailNotification(sender, mail, session) + } +} diff --git a/server/channelserver/channel_registry_test.go b/server/channelserver/channel_registry_test.go new file mode 100644 index 000000000..bdeccffbf --- /dev/null +++ b/server/channelserver/channel_registry_test.go @@ -0,0 +1,190 @@ +package channelserver + +import ( + "net" + "sync" + "testing" +) + +func createTestChannels(count int) []*Server { + channels := make([]*Server, count) + for i := 0; i < count; i++ { + s := createTestServer() + s.ID = uint16(0x1010 + i) + s.IP = "10.0.0.1" + s.Port = uint16(54001 + i) + s.GlobalID = "0101" + s.userBinaryParts = make(map[userBinaryPartID][]byte) + channels[i] = s + } + return channels +} + +func TestLocalRegistryFindSessionByCharID(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + conn1 := &mockConn{} + sess1 := createTestSessionForServer(channels[0], conn1, 100, "Alice") + channels[0].Lock() + channels[0].sessions[conn1] = sess1 + channels[0].Unlock() + + conn2 := &mockConn{} + sess2 := createTestSessionForServer(channels[1], conn2, 200, "Bob") + channels[1].Lock() + channels[1].sessions[conn2] = sess2 + channels[1].Unlock() + + // Find on first channel + found := reg.FindSessionByCharID(100) + if found == nil || found.charID != 100 { + t.Errorf("FindSessionByCharID(100) = %v, want session with charID 100", found) + } + + // Find on second channel + found = reg.FindSessionByCharID(200) + if found == nil || found.charID != 200 { + t.Errorf("FindSessionByCharID(200) = %v, want session with charID 200", found) + } + + // Not found + found = reg.FindSessionByCharID(999) + if found != nil { + t.Errorf("FindSessionByCharID(999) = %v, want nil", found) + } +} + +func TestLocalRegistryFindChannelForStage(t *testing.T) { + channels := createTestChannels(2) + channels[0].GlobalID = "0101" + channels[1].GlobalID = "0102" + reg := NewLocalChannelRegistry(channels) + + channels[1].stagesLock.Lock() + channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42") + channels[1].stagesLock.Unlock() + + gid := reg.FindChannelForStage("u42") + if gid != "0102" { + t.Errorf("FindChannelForStage(u42) = %q, want %q", gid, "0102") + } + + gid = reg.FindChannelForStage("u999") + if gid != "" { + t.Errorf("FindChannelForStage(u999) = %q, want empty", gid) + } +} + +func TestLocalRegistryDisconnectUser(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + conn := &mockConn{} + sess := createTestSessionForServer(channels[0], conn, 42, "Target") + channels[0].Lock() + channels[0].sessions[conn] = sess + channels[0].Unlock() + + reg.DisconnectUser([]uint32{42}) + + if !conn.WasClosed() { + t.Error("DisconnectUser should have closed the connection for charID 42") + } +} + +func TestLocalRegistrySearchSessions(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Add 3 sessions across 2 channels + for i, ch := range channels { + conn := &mockConn{} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + conn3 := &mockConn{} + sess3 := createTestSessionForServer(channels[0], conn3, 3, "Player") + sess3.stage = NewStage("sl1Ns200p0a0u0") + channels[0].Lock() + channels[0].sessions[conn3] = sess3 + channels[0].Unlock() + + // Search all + results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10) + if len(results) != 3 { + t.Errorf("SearchSessions(all) returned %d results, want 3", len(results)) + } + + // Search with max + results = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 2) + if len(results) != 2 { + t.Errorf("SearchSessions(max=2) returned %d results, want 2", len(results)) + } + + // Search with predicate + results = reg.SearchSessions(func(s SessionSnapshot) bool { return s.CharID == 1 }, 10) + if len(results) != 1 { + t.Errorf("SearchSessions(charID==1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistrySearchStages(t *testing.T) { + channels := createTestChannels(1) + reg := NewLocalChannelRegistry(channels) + + channels[0].stagesLock.Lock() + channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1") + channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2") + channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other") + channels[0].stagesLock.Unlock() + + results := reg.SearchStages("sl2Ls210", 10) + if len(results) != 2 { + t.Errorf("SearchStages(sl2Ls210) returned %d results, want 2", len(results)) + } + + results = reg.SearchStages("sl2Ls210", 1) + if len(results) != 1 { + t.Errorf("SearchStages(sl2Ls210, max=1) returned %d results, want 1", len(results)) + } +} + +func TestLocalRegistryConcurrentAccess(t *testing.T) { + channels := createTestChannels(2) + reg := NewLocalChannelRegistry(channels) + + // Populate some sessions + for _, ch := range channels { + for i := 0; i < 10; i++ { + conn := &mockConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 50000 + i}} + sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player") + sess.stage = NewStage("sl1Ns200p0a0u0") + ch.Lock() + ch.sessions[conn] = sess + ch.Unlock() + } + } + + // Run concurrent operations + var wg sync.WaitGroup + for i := 0; i < 50; i++ { + wg.Add(3) + go func(id int) { + defer wg.Done() + _ = reg.FindSessionByCharID(uint32(id%10 + 1)) + }(i) + go func() { + defer wg.Done() + _ = reg.FindChannelForStage("u0") + }() + go func() { + defer wg.Done() + _ = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 5) + }() + } + wg.Wait() +} diff --git a/server/channelserver/handlers_character.go b/server/channelserver/handlers_character.go index c5558e839..14cd59084 100644 --- a/server/channelserver/handlers_character.go +++ b/server/channelserver/handlers_character.go @@ -48,6 +48,13 @@ func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error) } func (save *CharacterSaveData) Save(s *Session) { + if save.decompSave == nil { + s.logger.Warn("No decompressed save data, skipping save", + zap.Uint32("charID", save.CharID), + ) + return + } + if !s.kqfOverride { s.kqf = save.KQF } else { diff --git a/server/channelserver/handlers_guild_ops.go b/server/channelserver/handlers_guild_ops.go index 5a9a1b3e0..dc4086aea 100644 --- a/server/channelserver/handlers_guild_ops.go +++ b/server/channelserver/handlers_guild_ops.go @@ -304,11 +304,26 @@ func handleMsgMhfOperateGuildMember(s *Session, p mhfpacket.MHFPacket) { doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4)) } else { _ = mail.Send(s, nil) - for _, channel := range s.server.Channels { - for _, session := range channel.sessions { - if session.charID == pkt.CharID { - SendMailNotification(s, &mail, session) + if s.server.Registry != nil { + s.server.Registry.NotifyMailToCharID(pkt.CharID, s, &mail) + } else { + // Fallback: find the target session under lock, then notify outside the lock. + var targetSession *Session + for _, channel := range s.server.Channels { + channel.Lock() + for _, session := range channel.sessions { + if session.charID == pkt.CharID { + targetSession = session + break + } } + channel.Unlock() + if targetSession != nil { + break + } + } + if targetSession != nil { + SendMailNotification(s, &mail, targetSession) } } doAckSimpleSucceed(s, pkt.AckHandle, make([]byte, 4)) diff --git a/server/channelserver/handlers_session.go b/server/channelserver/handlers_session.go index 6e805a7dc..ab39c6b95 100644 --- a/server/channelserver/handlers_session.go +++ b/server/channelserver/handlers_session.go @@ -293,14 +293,16 @@ func logoutPlayer(s *Session) { } // Update sign sessions and server player count - _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) - if err != nil { - panic(err) - } + if s.server.db != nil { + _, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token) + if err != nil { + s.logger.Error("Failed to clear sign session", zap.Error(err)) + } - _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) - if err != nil { - panic(err) + _, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID) + if err != nil { + s.logger.Error("Failed to update player count", zap.Error(err)) + } } if s.stage == nil { @@ -399,11 +401,17 @@ func handleMsgSysEcho(s *Session, p mhfpacket.MHFPacket) {} func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) { pkt := p.(*mhfpacket.MsgSysLockGlobalSema) var sgid string - for _, channel := range s.server.Channels { - for id := range channel.stages { - if strings.HasSuffix(id, pkt.UserIDString) { - sgid = channel.GlobalID + if s.server.Registry != nil { + sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString) + } else { + for _, channel := range s.server.Channels { + channel.stagesLock.RLock() + for id := range channel.stages { + if strings.HasSuffix(id, pkt.UserIDString) { + sgid = channel.GlobalID + } } + channel.stagesLock.RUnlock() } } bf := byteframe.NewByteFrame() @@ -468,7 +476,23 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { resp.WriteUint16(0) switch pkt.SearchType { case 1, 2, 3: // usersearchidx, usersearchname, lobbysearchname + // Snapshot matching sessions under lock, then build response outside locks. + type sessionResult struct { + charID uint32 + name []byte + stageID []byte + ip net.IP + port uint16 + userBin3 []byte + } + var results []sessionResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.Lock() + c.userBinaryPartsLock.RLock() for _, session := range c.sessions { if count == maxResults { break @@ -483,31 +507,45 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { continue } count++ - sessionName := stringsupport.UTF8ToSJIS(session.Name) - sessionStage := stringsupport.UTF8ToSJIS(session.stage.id) - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - resp.WriteUint32(session.charID) - resp.WriteUint8(uint8(len(sessionStage) + 1)) - resp.WriteUint8(uint8(len(sessionName) + 1)) - resp.WriteUint16(uint16(len(c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}]))) - - // TODO: This case might be <=G2 - if _config.ErupeConfig.RealClientMode <= _config.G1 { - resp.WriteBytes(make([]byte, 8)) - } else { - resp.WriteBytes(make([]byte, 40)) - } - resp.WriteBytes(make([]byte, 8)) - - resp.WriteNullTerminatedBytes(sessionStage) - resp.WriteNullTerminatedBytes(sessionName) - resp.WriteBytes(c.userBinaryParts[userBinaryPartID{session.charID, 3}]) + ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}] + ub3Copy := make([]byte, len(ub3)) + copy(ub3Copy, ub3) + results = append(results, sessionResult{ + charID: session.charID, + name: stringsupport.UTF8ToSJIS(session.Name), + stageID: stringsupport.UTF8ToSJIS(session.stage.id), + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + userBin3: ub3Copy, + }) } + c.userBinaryPartsLock.RUnlock() + c.Unlock() + } + + for _, r := range results { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(r.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(r.port) + resp.WriteUint32(r.charID) + resp.WriteUint8(uint8(len(r.stageID) + 1)) + resp.WriteUint8(uint8(len(r.name) + 1)) + resp.WriteUint16(uint16(len(r.userBin3))) + + // TODO: This case might be <=G2 + if _config.ErupeConfig.RealClientMode <= _config.G1 { + resp.WriteBytes(make([]byte, 8)) + } else { + resp.WriteBytes(make([]byte, 40)) + } + resp.WriteBytes(make([]byte, 8)) + + resp.WriteNullTerminatedBytes(r.stageID) + resp.WriteNullTerminatedBytes(r.name) + resp.WriteBytes(r.userBin3) } case 4: // lobbysearch type FindPartyParams struct { @@ -594,12 +632,31 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } } + // Snapshot matching stages under lock, then build response outside locks. + type stageResult struct { + ip net.IP + port uint16 + clientCount int + reserved int + maxPlayers uint16 + stageID string + stageData []int16 + rawBinData0 []byte + rawBinData1 []byte + } + var stageResults []stageResult + for _, c := range s.server.Channels { + if count == maxResults { + break + } + c.stagesLock.RLock() for _, stage := range c.stages { if count == maxResults { break } if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) { + stage.RLock() sb3 := byteframe.NewByteFrameFromBytes(stage.rawBinaryData[stageBinaryKey{1, 3}]) _, _ = sb3.Seek(4, 0) @@ -621,6 +678,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { if findPartyParams.RankRestriction >= 0 { if stageData[0] > findPartyParams.RankRestriction { + stage.RUnlock() continue } } @@ -634,47 +692,72 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) { } } if !hasTarget { + stage.RUnlock() continue } } + // Copy binary data under lock + bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}] + bin0Copy := make([]byte, len(bin0)) + copy(bin0Copy, bin0) + bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}] + bin1Copy := make([]byte, len(bin1)) + copy(bin1Copy, bin1) + count++ - if !local { - resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4())) - } else { - resp.WriteUint32(0x0100007F) - } - resp.WriteUint16(c.Port) - - resp.WriteUint16(0) // Static? - resp.WriteUint16(0) // Unk, [0 1 2] - resp.WriteUint16(uint16(len(stage.clients) + len(stage.reservedClientSlots))) - resp.WriteUint16(stage.maxPlayers) - // TODO: Retail returned the number of clients in quests, not workshop/my series - resp.WriteUint16(uint16(len(stage.reservedClientSlots))) - - resp.WriteUint8(0) // Static? - resp.WriteUint8(uint8(stage.maxPlayers)) - resp.WriteUint8(1) // Static? - resp.WriteUint8(uint8(len(stage.id) + 1)) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 0}]))) - resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 1}]))) - - for i := range stageData { - if _config.ErupeConfig.RealClientMode >= _config.Z1 { - resp.WriteInt16(stageData[i]) - } else { - resp.WriteInt8(int8(stageData[i])) - } - } - resp.WriteUint8(0) // Unk - resp.WriteUint8(0) // Unk - - resp.WriteNullTerminatedBytes([]byte(stage.id)) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 0}]) - resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 1}]) + stageResults = append(stageResults, stageResult{ + ip: net.ParseIP(c.IP).To4(), + port: c.Port, + clientCount: len(stage.clients) + len(stage.reservedClientSlots), + reserved: len(stage.reservedClientSlots), + maxPlayers: stage.maxPlayers, + stageID: stage.id, + stageData: stageData, + rawBinData0: bin0Copy, + rawBinData1: bin1Copy, + }) + stage.RUnlock() } } + c.stagesLock.RUnlock() + } + + for _, sr := range stageResults { + if !local { + resp.WriteUint32(binary.LittleEndian.Uint32(sr.ip)) + } else { + resp.WriteUint32(0x0100007F) + } + resp.WriteUint16(sr.port) + + resp.WriteUint16(0) // Static? + resp.WriteUint16(0) // Unk, [0 1 2] + resp.WriteUint16(uint16(sr.clientCount)) + resp.WriteUint16(sr.maxPlayers) + // TODO: Retail returned the number of clients in quests, not workshop/my series + resp.WriteUint16(uint16(sr.reserved)) + + resp.WriteUint8(0) // Static? + resp.WriteUint8(uint8(sr.maxPlayers)) + resp.WriteUint8(1) // Static? + resp.WriteUint8(uint8(len(sr.stageID) + 1)) + resp.WriteUint8(uint8(len(sr.rawBinData0))) + resp.WriteUint8(uint8(len(sr.rawBinData1))) + + for i := range sr.stageData { + if _config.ErupeConfig.RealClientMode >= _config.Z1 { + resp.WriteInt16(sr.stageData[i]) + } else { + resp.WriteInt8(int8(sr.stageData[i])) + } + } + resp.WriteUint8(0) // Unk + resp.WriteUint8(0) // Unk + + resp.WriteNullTerminatedBytes([]byte(sr.stageID)) + resp.WriteBytes(sr.rawBinData0) + resp.WriteBytes(sr.rawBinData1) } } _, _ = resp.Seek(0, io.SeekStart) diff --git a/server/channelserver/sys_channel_server.go b/server/channelserver/sys_channel_server.go index d951f93d0..bfd22414d 100644 --- a/server/channelserver/sys_channel_server.go +++ b/server/channelserver/sys_channel_server.go @@ -37,6 +37,7 @@ type userBinaryPartID struct { type Server struct { sync.Mutex Channels []*Server + Registry ChannelRegistry ID uint16 GlobalID string IP string @@ -49,6 +50,7 @@ type Server struct { sessions map[net.Conn]*Session listener net.Listener // Listener that is created when Server.Start is called. isShuttingDown bool + done chan struct{} // Closed on Shutdown to wake background goroutines. stagesLock sync.RWMutex stages map[string]*Stage @@ -90,6 +92,7 @@ func NewServer(config *Config) *Server { erupeConfig: config.ErupeConfig, acceptConns: make(chan net.Conn), deleteConns: make(chan net.Conn), + done: make(chan struct{}), sessions: make(map[net.Conn]*Session), stages: make(map[string]*Stage), userBinaryParts: make(map[userBinaryPartID][]byte), @@ -155,19 +158,23 @@ func (s *Server) Start() error { return nil } -// Shutdown tries to shut down the server gracefully. +// Shutdown tries to shut down the server gracefully. Safe to call multiple times. func (s *Server) Shutdown() { s.Lock() + alreadyShutDown := s.isShuttingDown s.isShuttingDown = true s.Unlock() + if alreadyShutDown { + return + } + + close(s.done) + if s.listener != nil { _ = s.listener.Close() } - if s.acceptConns != nil { - close(s.acceptConns) - } } func (s *Server) acceptClients() { @@ -185,25 +192,21 @@ func (s *Server) acceptClients() { continue } } - s.acceptConns <- conn + select { + case s.acceptConns <- conn: + case <-s.done: + _ = conn.Close() + return + } } } func (s *Server) manageSessions() { for { select { + case <-s.done: + return case newConn := <-s.acceptConns: - // Gracefully handle acceptConns channel closing. - if newConn == nil { - s.Lock() - shutdown := s.isShuttingDown - s.Unlock() - - if shutdown { - return - } - } - session := NewSession(s, newConn) s.Lock() @@ -235,15 +238,28 @@ func (s *Server) getObjectId() uint16 { } func (s *Server) invalidateSessions() { - for !s.isShuttingDown { + ticker := time.NewTicker(10 * time.Second) + defer ticker.Stop() + for { + select { + case <-s.done: + return + case <-ticker.C: + } + s.Lock() + var timedOut []*Session for _, sess := range s.sessions { if time.Since(sess.lastPacket) > time.Second*time.Duration(30) { - s.logger.Info("session timeout", zap.String("Name", sess.Name)) - logoutPlayer(sess) + timedOut = append(timedOut, sess) } } - time.Sleep(time.Second * 10) + s.Unlock() + + for _, sess := range timedOut { + s.logger.Info("session timeout", zap.String("Name", sess.Name)) + logoutPlayer(sess) + } } } @@ -271,6 +287,10 @@ func (s *Server) BroadcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session) // WorldcastMHF broadcasts a packet to all sessions across all channel servers. func (s *Server) WorldcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) { + if s.Registry != nil { + s.Registry.Worldcast(pkt, ignoredSession, ignoredChannel) + return + } for _, c := range s.Channels { if c == ignoredChannel { continue @@ -317,12 +337,18 @@ func (s *Server) DiscordScreenShotSend(charName string, title string, descriptio // FindSessionByCharID looks up a session by character ID across all channels. func (s *Server) FindSessionByCharID(charID uint32) *Session { + if s.Registry != nil { + return s.Registry.FindSessionByCharID(charID) + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { if session.charID == charID { + c.Unlock() return session } } + c.Unlock() } return nil } @@ -341,7 +367,12 @@ func (s *Server) DisconnectUser(uid uint32) { cids = append(cids, cid) } } + if s.Registry != nil { + s.Registry.DisconnectUser(cids) + return + } for _, c := range s.Channels { + c.Lock() for _, session := range c.sessions { for _, cid := range cids { if session.charID == cid { @@ -350,6 +381,7 @@ func (s *Server) DisconnectUser(uid uint32) { } } } + c.Unlock() } } diff --git a/server/channelserver/sys_session.go b/server/channelserver/sys_session.go index 294d470ab..b30190aec 100644 --- a/server/channelserver/sys_session.go +++ b/server/channelserver/sys_session.go @@ -4,7 +4,6 @@ import ( "encoding/binary" "encoding/hex" "erupe-ce/common/mhfcourse" - _config "erupe-ce/config" "fmt" "io" "net" @@ -172,7 +171,7 @@ func (s *Session) sendLoop() { s.logger.Warn("Failed to send packet", zap.Error(err)) } } - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } } @@ -215,7 +214,7 @@ func (s *Session) recvLoop() { return } s.handlePacketGroup(pkt) - time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond) + time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond) } }