mirror of
https://github.com/Mezeporta/Erupe.git
synced 2026-03-22 07:32:32 +01:00
Merge pull request #159 from Mezeporta/feat/independent-channel-servers
Summary - Each channel server now runs as a fully independent instance with its own listener, goroutines, and state — one channel crashing or shutting down no longer affects the others - Introduces a ChannelRegistry interface for cross-channel operations (find session, disconnect user, worldcast) replacing direct iteration over a shared []*Server slice - Adds cmd/protbot, a headless MHF protocol bot that exercises the full sign → entrance → channel flow for automated testing - Fixes several data races and panics found by the race detector during isolation testing Changes Channel server isolation (server/channelserver/) - ChannelRegistry interface + LocalChannelRegistry implementation for cross-channel lookups - done channel for clean goroutine shutdown signaling, idempotent Shutdown() - Race-free acceptClients/manageSessions using select on done instead of closing acceptConns - invalidateSessions rewritten with proper locking (snapshot under lock, process outside) - logoutPlayer guards nil DB and logs errors instead of panicking - Session loops use per-server erupeConfig instead of global _config.ErupeConfig - Per-channel Enabled flag in config for selectively disabling channels Protocol bot (cmd/protbot/) - Standalone Blowfish connection package (no dependency on server config) - Sign, entrance, and channel protocol implementations - 5 scenario actions: login, lobby, session, chat, quests - 19 unit tests covering packet building, parsing, and connection handling Bug fixes - Nil decompSave panic on disconnect before character data loads - Docker Postgres 18 volume mount path (/var/lib/postgresql/ not /data/) Test plan - go test -race ./... passes (27 packages, 0 races) - 5 channel isolation tests verify: independent shutdown, listener failure recovery, session panic containment, cross-channel registry after shutdown, stage isolation - Protbot live-tested against Docker stack (all 5 actions) - Existing config.json files work unchanged (Enabled defaults to false but config.example.json sets it explicitly)
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -16,6 +16,7 @@ screenshots/*
|
||||
# We don't need built files
|
||||
erupe-ce
|
||||
erupe
|
||||
protbot
|
||||
tools/loganalyzer/loganalyzer
|
||||
|
||||
# config is install dependent
|
||||
|
||||
37
cmd/protbot/conn/bin8.go
Normal file
37
cmd/protbot/conn/bin8.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package conn
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
var (
|
||||
bin8Key = []byte{0x01, 0x23, 0x34, 0x45, 0x56, 0xAB, 0xCD, 0xEF}
|
||||
sum32Table0 = []byte{0x35, 0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12}
|
||||
sum32Table1 = []byte{0x7A, 0xAA, 0x97, 0x53, 0x66, 0x12, 0xDE, 0xDE, 0x35}
|
||||
)
|
||||
|
||||
// CalcSum32 calculates the custom MHF "sum32" checksum.
|
||||
func CalcSum32(data []byte) uint32 {
|
||||
tableIdx0 := (len(data) + 1) & 0xFF
|
||||
tableIdx1 := int((data[len(data)>>1] + 1) & 0xFF)
|
||||
out := make([]byte, 4)
|
||||
for i := 0; i < len(data); i++ {
|
||||
key := data[i] ^ sum32Table0[(tableIdx0+i)%7] ^ sum32Table1[(tableIdx1+i)%9]
|
||||
out[i&3] = (out[i&3] + key) & 0xFF
|
||||
}
|
||||
return binary.BigEndian.Uint32(out)
|
||||
}
|
||||
|
||||
func rotate(k *uint32) {
|
||||
*k = uint32(((54323 * uint(*k)) + 1) & 0xFFFFFFFF)
|
||||
}
|
||||
|
||||
// DecryptBin8 decrypts MHF "binary8" data.
|
||||
func DecryptBin8(data []byte, key byte) []byte {
|
||||
k := uint32(key)
|
||||
output := make([]byte, len(data))
|
||||
for i := 0; i < len(data); i++ {
|
||||
rotate(&k)
|
||||
tmp := data[i] ^ byte((k>>13)&0xFF)
|
||||
output[i] = tmp ^ bin8Key[i&7]
|
||||
}
|
||||
return output
|
||||
}
|
||||
52
cmd/protbot/conn/bin8_test.go
Normal file
52
cmd/protbot/conn/bin8_test.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCalcSum32 verifies the checksum against a known input.
|
||||
func TestCalcSum32(t *testing.T) {
|
||||
// Verify determinism: same input gives same output.
|
||||
data := []byte("Hello, MHF!")
|
||||
sum1 := CalcSum32(data)
|
||||
sum2 := CalcSum32(data)
|
||||
if sum1 != sum2 {
|
||||
t.Fatalf("CalcSum32 not deterministic: %08X != %08X", sum1, sum2)
|
||||
}
|
||||
|
||||
// Different inputs produce different outputs (basic sanity).
|
||||
data2 := []byte("Hello, MHF?")
|
||||
sum3 := CalcSum32(data2)
|
||||
if sum1 == sum3 {
|
||||
t.Fatalf("CalcSum32 collision on different inputs: both %08X", sum1)
|
||||
}
|
||||
}
|
||||
|
||||
// TestDecryptBin8RoundTrip verifies that encrypting and decrypting with Bin8
|
||||
// produces the original data. We only have DecryptBin8, but we can verify
|
||||
// the encrypt→decrypt path by implementing encrypt inline here.
|
||||
func TestDecryptBin8RoundTrip(t *testing.T) {
|
||||
original := []byte("Test data for Bin8 encryption round-trip")
|
||||
key := byte(0x42)
|
||||
|
||||
// Encrypt (inline copy of Erupe's EncryptBin8)
|
||||
k := uint32(key)
|
||||
encrypted := make([]byte, len(original))
|
||||
for i := 0; i < len(original); i++ {
|
||||
rotate(&k)
|
||||
tmp := bin8Key[i&7] ^ byte((k>>13)&0xFF)
|
||||
encrypted[i] = original[i] ^ tmp
|
||||
}
|
||||
|
||||
// Decrypt
|
||||
decrypted := DecryptBin8(encrypted, key)
|
||||
|
||||
if len(decrypted) != len(original) {
|
||||
t.Fatalf("length mismatch: got %d, want %d", len(decrypted), len(original))
|
||||
}
|
||||
for i := range original {
|
||||
if decrypted[i] != original[i] {
|
||||
t.Fatalf("byte %d: got 0x%02X, want 0x%02X", i, decrypted[i], original[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
52
cmd/protbot/conn/conn.go
Normal file
52
cmd/protbot/conn/conn.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
)
|
||||
|
||||
// MHFConn wraps a CryptConn and provides convenience methods for MHF connections.
|
||||
type MHFConn struct {
|
||||
*CryptConn
|
||||
RawConn net.Conn
|
||||
}
|
||||
|
||||
// DialWithInit connects to addr and sends the 8 NULL byte initialization
|
||||
// required by sign and entrance servers.
|
||||
func DialWithInit(addr string) (*MHFConn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
// Sign and entrance servers expect 8 NULL bytes to initialize the connection.
|
||||
_, err = conn.Write(make([]byte, 8))
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, fmt.Errorf("write init bytes to %s: %w", addr, err)
|
||||
}
|
||||
|
||||
return &MHFConn{
|
||||
CryptConn: NewCryptConn(conn),
|
||||
RawConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// DialDirect connects to addr without sending initialization bytes.
|
||||
// Used for channel server connections.
|
||||
func DialDirect(addr string) (*MHFConn, error) {
|
||||
conn, err := net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
return &MHFConn{
|
||||
CryptConn: NewCryptConn(conn),
|
||||
RawConn: conn,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying connection.
|
||||
func (c *MHFConn) Close() error {
|
||||
return c.RawConn.Close()
|
||||
}
|
||||
115
cmd/protbot/conn/crypt_conn.go
Normal file
115
cmd/protbot/conn/crypt_conn.go
Normal file
@@ -0,0 +1,115 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"erupe-ce/network/crypto"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
)
|
||||
|
||||
// CryptConn is an MHF encrypted two-way connection.
|
||||
// Adapted from Erupe's network/crypt_conn.go with config dependency removed.
|
||||
// Hardcoded to ZZ mode (supports Pf0-based extended data size).
|
||||
type CryptConn struct {
|
||||
conn net.Conn
|
||||
readKeyRot uint32
|
||||
sendKeyRot uint32
|
||||
sentPackets int32
|
||||
prevRecvPacketCombinedCheck uint16
|
||||
prevSendPacketCombinedCheck uint16
|
||||
}
|
||||
|
||||
// NewCryptConn creates a new CryptConn with proper default values.
|
||||
func NewCryptConn(conn net.Conn) *CryptConn {
|
||||
return &CryptConn{
|
||||
conn: conn,
|
||||
readKeyRot: 995117,
|
||||
sendKeyRot: 995117,
|
||||
}
|
||||
}
|
||||
|
||||
// ReadPacket reads a packet from the connection and returns the decrypted data.
|
||||
func (cc *CryptConn) ReadPacket() ([]byte, error) {
|
||||
headerData := make([]byte, CryptPacketHeaderLength)
|
||||
_, err := io.ReadFull(cc.conn, headerData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cph, err := NewCryptPacketHeader(headerData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// ZZ mode: extended data size using Pf0 field.
|
||||
encryptedPacketBody := make([]byte, uint32(cph.DataSize)+(uint32(cph.Pf0-0x03)*0x1000))
|
||||
_, err = io.ReadFull(cc.conn, encryptedPacketBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cph.KeyRotDelta != 0 {
|
||||
cc.readKeyRot = uint32(cph.KeyRotDelta) * (cc.readKeyRot + 1)
|
||||
}
|
||||
|
||||
out, combinedCheck, check0, check1, check2 := crypto.Crypto(encryptedPacketBody, cc.readKeyRot, false, nil)
|
||||
if cph.Check0 != check0 || cph.Check1 != check1 || cph.Check2 != check2 {
|
||||
fmt.Printf("got c0 %X, c1 %X, c2 %X\n", check0, check1, check2)
|
||||
fmt.Printf("want c0 %X, c1 %X, c2 %X\n", cph.Check0, cph.Check1, cph.Check2)
|
||||
fmt.Printf("headerData:\n%s\n", hex.Dump(headerData))
|
||||
fmt.Printf("encryptedPacketBody:\n%s\n", hex.Dump(encryptedPacketBody))
|
||||
|
||||
// Attempt bruteforce recovery.
|
||||
fmt.Println("Crypto out of sync? Attempting bruteforce")
|
||||
for key := byte(0); key < 255; key++ {
|
||||
out, combinedCheck, check0, check1, check2 = crypto.Crypto(encryptedPacketBody, 0, false, &key)
|
||||
if cph.Check0 == check0 && cph.Check1 == check1 && cph.Check2 == check2 {
|
||||
fmt.Printf("Bruteforce successful, override key: 0x%X\n", key)
|
||||
cc.prevRecvPacketCombinedCheck = combinedCheck
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("decrypted data checksum doesn't match header")
|
||||
}
|
||||
|
||||
cc.prevRecvPacketCombinedCheck = combinedCheck
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SendPacket encrypts and sends a packet.
|
||||
func (cc *CryptConn) SendPacket(data []byte) error {
|
||||
keyRotDelta := byte(3)
|
||||
|
||||
if keyRotDelta != 0 {
|
||||
cc.sendKeyRot = uint32(keyRotDelta) * (cc.sendKeyRot + 1)
|
||||
}
|
||||
|
||||
encData, combinedCheck, check0, check1, check2 := crypto.Crypto(data, cc.sendKeyRot, true, nil)
|
||||
|
||||
header := &CryptPacketHeader{}
|
||||
header.Pf0 = byte(((uint(len(encData)) >> 12) & 0xF3) | 3)
|
||||
header.KeyRotDelta = keyRotDelta
|
||||
header.PacketNum = uint16(cc.sentPackets)
|
||||
header.DataSize = uint16(len(encData))
|
||||
header.PrevPacketCombinedCheck = cc.prevSendPacketCombinedCheck
|
||||
header.Check0 = check0
|
||||
header.Check1 = check1
|
||||
header.Check2 = check2
|
||||
|
||||
headerBytes, err := header.Encode()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = cc.conn.Write(append(headerBytes, encData...))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cc.sentPackets++
|
||||
cc.prevSendPacketCombinedCheck = combinedCheck
|
||||
|
||||
return nil
|
||||
}
|
||||
152
cmd/protbot/conn/crypt_conn_test.go
Normal file
152
cmd/protbot/conn/crypt_conn_test.go
Normal file
@@ -0,0 +1,152 @@
|
||||
package conn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCryptConnRoundTrip verifies that encrypting and decrypting a packet
|
||||
// through a pair of CryptConn instances produces the original data.
|
||||
func TestCryptConnRoundTrip(t *testing.T) {
|
||||
// Create an in-process TCP pipe.
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
sender := NewCryptConn(client)
|
||||
receiver := NewCryptConn(server)
|
||||
|
||||
testCases := [][]byte{
|
||||
{0x00, 0x14, 0x00, 0x00, 0x00, 0x01}, // Minimal login-like packet
|
||||
{0xDE, 0xAD, 0xBE, 0xEF},
|
||||
make([]byte, 256), // Larger packet
|
||||
}
|
||||
|
||||
for i, original := range testCases {
|
||||
// Send in a goroutine to avoid blocking.
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sender.SendPacket(original)
|
||||
}()
|
||||
|
||||
received, err := receiver.ReadPacket()
|
||||
if err != nil {
|
||||
t.Fatalf("case %d: ReadPacket error: %v", i, err)
|
||||
}
|
||||
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("case %d: SendPacket error: %v", i, err)
|
||||
}
|
||||
|
||||
if len(received) != len(original) {
|
||||
t.Fatalf("case %d: length mismatch: got %d, want %d", i, len(received), len(original))
|
||||
}
|
||||
for j := range original {
|
||||
if received[j] != original[j] {
|
||||
t.Fatalf("case %d: byte %d mismatch: got 0x%02X, want 0x%02X", i, j, received[j], original[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestCryptPacketHeaderRoundTrip verifies header encode/decode.
|
||||
func TestCryptPacketHeaderRoundTrip(t *testing.T) {
|
||||
original := &CryptPacketHeader{
|
||||
Pf0: 0x03,
|
||||
KeyRotDelta: 0x03,
|
||||
PacketNum: 42,
|
||||
DataSize: 100,
|
||||
PrevPacketCombinedCheck: 0x1234,
|
||||
Check0: 0xAAAA,
|
||||
Check1: 0xBBBB,
|
||||
Check2: 0xCCCC,
|
||||
}
|
||||
|
||||
encoded, err := original.Encode()
|
||||
if err != nil {
|
||||
t.Fatalf("Encode error: %v", err)
|
||||
}
|
||||
|
||||
if len(encoded) != CryptPacketHeaderLength {
|
||||
t.Fatalf("encoded length: got %d, want %d", len(encoded), CryptPacketHeaderLength)
|
||||
}
|
||||
|
||||
decoded, err := NewCryptPacketHeader(encoded)
|
||||
if err != nil {
|
||||
t.Fatalf("NewCryptPacketHeader error: %v", err)
|
||||
}
|
||||
|
||||
if *decoded != *original {
|
||||
t.Fatalf("header mismatch:\ngot %+v\nwant %+v", *decoded, *original)
|
||||
}
|
||||
}
|
||||
|
||||
// TestMultiPacketSequence verifies that key rotation stays in sync across
|
||||
// multiple sequential packets.
|
||||
func TestMultiPacketSequence(t *testing.T) {
|
||||
server, client := net.Pipe()
|
||||
defer server.Close()
|
||||
defer client.Close()
|
||||
|
||||
sender := NewCryptConn(client)
|
||||
receiver := NewCryptConn(server)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
data := []byte{byte(i), byte(i + 1), byte(i + 2), byte(i + 3)}
|
||||
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- sender.SendPacket(data)
|
||||
}()
|
||||
|
||||
received, err := receiver.ReadPacket()
|
||||
if err != nil {
|
||||
t.Fatalf("packet %d: ReadPacket error: %v", i, err)
|
||||
}
|
||||
|
||||
if err := <-errCh; err != nil {
|
||||
t.Fatalf("packet %d: SendPacket error: %v", i, err)
|
||||
}
|
||||
|
||||
for j := range data {
|
||||
if received[j] != data[j] {
|
||||
t.Fatalf("packet %d byte %d: got 0x%02X, want 0x%02X", i, j, received[j], data[j])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestDialWithInit verifies that DialWithInit sends 8 NULL bytes on connect.
|
||||
func TestDialWithInit(t *testing.T) {
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
done := make(chan []byte, 1)
|
||||
go func() {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 8)
|
||||
_, _ = io.ReadFull(conn, buf)
|
||||
done <- buf
|
||||
}()
|
||||
|
||||
c, err := DialWithInit(listener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
initBytes := <-done
|
||||
for i, b := range initBytes {
|
||||
if b != 0 {
|
||||
t.Fatalf("init byte %d: got 0x%02X, want 0x00", i, b)
|
||||
}
|
||||
}
|
||||
}
|
||||
78
cmd/protbot/conn/crypt_packet.go
Normal file
78
cmd/protbot/conn/crypt_packet.go
Normal file
@@ -0,0 +1,78 @@
|
||||
// Package conn provides MHF encrypted connection primitives.
|
||||
//
|
||||
// This is adapted from Erupe's network/crypt_packet.go to avoid importing
|
||||
// erupe-ce/config (whose init() calls os.Exit without a config file).
|
||||
package conn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
const CryptPacketHeaderLength = 14
|
||||
|
||||
// CryptPacketHeader represents the parsed information of an encrypted packet header.
|
||||
type CryptPacketHeader struct {
|
||||
Pf0 byte
|
||||
KeyRotDelta byte
|
||||
PacketNum uint16
|
||||
DataSize uint16
|
||||
PrevPacketCombinedCheck uint16
|
||||
Check0 uint16
|
||||
Check1 uint16
|
||||
Check2 uint16
|
||||
}
|
||||
|
||||
// NewCryptPacketHeader parses raw bytes into a CryptPacketHeader.
|
||||
func NewCryptPacketHeader(data []byte) (*CryptPacketHeader, error) {
|
||||
var c CryptPacketHeader
|
||||
r := bytes.NewReader(data)
|
||||
|
||||
if err := binary.Read(r, binary.BigEndian, &c.Pf0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.KeyRotDelta); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.PacketNum); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.DataSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.PrevPacketCombinedCheck); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.Check0); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.Check1); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := binary.Read(r, binary.BigEndian, &c.Check2); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
// Encode encodes the CryptPacketHeader into raw bytes.
|
||||
func (c *CryptPacketHeader) Encode() ([]byte, error) {
|
||||
buf := bytes.NewBuffer([]byte{})
|
||||
data := []interface{}{
|
||||
c.Pf0,
|
||||
c.KeyRotDelta,
|
||||
c.PacketNum,
|
||||
c.DataSize,
|
||||
c.PrevPacketCombinedCheck,
|
||||
c.Check0,
|
||||
c.Check1,
|
||||
c.Check2,
|
||||
}
|
||||
for _, v := range data {
|
||||
if err := binary.Write(buf, binary.BigEndian, v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
154
cmd/protbot/main.go
Normal file
154
cmd/protbot/main.go
Normal file
@@ -0,0 +1,154 @@
|
||||
// protbot is a headless MHF protocol bot for testing Erupe server instances.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action login
|
||||
// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action lobby
|
||||
// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action session
|
||||
// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action chat --message "Hello"
|
||||
// protbot --sign-addr 127.0.0.1:53312 --user test --pass test --action quests
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"erupe-ce/cmd/protbot/scenario"
|
||||
)
|
||||
|
||||
func main() {
|
||||
signAddr := flag.String("sign-addr", "127.0.0.1:53312", "Sign server address (host:port)")
|
||||
user := flag.String("user", "", "Username")
|
||||
pass := flag.String("pass", "", "Password")
|
||||
action := flag.String("action", "login", "Action to perform: login, lobby, session, chat, quests")
|
||||
message := flag.String("message", "", "Chat message to send (used with --action chat)")
|
||||
flag.Parse()
|
||||
|
||||
if *user == "" || *pass == "" {
|
||||
fmt.Fprintln(os.Stderr, "error: --user and --pass are required")
|
||||
flag.Usage()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
switch *action {
|
||||
case "login":
|
||||
result, err := scenario.Login(*signAddr, *user, *pass)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("[done] Login successful!")
|
||||
result.Channel.Close()
|
||||
|
||||
case "lobby":
|
||||
result, err := scenario.Login(*signAddr, *user, *pass)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := scenario.EnterLobby(result.Channel); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("[done] Lobby entry successful!")
|
||||
result.Channel.Close()
|
||||
|
||||
case "session":
|
||||
result, err := scenario.Login(*signAddr, *user, *pass)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
charID := result.Sign.CharIDs[0]
|
||||
if _, err := scenario.SetupSession(result.Channel, charID); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := scenario.EnterLobby(result.Channel); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Println("[session] Connected. Press Ctrl+C to disconnect.")
|
||||
waitForSignal()
|
||||
scenario.Logout(result.Channel)
|
||||
|
||||
case "chat":
|
||||
result, err := scenario.Login(*signAddr, *user, *pass)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
charID := result.Sign.CharIDs[0]
|
||||
if _, err := scenario.SetupSession(result.Channel, charID); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := scenario.EnterLobby(result.Channel); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Register chat listener.
|
||||
scenario.ListenChat(result.Channel, func(msg scenario.ChatMessage) {
|
||||
fmt.Printf("[chat] <%s> (type=%d): %s\n", msg.SenderName, msg.ChatType, msg.Message)
|
||||
})
|
||||
|
||||
// Send a message if provided.
|
||||
if *message != "" {
|
||||
if err := scenario.SendChat(result.Channel, 0x03, 1, *message, *user); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "send chat failed: %v\n", err)
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("[chat] Listening for chat messages. Press Ctrl+C to disconnect.")
|
||||
waitForSignal()
|
||||
scenario.Logout(result.Channel)
|
||||
|
||||
case "quests":
|
||||
result, err := scenario.Login(*signAddr, *user, *pass)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "login failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
charID := result.Sign.CharIDs[0]
|
||||
if _, err := scenario.SetupSession(result.Channel, charID); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "session setup failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
if err := scenario.EnterLobby(result.Channel); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "enter lobby failed: %v\n", err)
|
||||
result.Channel.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
data, err := scenario.EnumerateQuests(result.Channel, 0, 0)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "enumerate quests failed: %v\n", err)
|
||||
scenario.Logout(result.Channel)
|
||||
os.Exit(1)
|
||||
}
|
||||
fmt.Printf("[quests] Received %d bytes of quest data\n", len(data))
|
||||
scenario.Logout(result.Channel)
|
||||
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "unknown action: %s (supported: login, lobby, session, chat, quests)\n", *action)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForSignal blocks until SIGINT or SIGTERM is received.
|
||||
func waitForSignal() {
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sig
|
||||
fmt.Println("\n[signal] Shutting down...")
|
||||
}
|
||||
190
cmd/protbot/protocol/channel.go
Normal file
190
cmd/protbot/protocol/channel.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"erupe-ce/cmd/protbot/conn"
|
||||
)
|
||||
|
||||
// PacketHandler is a callback invoked when a server-pushed packet is received.
|
||||
type PacketHandler func(opcode uint16, data []byte)
|
||||
|
||||
// ChannelConn manages a connection to a channel server.
|
||||
type ChannelConn struct {
|
||||
conn *conn.MHFConn
|
||||
ackCounter uint32
|
||||
waiters sync.Map // map[uint32]chan *AckResponse
|
||||
handlers sync.Map // map[uint16]PacketHandler
|
||||
closed atomic.Bool
|
||||
}
|
||||
|
||||
// OnPacket registers a handler for a specific server-pushed opcode.
|
||||
// Only one handler per opcode; later registrations replace earlier ones.
|
||||
func (ch *ChannelConn) OnPacket(opcode uint16, handler PacketHandler) {
|
||||
ch.handlers.Store(opcode, handler)
|
||||
}
|
||||
|
||||
// AckResponse holds the parsed ACK data from the server.
|
||||
type AckResponse struct {
|
||||
AckHandle uint32
|
||||
IsBufferResponse bool
|
||||
ErrorCode uint8
|
||||
Data []byte
|
||||
}
|
||||
|
||||
// ConnectChannel establishes a connection to a channel server.
|
||||
// Channel servers do NOT use the 8 NULL byte initialization.
|
||||
func ConnectChannel(addr string) (*ChannelConn, error) {
|
||||
c, err := conn.DialDirect(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("channel connect: %w", err)
|
||||
}
|
||||
|
||||
ch := &ChannelConn{
|
||||
conn: c,
|
||||
}
|
||||
|
||||
go ch.recvLoop()
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
// NextAckHandle returns the next unique ACK handle for packet requests.
|
||||
func (ch *ChannelConn) NextAckHandle() uint32 {
|
||||
return atomic.AddUint32(&ch.ackCounter, 1)
|
||||
}
|
||||
|
||||
// SendPacket encrypts and sends raw packet data (including the 0x00 0x10 terminator
|
||||
// which is already appended by the Build* functions in packets.go).
|
||||
func (ch *ChannelConn) SendPacket(data []byte) error {
|
||||
return ch.conn.SendPacket(data)
|
||||
}
|
||||
|
||||
// WaitForAck waits for an ACK response matching the given handle.
|
||||
func (ch *ChannelConn) WaitForAck(handle uint32, timeout time.Duration) (*AckResponse, error) {
|
||||
waitCh := make(chan *AckResponse, 1)
|
||||
ch.waiters.Store(handle, waitCh)
|
||||
defer ch.waiters.Delete(handle)
|
||||
|
||||
select {
|
||||
case resp := <-waitCh:
|
||||
return resp, nil
|
||||
case <-time.After(timeout):
|
||||
return nil, fmt.Errorf("ACK timeout for handle %d", handle)
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the channel connection.
|
||||
func (ch *ChannelConn) Close() error {
|
||||
ch.closed.Store(true)
|
||||
return ch.conn.Close()
|
||||
}
|
||||
|
||||
// recvLoop continuously reads packets from the channel server and dispatches ACKs.
|
||||
func (ch *ChannelConn) recvLoop() {
|
||||
for {
|
||||
if ch.closed.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
pkt, err := ch.conn.ReadPacket()
|
||||
if err != nil {
|
||||
if ch.closed.Load() {
|
||||
return
|
||||
}
|
||||
fmt.Printf("[channel] read error: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(pkt) < 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
// Strip trailing 0x00 0x10 terminator if present for opcode parsing.
|
||||
// Packets from server: [opcode uint16][fields...][0x00 0x10]
|
||||
opcode := binary.BigEndian.Uint16(pkt[0:2])
|
||||
|
||||
switch opcode {
|
||||
case MSG_SYS_ACK:
|
||||
ch.handleAck(pkt[2:])
|
||||
case MSG_SYS_PING:
|
||||
ch.handlePing(pkt[2:])
|
||||
default:
|
||||
if val, ok := ch.handlers.Load(opcode); ok {
|
||||
val.(PacketHandler)(opcode, pkt[2:])
|
||||
} else {
|
||||
fmt.Printf("[channel] recv opcode 0x%04X (%d bytes)\n", opcode, len(pkt))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleAck parses an ACK packet and dispatches it to a waiting caller.
|
||||
// Reference: Erupe network/mhfpacket/msg_sys_ack.go
|
||||
func (ch *ChannelConn) handleAck(data []byte) {
|
||||
if len(data) < 8 {
|
||||
return
|
||||
}
|
||||
|
||||
ackHandle := binary.BigEndian.Uint32(data[0:4])
|
||||
isBuffer := data[4] > 0
|
||||
errorCode := data[5]
|
||||
|
||||
var ackData []byte
|
||||
if isBuffer {
|
||||
payloadSize := binary.BigEndian.Uint16(data[6:8])
|
||||
offset := uint32(8)
|
||||
if payloadSize == 0xFFFF {
|
||||
if len(data) < 12 {
|
||||
return
|
||||
}
|
||||
payloadSize32 := binary.BigEndian.Uint32(data[8:12])
|
||||
offset = 12
|
||||
if uint32(len(data)) >= offset+payloadSize32 {
|
||||
ackData = data[offset : offset+payloadSize32]
|
||||
}
|
||||
} else {
|
||||
if uint32(len(data)) >= offset+uint32(payloadSize) {
|
||||
ackData = data[offset : offset+uint32(payloadSize)]
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Simple ACK: 4 bytes of data after the uint16 field.
|
||||
if len(data) >= 12 {
|
||||
ackData = data[8:12]
|
||||
}
|
||||
}
|
||||
|
||||
resp := &AckResponse{
|
||||
AckHandle: ackHandle,
|
||||
IsBufferResponse: isBuffer,
|
||||
ErrorCode: errorCode,
|
||||
Data: ackData,
|
||||
}
|
||||
|
||||
if val, ok := ch.waiters.Load(ackHandle); ok {
|
||||
waitCh := val.(chan *AckResponse)
|
||||
select {
|
||||
case waitCh <- resp:
|
||||
default:
|
||||
}
|
||||
} else {
|
||||
fmt.Printf("[channel] unexpected ACK handle %d (error=%d, buffer=%v, %d bytes)\n",
|
||||
ackHandle, errorCode, isBuffer, len(ackData))
|
||||
}
|
||||
}
|
||||
|
||||
// handlePing responds to a server ping to keep the connection alive.
|
||||
func (ch *ChannelConn) handlePing(data []byte) {
|
||||
var ackHandle uint32
|
||||
if len(data) >= 4 {
|
||||
ackHandle = binary.BigEndian.Uint32(data[0:4])
|
||||
}
|
||||
pkt := BuildPingPacket(ackHandle)
|
||||
if err := ch.conn.SendPacket(pkt); err != nil {
|
||||
fmt.Printf("[channel] ping response failed: %v\n", err)
|
||||
}
|
||||
}
|
||||
142
cmd/protbot/protocol/entrance.go
Normal file
142
cmd/protbot/protocol/entrance.go
Normal file
@@ -0,0 +1,142 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"net"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
|
||||
"erupe-ce/cmd/protbot/conn"
|
||||
)
|
||||
|
||||
// ServerEntry represents a channel server from the entrance server response.
|
||||
type ServerEntry struct {
|
||||
IP string
|
||||
Port uint16
|
||||
Name string
|
||||
}
|
||||
|
||||
// DoEntrance connects to the entrance server and retrieves the server list.
|
||||
// Reference: Erupe server/entranceserver/entrance_server.go and make_resp.go.
|
||||
func DoEntrance(addr string) ([]ServerEntry, error) {
|
||||
c, err := conn.DialWithInit(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entrance connect: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Send a minimal packet (the entrance server reads it, checks len > 5 for USR data).
|
||||
// An empty/short packet triggers only SV2 response.
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint8(0)
|
||||
if err := c.SendPacket(bf.Data()); err != nil {
|
||||
return nil, fmt.Errorf("entrance send: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entrance recv: %w", err)
|
||||
}
|
||||
|
||||
return parseEntranceResponse(resp)
|
||||
}
|
||||
|
||||
// parseEntranceResponse parses the Bin8-encrypted entrance server response.
|
||||
// Reference: Erupe server/entranceserver/make_resp.go (makeHeader, makeSv2Resp)
|
||||
func parseEntranceResponse(data []byte) ([]ServerEntry, error) {
|
||||
if len(data) < 2 {
|
||||
return nil, fmt.Errorf("entrance response too short")
|
||||
}
|
||||
|
||||
// First byte is the Bin8 encryption key.
|
||||
key := data[0]
|
||||
decrypted := conn.DecryptBin8(data[1:], key)
|
||||
|
||||
rbf := byteframe.NewByteFrameFromBytes(decrypted)
|
||||
|
||||
// Read response type header: "SV2" or "SVR"
|
||||
respType := string(rbf.ReadBytes(3))
|
||||
if respType != "SV2" && respType != "SVR" {
|
||||
return nil, fmt.Errorf("unexpected entrance response type: %s", respType)
|
||||
}
|
||||
|
||||
entryCount := rbf.ReadUint16()
|
||||
dataLen := rbf.ReadUint16()
|
||||
if dataLen == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
expectedSum := rbf.ReadUint32()
|
||||
serverData := rbf.ReadBytes(uint(dataLen))
|
||||
|
||||
actualSum := conn.CalcSum32(serverData)
|
||||
if expectedSum != actualSum {
|
||||
return nil, fmt.Errorf("entrance checksum mismatch: expected %08X, got %08X", expectedSum, actualSum)
|
||||
}
|
||||
|
||||
return parseServerEntries(serverData, entryCount)
|
||||
}
|
||||
|
||||
// parseServerEntries parses the server info binary blob.
|
||||
// Reference: Erupe server/entranceserver/make_resp.go (encodeServerInfo)
|
||||
func parseServerEntries(data []byte, entryCount uint16) ([]ServerEntry, error) {
|
||||
bf := byteframe.NewByteFrameFromBytes(data)
|
||||
var entries []ServerEntry
|
||||
|
||||
for i := uint16(0); i < entryCount; i++ {
|
||||
ipBytes := bf.ReadBytes(4)
|
||||
ip := net.IP([]byte{
|
||||
byte(ipBytes[3]), byte(ipBytes[2]),
|
||||
byte(ipBytes[1]), byte(ipBytes[0]),
|
||||
})
|
||||
|
||||
_ = bf.ReadUint16() // serverIdx | 16
|
||||
_ = bf.ReadUint16() // 0
|
||||
channelCount := bf.ReadUint16()
|
||||
_ = bf.ReadUint8() // Type
|
||||
_ = bf.ReadUint8() // Season/rotation
|
||||
|
||||
// G1+ recommended flag
|
||||
_ = bf.ReadUint8()
|
||||
|
||||
// G51+ (ZZ): skip 1 byte, then read 65-byte padded name
|
||||
_ = bf.ReadUint8()
|
||||
nameBytes := bf.ReadBytes(65)
|
||||
|
||||
// GG+: AllowedClientFlags
|
||||
_ = bf.ReadUint32()
|
||||
|
||||
// Parse name (null-separated: name + description)
|
||||
name := ""
|
||||
for j := 0; j < len(nameBytes); j++ {
|
||||
if nameBytes[j] == 0 {
|
||||
break
|
||||
}
|
||||
name += string(nameBytes[j])
|
||||
}
|
||||
|
||||
// Read channel entries (14 x uint16 = 28 bytes each)
|
||||
for j := uint16(0); j < channelCount; j++ {
|
||||
port := bf.ReadUint16()
|
||||
_ = bf.ReadUint16() // channelIdx | 16
|
||||
_ = bf.ReadUint16() // maxPlayers
|
||||
_ = bf.ReadUint16() // currentPlayers
|
||||
_ = bf.ReadBytes(18) // remaining channel fields (9 x uint16: 6 zeros + unk319 + unk254 + unk255)
|
||||
_ = bf.ReadUint16() // 12345
|
||||
|
||||
serverIP := ip.String()
|
||||
// Convert 127.0.0.1 representation
|
||||
if binary.LittleEndian.Uint32(ipBytes) == 0x0100007F {
|
||||
serverIP = "127.0.0.1"
|
||||
}
|
||||
|
||||
entries = append(entries, ServerEntry{
|
||||
IP: serverIP,
|
||||
Port: port,
|
||||
Name: fmt.Sprintf("%s ch%d", name, j+1),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return entries, nil
|
||||
}
|
||||
23
cmd/protbot/protocol/opcodes.go
Normal file
23
cmd/protbot/protocol/opcodes.go
Normal file
@@ -0,0 +1,23 @@
|
||||
// Package protocol implements MHF network protocol message building and parsing.
|
||||
package protocol
|
||||
|
||||
// Packet opcodes (subset from Erupe's network/packetid.go iota).
|
||||
const (
|
||||
MSG_SYS_ACK uint16 = 0x0012
|
||||
MSG_SYS_LOGIN uint16 = 0x0014
|
||||
MSG_SYS_LOGOUT uint16 = 0x0015
|
||||
MSG_SYS_PING uint16 = 0x0017
|
||||
MSG_SYS_CAST_BINARY uint16 = 0x0018
|
||||
MSG_SYS_TIME uint16 = 0x001A
|
||||
MSG_SYS_CASTED_BINARY uint16 = 0x001B
|
||||
MSG_SYS_ISSUE_LOGKEY uint16 = 0x001D
|
||||
MSG_SYS_ENTER_STAGE uint16 = 0x0022
|
||||
MSG_SYS_ENUMERATE_STAGE uint16 = 0x002F
|
||||
MSG_SYS_INSERT_USER uint16 = 0x0050
|
||||
MSG_SYS_DELETE_USER uint16 = 0x0051
|
||||
MSG_SYS_UPDATE_RIGHT uint16 = 0x0058
|
||||
MSG_SYS_RIGHTS_RELOAD uint16 = 0x005D
|
||||
MSG_MHF_LOADDATA uint16 = 0x0061
|
||||
MSG_MHF_ENUMERATE_QUEST uint16 = 0x009F
|
||||
MSG_MHF_GET_WEEKLY_SCHED uint16 = 0x00E1
|
||||
)
|
||||
229
cmd/protbot/protocol/packets.go
Normal file
229
cmd/protbot/protocol/packets.go
Normal file
@@ -0,0 +1,229 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"erupe-ce/common/byteframe"
|
||||
"erupe-ce/common/stringsupport"
|
||||
)
|
||||
|
||||
// BuildLoginPacket builds a MSG_SYS_LOGIN packet.
|
||||
// Layout mirrors Erupe's MsgSysLogin.Parse:
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint32 charID
|
||||
// uint32 loginTokenNumber
|
||||
// uint16 hardcodedZero
|
||||
// uint16 requestVersion (set to 0xCAFE as dummy)
|
||||
// uint32 charID (repeated)
|
||||
// uint16 zeroed
|
||||
// uint16 always 11
|
||||
// null-terminated tokenString
|
||||
// 0x00 0x10 terminator
|
||||
func BuildLoginPacket(ackHandle, charID, tokenNumber uint32, tokenString string) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_LOGIN)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint32(charID)
|
||||
bf.WriteUint32(tokenNumber)
|
||||
bf.WriteUint16(0) // HardcodedZero0
|
||||
bf.WriteUint16(0xCAFE) // RequestVersion (dummy)
|
||||
bf.WriteUint32(charID) // CharID1 (repeated)
|
||||
bf.WriteUint16(0) // Zeroed
|
||||
bf.WriteUint16(11) // Always 11
|
||||
bf.WriteNullTerminatedBytes([]byte(tokenString))
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildEnumerateStagePacket builds a MSG_SYS_ENUMERATE_STAGE packet.
|
||||
// Layout mirrors Erupe's MsgSysEnumerateStage.Parse:
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint8 always 1
|
||||
// uint8 prefix length (including null terminator)
|
||||
// null-terminated stagePrefix
|
||||
// 0x00 0x10 terminator
|
||||
func BuildEnumerateStagePacket(ackHandle uint32, prefix string) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_ENUMERATE_STAGE)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint8(1) // Always 1
|
||||
bf.WriteUint8(uint8(len(prefix) + 1)) // Length including null terminator
|
||||
bf.WriteNullTerminatedBytes([]byte(prefix))
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildEnterStagePacket builds a MSG_SYS_ENTER_STAGE packet.
|
||||
// Layout mirrors Erupe's MsgSysEnterStage.Parse:
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint8 isQuest (0=false)
|
||||
// uint8 stageID length (including null terminator)
|
||||
// null-terminated stageID
|
||||
// 0x00 0x10 terminator
|
||||
func BuildEnterStagePacket(ackHandle uint32, stageID string) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_ENTER_STAGE)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint8(0) // IsQuest = false
|
||||
bf.WriteUint8(uint8(len(stageID) + 1)) // Length including null terminator
|
||||
bf.WriteNullTerminatedBytes([]byte(stageID))
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildPingPacket builds a MSG_SYS_PING response packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// 0x00 0x10 terminator
|
||||
func BuildPingPacket(ackHandle uint32) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_PING)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildLogoutPacket builds a MSG_SYS_LOGOUT packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint8 logoutType (1 = normal logout)
|
||||
// 0x00 0x10 terminator
|
||||
func BuildLogoutPacket() []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_LOGOUT)
|
||||
bf.WriteUint8(1) // LogoutType = normal
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildIssueLogkeyPacket builds a MSG_SYS_ISSUE_LOGKEY packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint16 unk0
|
||||
// uint16 unk1
|
||||
// 0x00 0x10 terminator
|
||||
func BuildIssueLogkeyPacket(ackHandle uint32) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_ISSUE_LOGKEY)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint16(0)
|
||||
bf.WriteUint16(0)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildRightsReloadPacket builds a MSG_SYS_RIGHTS_RELOAD packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint8 count (0 = empty)
|
||||
// 0x00 0x10 terminator
|
||||
func BuildRightsReloadPacket(ackHandle uint32) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_RIGHTS_RELOAD)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint8(0) // Count = 0 (no rights entries)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildLoaddataPacket builds a MSG_MHF_LOADDATA packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// 0x00 0x10 terminator
|
||||
func BuildLoaddataPacket(ackHandle uint32) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_MHF_LOADDATA)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildCastBinaryPacket builds a MSG_SYS_CAST_BINARY packet.
|
||||
// Layout mirrors Erupe's MsgSysCastBinary.Parse:
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 unk (always 0)
|
||||
// uint8 broadcastType
|
||||
// uint8 messageType
|
||||
// uint16 dataSize
|
||||
// []byte payload
|
||||
// 0x00 0x10 terminator
|
||||
func BuildCastBinaryPacket(broadcastType, messageType uint8, payload []byte) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_SYS_CAST_BINARY)
|
||||
bf.WriteUint32(0) // Unk
|
||||
bf.WriteUint8(broadcastType)
|
||||
bf.WriteUint8(messageType)
|
||||
bf.WriteUint16(uint16(len(payload)))
|
||||
bf.WriteBytes(payload)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildChatPayload builds the inner MsgBinChat binary blob for use with BuildCastBinaryPacket.
|
||||
// Layout mirrors Erupe's binpacket/msg_bin_chat.go Build:
|
||||
//
|
||||
// uint8 unk0 (always 0)
|
||||
// uint8 chatType
|
||||
// uint16 flags (always 0)
|
||||
// uint16 senderNameLen (SJIS bytes + null terminator)
|
||||
// uint16 messageLen (SJIS bytes + null terminator)
|
||||
// null-terminated SJIS message
|
||||
// null-terminated SJIS senderName
|
||||
func BuildChatPayload(chatType uint8, message, senderName string) []byte {
|
||||
sjisMsg := stringsupport.UTF8ToSJIS(message)
|
||||
sjisName := stringsupport.UTF8ToSJIS(senderName)
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint8(0) // Unk0
|
||||
bf.WriteUint8(chatType) // Type
|
||||
bf.WriteUint16(0) // Flags
|
||||
bf.WriteUint16(uint16(len(sjisName) + 1)) // SenderName length (+ null term)
|
||||
bf.WriteUint16(uint16(len(sjisMsg) + 1)) // Message length (+ null term)
|
||||
bf.WriteNullTerminatedBytes(sjisMsg) // Message
|
||||
bf.WriteNullTerminatedBytes(sjisName) // SenderName
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildEnumerateQuestPacket builds a MSG_MHF_ENUMERATE_QUEST packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// uint8 unk0 (always 0)
|
||||
// uint8 world
|
||||
// uint16 counter
|
||||
// uint16 offset
|
||||
// uint8 unk1 (always 0)
|
||||
// 0x00 0x10 terminator
|
||||
func BuildEnumerateQuestPacket(ackHandle uint32, world uint8, counter, offset uint16) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_MHF_ENUMERATE_QUEST)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteUint8(0) // Unk0
|
||||
bf.WriteUint8(world)
|
||||
bf.WriteUint16(counter)
|
||||
bf.WriteUint16(offset)
|
||||
bf.WriteUint8(0) // Unk1
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
|
||||
// BuildGetWeeklySchedulePacket builds a MSG_MHF_GET_WEEKLY_SCHEDULE packet.
|
||||
//
|
||||
// uint16 opcode
|
||||
// uint32 ackHandle
|
||||
// 0x00 0x10 terminator
|
||||
func BuildGetWeeklySchedulePacket(ackHandle uint32) []byte {
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteUint16(MSG_MHF_GET_WEEKLY_SCHED)
|
||||
bf.WriteUint32(ackHandle)
|
||||
bf.WriteBytes([]byte{0x00, 0x10})
|
||||
return bf.Data()
|
||||
}
|
||||
412
cmd/protbot/protocol/packets_test.go
Normal file
412
cmd/protbot/protocol/packets_test.go
Normal file
@@ -0,0 +1,412 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"testing"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
)
|
||||
|
||||
// TestBuildLoginPacket verifies that the binary layout matches Erupe's Parse.
|
||||
func TestBuildLoginPacket(t *testing.T) {
|
||||
ackHandle := uint32(1)
|
||||
charID := uint32(100)
|
||||
tokenNumber := uint32(42)
|
||||
tokenString := "0123456789ABCDEF"
|
||||
|
||||
pkt := BuildLoginPacket(ackHandle, charID, tokenNumber, tokenString)
|
||||
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
opcode := bf.ReadUint16()
|
||||
if opcode != MSG_SYS_LOGIN {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_LOGIN)
|
||||
}
|
||||
|
||||
gotAck := bf.ReadUint32()
|
||||
if gotAck != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle)
|
||||
}
|
||||
|
||||
gotCharID0 := bf.ReadUint32()
|
||||
if gotCharID0 != charID {
|
||||
t.Fatalf("charID0: got %d, want %d", gotCharID0, charID)
|
||||
}
|
||||
|
||||
gotTokenNum := bf.ReadUint32()
|
||||
if gotTokenNum != tokenNumber {
|
||||
t.Fatalf("tokenNumber: got %d, want %d", gotTokenNum, tokenNumber)
|
||||
}
|
||||
|
||||
gotZero := bf.ReadUint16()
|
||||
if gotZero != 0 {
|
||||
t.Fatalf("hardcodedZero: got %d, want 0", gotZero)
|
||||
}
|
||||
|
||||
gotVersion := bf.ReadUint16()
|
||||
if gotVersion != 0xCAFE {
|
||||
t.Fatalf("requestVersion: got 0x%04X, want 0xCAFE", gotVersion)
|
||||
}
|
||||
|
||||
gotCharID1 := bf.ReadUint32()
|
||||
if gotCharID1 != charID {
|
||||
t.Fatalf("charID1: got %d, want %d", gotCharID1, charID)
|
||||
}
|
||||
|
||||
gotZeroed := bf.ReadUint16()
|
||||
if gotZeroed != 0 {
|
||||
t.Fatalf("zeroed: got %d, want 0", gotZeroed)
|
||||
}
|
||||
|
||||
gotEleven := bf.ReadUint16()
|
||||
if gotEleven != 11 {
|
||||
t.Fatalf("always11: got %d, want 11", gotEleven)
|
||||
}
|
||||
|
||||
gotToken := string(bf.ReadNullTerminatedBytes())
|
||||
if gotToken != tokenString {
|
||||
t.Fatalf("tokenString: got %q, want %q", gotToken, tokenString)
|
||||
}
|
||||
|
||||
// Verify terminator.
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildEnumerateStagePacket verifies binary layout matches Erupe's Parse.
|
||||
func TestBuildEnumerateStagePacket(t *testing.T) {
|
||||
ackHandle := uint32(5)
|
||||
prefix := "sl1Ns"
|
||||
|
||||
pkt := BuildEnumerateStagePacket(ackHandle, prefix)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
opcode := bf.ReadUint16()
|
||||
if opcode != MSG_SYS_ENUMERATE_STAGE {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENUMERATE_STAGE)
|
||||
}
|
||||
|
||||
gotAck := bf.ReadUint32()
|
||||
if gotAck != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle)
|
||||
}
|
||||
|
||||
alwaysOne := bf.ReadUint8()
|
||||
if alwaysOne != 1 {
|
||||
t.Fatalf("alwaysOne: got %d, want 1", alwaysOne)
|
||||
}
|
||||
|
||||
prefixLen := bf.ReadUint8()
|
||||
if prefixLen != uint8(len(prefix)+1) {
|
||||
t.Fatalf("prefixLen: got %d, want %d", prefixLen, len(prefix)+1)
|
||||
}
|
||||
|
||||
gotPrefix := string(bf.ReadNullTerminatedBytes())
|
||||
if gotPrefix != prefix {
|
||||
t.Fatalf("prefix: got %q, want %q", gotPrefix, prefix)
|
||||
}
|
||||
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildEnterStagePacket verifies binary layout matches Erupe's Parse.
|
||||
func TestBuildEnterStagePacket(t *testing.T) {
|
||||
ackHandle := uint32(7)
|
||||
stageID := "sl1Ns200p0a0u0"
|
||||
|
||||
pkt := BuildEnterStagePacket(ackHandle, stageID)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
opcode := bf.ReadUint16()
|
||||
if opcode != MSG_SYS_ENTER_STAGE {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", opcode, MSG_SYS_ENTER_STAGE)
|
||||
}
|
||||
|
||||
gotAck := bf.ReadUint32()
|
||||
if gotAck != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", gotAck, ackHandle)
|
||||
}
|
||||
|
||||
isQuest := bf.ReadUint8()
|
||||
if isQuest != 0 {
|
||||
t.Fatalf("isQuest: got %d, want 0", isQuest)
|
||||
}
|
||||
|
||||
stageLen := bf.ReadUint8()
|
||||
if stageLen != uint8(len(stageID)+1) {
|
||||
t.Fatalf("stageLen: got %d, want %d", stageLen, len(stageID)+1)
|
||||
}
|
||||
|
||||
gotStage := string(bf.ReadNullTerminatedBytes())
|
||||
if gotStage != stageID {
|
||||
t.Fatalf("stageID: got %q, want %q", gotStage, stageID)
|
||||
}
|
||||
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildPingPacket verifies MSG_SYS_PING binary layout.
|
||||
func TestBuildPingPacket(t *testing.T) {
|
||||
ackHandle := uint32(99)
|
||||
pkt := BuildPingPacket(ackHandle)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_SYS_PING {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_PING)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildLogoutPacket verifies MSG_SYS_LOGOUT binary layout.
|
||||
func TestBuildLogoutPacket(t *testing.T) {
|
||||
pkt := BuildLogoutPacket()
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_SYS_LOGOUT {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_LOGOUT)
|
||||
}
|
||||
if lt := bf.ReadUint8(); lt != 1 {
|
||||
t.Fatalf("logoutType: got %d, want 1", lt)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildIssueLogkeyPacket verifies MSG_SYS_ISSUE_LOGKEY binary layout.
|
||||
func TestBuildIssueLogkeyPacket(t *testing.T) {
|
||||
ackHandle := uint32(10)
|
||||
pkt := BuildIssueLogkeyPacket(ackHandle)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_SYS_ISSUE_LOGKEY {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_ISSUE_LOGKEY)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
if v := bf.ReadUint16(); v != 0 {
|
||||
t.Fatalf("unk0: got %d, want 0", v)
|
||||
}
|
||||
if v := bf.ReadUint16(); v != 0 {
|
||||
t.Fatalf("unk1: got %d, want 0", v)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildRightsReloadPacket verifies MSG_SYS_RIGHTS_RELOAD binary layout.
|
||||
func TestBuildRightsReloadPacket(t *testing.T) {
|
||||
ackHandle := uint32(20)
|
||||
pkt := BuildRightsReloadPacket(ackHandle)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_SYS_RIGHTS_RELOAD {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_RIGHTS_RELOAD)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
if c := bf.ReadUint8(); c != 0 {
|
||||
t.Fatalf("count: got %d, want 0", c)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildLoaddataPacket verifies MSG_MHF_LOADDATA binary layout.
|
||||
func TestBuildLoaddataPacket(t *testing.T) {
|
||||
ackHandle := uint32(30)
|
||||
pkt := BuildLoaddataPacket(ackHandle)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_MHF_LOADDATA {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_LOADDATA)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildCastBinaryPacket verifies MSG_SYS_CAST_BINARY binary layout.
|
||||
func TestBuildCastBinaryPacket(t *testing.T) {
|
||||
payload := []byte{0xDE, 0xAD, 0xBE, 0xEF}
|
||||
pkt := BuildCastBinaryPacket(0x03, 1, payload)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_SYS_CAST_BINARY {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_SYS_CAST_BINARY)
|
||||
}
|
||||
if unk := bf.ReadUint32(); unk != 0 {
|
||||
t.Fatalf("unk: got %d, want 0", unk)
|
||||
}
|
||||
if bt := bf.ReadUint8(); bt != 0x03 {
|
||||
t.Fatalf("broadcastType: got %d, want 3", bt)
|
||||
}
|
||||
if mt := bf.ReadUint8(); mt != 1 {
|
||||
t.Fatalf("messageType: got %d, want 1", mt)
|
||||
}
|
||||
if ds := bf.ReadUint16(); ds != uint16(len(payload)) {
|
||||
t.Fatalf("dataSize: got %d, want %d", ds, len(payload))
|
||||
}
|
||||
gotPayload := bf.ReadBytes(uint(len(payload)))
|
||||
for i, b := range payload {
|
||||
if gotPayload[i] != b {
|
||||
t.Fatalf("payload[%d]: got 0x%02X, want 0x%02X", i, gotPayload[i], b)
|
||||
}
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildChatPayload verifies the MsgBinChat inner binary layout and SJIS encoding.
|
||||
func TestBuildChatPayload(t *testing.T) {
|
||||
chatType := uint8(1)
|
||||
message := "Hello"
|
||||
senderName := "TestUser"
|
||||
|
||||
payload := BuildChatPayload(chatType, message, senderName)
|
||||
bf := byteframe.NewByteFrameFromBytes(payload)
|
||||
|
||||
if unk := bf.ReadUint8(); unk != 0 {
|
||||
t.Fatalf("unk0: got %d, want 0", unk)
|
||||
}
|
||||
if ct := bf.ReadUint8(); ct != chatType {
|
||||
t.Fatalf("chatType: got %d, want %d", ct, chatType)
|
||||
}
|
||||
if flags := bf.ReadUint16(); flags != 0 {
|
||||
t.Fatalf("flags: got %d, want 0", flags)
|
||||
}
|
||||
nameLen := bf.ReadUint16()
|
||||
msgLen := bf.ReadUint16()
|
||||
// "Hello" in ASCII/SJIS = 5 bytes + 1 null = 6
|
||||
if msgLen != 6 {
|
||||
t.Fatalf("messageLen: got %d, want 6", msgLen)
|
||||
}
|
||||
// "TestUser" in ASCII/SJIS = 8 bytes + 1 null = 9
|
||||
if nameLen != 9 {
|
||||
t.Fatalf("senderNameLen: got %d, want 9", nameLen)
|
||||
}
|
||||
|
||||
gotMsg := string(bf.ReadNullTerminatedBytes())
|
||||
if gotMsg != message {
|
||||
t.Fatalf("message: got %q, want %q", gotMsg, message)
|
||||
}
|
||||
gotName := string(bf.ReadNullTerminatedBytes())
|
||||
if gotName != senderName {
|
||||
t.Fatalf("senderName: got %q, want %q", gotName, senderName)
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildEnumerateQuestPacket verifies MSG_MHF_ENUMERATE_QUEST binary layout.
|
||||
func TestBuildEnumerateQuestPacket(t *testing.T) {
|
||||
ackHandle := uint32(40)
|
||||
world := uint8(2)
|
||||
counter := uint16(100)
|
||||
offset := uint16(50)
|
||||
|
||||
pkt := BuildEnumerateQuestPacket(ackHandle, world, counter, offset)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_MHF_ENUMERATE_QUEST {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_ENUMERATE_QUEST)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
if u0 := bf.ReadUint8(); u0 != 0 {
|
||||
t.Fatalf("unk0: got %d, want 0", u0)
|
||||
}
|
||||
if w := bf.ReadUint8(); w != world {
|
||||
t.Fatalf("world: got %d, want %d", w, world)
|
||||
}
|
||||
if c := bf.ReadUint16(); c != counter {
|
||||
t.Fatalf("counter: got %d, want %d", c, counter)
|
||||
}
|
||||
if o := bf.ReadUint16(); o != offset {
|
||||
t.Fatalf("offset: got %d, want %d", o, offset)
|
||||
}
|
||||
if u1 := bf.ReadUint8(); u1 != 0 {
|
||||
t.Fatalf("unk1: got %d, want 0", u1)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestBuildGetWeeklySchedulePacket verifies MSG_MHF_GET_WEEKLY_SCHEDULE binary layout.
|
||||
func TestBuildGetWeeklySchedulePacket(t *testing.T) {
|
||||
ackHandle := uint32(50)
|
||||
pkt := BuildGetWeeklySchedulePacket(ackHandle)
|
||||
bf := byteframe.NewByteFrameFromBytes(pkt)
|
||||
|
||||
if op := bf.ReadUint16(); op != MSG_MHF_GET_WEEKLY_SCHED {
|
||||
t.Fatalf("opcode: got 0x%04X, want 0x%04X", op, MSG_MHF_GET_WEEKLY_SCHED)
|
||||
}
|
||||
if ack := bf.ReadUint32(); ack != ackHandle {
|
||||
t.Fatalf("ackHandle: got %d, want %d", ack, ackHandle)
|
||||
}
|
||||
term := bf.ReadBytes(2)
|
||||
if term[0] != 0x00 || term[1] != 0x10 {
|
||||
t.Fatalf("terminator: got %02X %02X, want 00 10", term[0], term[1])
|
||||
}
|
||||
}
|
||||
|
||||
// TestOpcodeValues verifies opcode constants match Erupe's iota-based enum.
|
||||
func TestOpcodeValues(t *testing.T) {
|
||||
_ = binary.BigEndian // ensure import used
|
||||
tests := []struct {
|
||||
name string
|
||||
got uint16
|
||||
want uint16
|
||||
}{
|
||||
{"MSG_SYS_ACK", MSG_SYS_ACK, 0x0012},
|
||||
{"MSG_SYS_LOGIN", MSG_SYS_LOGIN, 0x0014},
|
||||
{"MSG_SYS_LOGOUT", MSG_SYS_LOGOUT, 0x0015},
|
||||
{"MSG_SYS_PING", MSG_SYS_PING, 0x0017},
|
||||
{"MSG_SYS_CAST_BINARY", MSG_SYS_CAST_BINARY, 0x0018},
|
||||
{"MSG_SYS_TIME", MSG_SYS_TIME, 0x001A},
|
||||
{"MSG_SYS_CASTED_BINARY", MSG_SYS_CASTED_BINARY, 0x001B},
|
||||
{"MSG_SYS_ISSUE_LOGKEY", MSG_SYS_ISSUE_LOGKEY, 0x001D},
|
||||
{"MSG_SYS_ENTER_STAGE", MSG_SYS_ENTER_STAGE, 0x0022},
|
||||
{"MSG_SYS_ENUMERATE_STAGE", MSG_SYS_ENUMERATE_STAGE, 0x002F},
|
||||
{"MSG_SYS_INSERT_USER", MSG_SYS_INSERT_USER, 0x0050},
|
||||
{"MSG_SYS_DELETE_USER", MSG_SYS_DELETE_USER, 0x0051},
|
||||
{"MSG_SYS_UPDATE_RIGHT", MSG_SYS_UPDATE_RIGHT, 0x0058},
|
||||
{"MSG_SYS_RIGHTS_RELOAD", MSG_SYS_RIGHTS_RELOAD, 0x005D},
|
||||
{"MSG_MHF_LOADDATA", MSG_MHF_LOADDATA, 0x0061},
|
||||
{"MSG_MHF_ENUMERATE_QUEST", MSG_MHF_ENUMERATE_QUEST, 0x009F},
|
||||
{"MSG_MHF_GET_WEEKLY_SCHED", MSG_MHF_GET_WEEKLY_SCHED, 0x00E1},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.got != tt.want {
|
||||
t.Errorf("%s: got 0x%04X, want 0x%04X", tt.name, tt.got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
106
cmd/protbot/protocol/sign.go
Normal file
106
cmd/protbot/protocol/sign.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package protocol
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
"erupe-ce/common/stringsupport"
|
||||
|
||||
"erupe-ce/cmd/protbot/conn"
|
||||
)
|
||||
|
||||
// SignResult holds the parsed response from a successful DSGN sign-in.
|
||||
type SignResult struct {
|
||||
TokenID uint32
|
||||
TokenString string // 16 raw bytes as string
|
||||
Timestamp uint32
|
||||
EntranceAddr string
|
||||
CharIDs []uint32
|
||||
}
|
||||
|
||||
// DoSign connects to the sign server and performs a DSGN login.
|
||||
// Reference: Erupe server/signserver/session.go (handleDSGN) and dsgn_resp.go (makeSignResponse).
|
||||
func DoSign(addr, username, password string) (*SignResult, error) {
|
||||
c, err := conn.DialWithInit(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign connect: %w", err)
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
// Build DSGN request: "DSGN:041" + \x00 + SJIS(user) + \x00 + SJIS(pass) + \x00 + \x00
|
||||
// The server reads: null-terminated request type, null-terminated user, null-terminated pass, null-terminated unk.
|
||||
// The request type has a 3-char version suffix (e.g. "041" for ZZ client mode 41) that the server strips.
|
||||
bf := byteframe.NewByteFrame()
|
||||
bf.WriteNullTerminatedBytes([]byte("DSGN:041")) // reqType with version suffix (server strips last 3 chars to get "DSGN:")
|
||||
bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(username))
|
||||
bf.WriteNullTerminatedBytes(stringsupport.UTF8ToSJIS(password))
|
||||
bf.WriteUint8(0) // Unk null-terminated empty string
|
||||
|
||||
if err := c.SendPacket(bf.Data()); err != nil {
|
||||
return nil, fmt.Errorf("sign send: %w", err)
|
||||
}
|
||||
|
||||
resp, err := c.ReadPacket()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign recv: %w", err)
|
||||
}
|
||||
|
||||
return parseSignResponse(resp)
|
||||
}
|
||||
|
||||
// parseSignResponse parses the binary response from the sign server.
|
||||
// Reference: Erupe server/signserver/dsgn_resp.go:makeSignResponse
|
||||
func parseSignResponse(data []byte) (*SignResult, error) {
|
||||
if len(data) < 1 {
|
||||
return nil, fmt.Errorf("empty sign response")
|
||||
}
|
||||
|
||||
rbf := byteframe.NewByteFrameFromBytes(data)
|
||||
|
||||
resultCode := rbf.ReadUint8()
|
||||
if resultCode != 1 { // SIGN_SUCCESS = 1
|
||||
return nil, fmt.Errorf("sign failed with code %d", resultCode)
|
||||
}
|
||||
|
||||
patchCount := rbf.ReadUint8() // patch server count (usually 2)
|
||||
_ = rbf.ReadUint8() // entrance server count (usually 1)
|
||||
charCount := rbf.ReadUint8() // character count
|
||||
|
||||
result := &SignResult{}
|
||||
result.TokenID = rbf.ReadUint32()
|
||||
result.TokenString = string(rbf.ReadBytes(16)) // 16 raw bytes
|
||||
result.Timestamp = rbf.ReadUint32()
|
||||
|
||||
// Skip patch server URLs (pascal strings with uint8 length prefix)
|
||||
for i := uint8(0); i < patchCount; i++ {
|
||||
strLen := rbf.ReadUint8()
|
||||
_ = rbf.ReadBytes(uint(strLen))
|
||||
}
|
||||
|
||||
// Read entrance server address (pascal string with uint8 length prefix)
|
||||
entranceLen := rbf.ReadUint8()
|
||||
result.EntranceAddr = string(rbf.ReadBytes(uint(entranceLen - 1)))
|
||||
_ = rbf.ReadUint8() // null terminator
|
||||
|
||||
// Read character entries
|
||||
for i := uint8(0); i < charCount; i++ {
|
||||
charID := rbf.ReadUint32()
|
||||
result.CharIDs = append(result.CharIDs, charID)
|
||||
|
||||
_ = rbf.ReadUint16() // HR
|
||||
_ = rbf.ReadUint16() // WeaponType
|
||||
_ = rbf.ReadUint32() // LastLogin
|
||||
_ = rbf.ReadUint8() // IsFemale
|
||||
_ = rbf.ReadUint8() // IsNewCharacter
|
||||
_ = rbf.ReadUint8() // Old GR
|
||||
_ = rbf.ReadUint8() // Use uint16 GR flag
|
||||
_ = rbf.ReadBytes(16) // Character name (padded)
|
||||
_ = rbf.ReadBytes(32) // Unk desc string (padded)
|
||||
// ZZ mode: additional fields
|
||||
_ = rbf.ReadUint16() // GR
|
||||
_ = rbf.ReadUint8() // Unk
|
||||
_ = rbf.ReadUint8() // Unk
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
74
cmd/protbot/scenario/chat.go
Normal file
74
cmd/protbot/scenario/chat.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
"erupe-ce/common/stringsupport"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// ChatMessage holds a parsed incoming chat message.
|
||||
type ChatMessage struct {
|
||||
ChatType uint8
|
||||
SenderName string
|
||||
Message string
|
||||
}
|
||||
|
||||
// SendChat sends a chat message via MSG_SYS_CAST_BINARY with a MsgBinChat payload.
|
||||
// broadcastType controls delivery scope: 0x03 = stage, 0x06 = world.
|
||||
func SendChat(ch *protocol.ChannelConn, broadcastType, chatType uint8, message, senderName string) error {
|
||||
payload := protocol.BuildChatPayload(chatType, message, senderName)
|
||||
pkt := protocol.BuildCastBinaryPacket(broadcastType, 1, payload)
|
||||
fmt.Printf("[chat] Sending chat (type=%d, broadcast=%d): %s\n", chatType, broadcastType, message)
|
||||
return ch.SendPacket(pkt)
|
||||
}
|
||||
|
||||
// ChatCallback is invoked when a chat message is received.
|
||||
type ChatCallback func(msg ChatMessage)
|
||||
|
||||
// ListenChat registers a handler on MSG_SYS_CASTED_BINARY that parses chat
|
||||
// messages (messageType=1) and invokes the callback.
|
||||
func ListenChat(ch *protocol.ChannelConn, cb ChatCallback) {
|
||||
ch.OnPacket(protocol.MSG_SYS_CASTED_BINARY, func(opcode uint16, data []byte) {
|
||||
// MSG_SYS_CASTED_BINARY layout from server:
|
||||
// uint32 unk
|
||||
// uint8 broadcastType
|
||||
// uint8 messageType
|
||||
// uint16 dataSize
|
||||
// []byte payload
|
||||
if len(data) < 8 {
|
||||
return
|
||||
}
|
||||
messageType := data[5]
|
||||
if messageType != 1 { // Only handle chat messages.
|
||||
return
|
||||
}
|
||||
bf := byteframe.NewByteFrameFromBytes(data)
|
||||
_ = bf.ReadUint32() // unk
|
||||
_ = bf.ReadUint8() // broadcastType
|
||||
_ = bf.ReadUint8() // messageType
|
||||
dataSize := bf.ReadUint16()
|
||||
if dataSize == 0 {
|
||||
return
|
||||
}
|
||||
payload := bf.ReadBytes(uint(dataSize))
|
||||
|
||||
// Parse MsgBinChat inner payload.
|
||||
pbf := byteframe.NewByteFrameFromBytes(payload)
|
||||
_ = pbf.ReadUint8() // unk0
|
||||
chatType := pbf.ReadUint8()
|
||||
_ = pbf.ReadUint16() // flags
|
||||
_ = pbf.ReadUint16() // senderNameLen
|
||||
_ = pbf.ReadUint16() // messageLen
|
||||
msg := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes())
|
||||
sender := stringsupport.SJISToUTF8(pbf.ReadNullTerminatedBytes())
|
||||
|
||||
cb(ChatMessage{
|
||||
ChatType: chatType,
|
||||
SenderName: sender,
|
||||
Message: msg,
|
||||
})
|
||||
})
|
||||
}
|
||||
82
cmd/protbot/scenario/login.go
Normal file
82
cmd/protbot/scenario/login.go
Normal file
@@ -0,0 +1,82 @@
|
||||
// Package scenario provides high-level MHF protocol flows.
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// LoginResult holds the outcome of a full login flow.
|
||||
type LoginResult struct {
|
||||
Sign *protocol.SignResult
|
||||
Servers []protocol.ServerEntry
|
||||
Channel *protocol.ChannelConn
|
||||
}
|
||||
|
||||
// Login performs the full sign → entrance → channel login flow.
|
||||
func Login(signAddr, username, password string) (*LoginResult, error) {
|
||||
// Step 1: Sign server authentication.
|
||||
fmt.Printf("[sign] Connecting to %s...\n", signAddr)
|
||||
sign, err := protocol.DoSign(signAddr, username, password)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("sign: %w", err)
|
||||
}
|
||||
fmt.Printf("[sign] OK — tokenID=%d, %d character(s), entrance=%s\n",
|
||||
sign.TokenID, len(sign.CharIDs), sign.EntranceAddr)
|
||||
|
||||
if len(sign.CharIDs) == 0 {
|
||||
return nil, fmt.Errorf("no characters on account")
|
||||
}
|
||||
|
||||
// Step 2: Entrance server — get server/channel list.
|
||||
fmt.Printf("[entrance] Connecting to %s...\n", sign.EntranceAddr)
|
||||
servers, err := protocol.DoEntrance(sign.EntranceAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("entrance: %w", err)
|
||||
}
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no channels available")
|
||||
}
|
||||
for i, s := range servers {
|
||||
fmt.Printf("[entrance] [%d] %s — %s:%d\n", i, s.Name, s.IP, s.Port)
|
||||
}
|
||||
|
||||
// Step 3: Connect to the first channel server.
|
||||
first := servers[0]
|
||||
channelAddr := fmt.Sprintf("%s:%d", first.IP, first.Port)
|
||||
fmt.Printf("[channel] Connecting to %s...\n", channelAddr)
|
||||
ch, err := protocol.ConnectChannel(channelAddr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("channel connect: %w", err)
|
||||
}
|
||||
|
||||
// Step 4: Send MSG_SYS_LOGIN.
|
||||
charID := sign.CharIDs[0]
|
||||
ack := ch.NextAckHandle()
|
||||
loginPkt := protocol.BuildLoginPacket(ack, charID, sign.TokenID, sign.TokenString)
|
||||
fmt.Printf("[channel] Sending MSG_SYS_LOGIN (charID=%d, ackHandle=%d)...\n", charID, ack)
|
||||
if err := ch.SendPacket(loginPkt); err != nil {
|
||||
ch.Close()
|
||||
return nil, fmt.Errorf("channel send login: %w", err)
|
||||
}
|
||||
|
||||
resp, err := ch.WaitForAck(ack, 10*time.Second)
|
||||
if err != nil {
|
||||
ch.Close()
|
||||
return nil, fmt.Errorf("channel login ack: %w", err)
|
||||
}
|
||||
if resp.ErrorCode != 0 {
|
||||
ch.Close()
|
||||
return nil, fmt.Errorf("channel login failed: error code %d", resp.ErrorCode)
|
||||
}
|
||||
fmt.Printf("[channel] Login ACK received (error=%d, %d bytes data)\n",
|
||||
resp.ErrorCode, len(resp.Data))
|
||||
|
||||
return &LoginResult{
|
||||
Sign: sign,
|
||||
Servers: servers,
|
||||
Channel: ch,
|
||||
}, nil
|
||||
}
|
||||
17
cmd/protbot/scenario/logout.go
Normal file
17
cmd/protbot/scenario/logout.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// Logout sends MSG_SYS_LOGOUT and closes the channel connection.
|
||||
func Logout(ch *protocol.ChannelConn) error {
|
||||
fmt.Println("[logout] Sending MSG_SYS_LOGOUT...")
|
||||
if err := ch.SendPacket(protocol.BuildLogoutPacket()); err != nil {
|
||||
ch.Close()
|
||||
return fmt.Errorf("logout send: %w", err)
|
||||
}
|
||||
return ch.Close()
|
||||
}
|
||||
31
cmd/protbot/scenario/quest.go
Normal file
31
cmd/protbot/scenario/quest.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// EnumerateQuests sends MSG_MHF_ENUMERATE_QUEST and returns the raw quest list data.
|
||||
func EnumerateQuests(ch *protocol.ChannelConn, world uint8, counter uint16) ([]byte, error) {
|
||||
ack := ch.NextAckHandle()
|
||||
pkt := protocol.BuildEnumerateQuestPacket(ack, world, counter, 0)
|
||||
fmt.Printf("[quest] Sending MSG_MHF_ENUMERATE_QUEST (world=%d, counter=%d, ackHandle=%d)...\n",
|
||||
world, counter, ack)
|
||||
if err := ch.SendPacket(pkt); err != nil {
|
||||
return nil, fmt.Errorf("enumerate quest send: %w", err)
|
||||
}
|
||||
|
||||
resp, err := ch.WaitForAck(ack, 15*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("enumerate quest ack: %w", err)
|
||||
}
|
||||
if resp.ErrorCode != 0 {
|
||||
return nil, fmt.Errorf("enumerate quest failed: error code %d", resp.ErrorCode)
|
||||
}
|
||||
fmt.Printf("[quest] ENUMERATE_QUEST ACK (error=%d, %d bytes data)\n",
|
||||
resp.ErrorCode, len(resp.Data))
|
||||
|
||||
return resp.Data, nil
|
||||
}
|
||||
50
cmd/protbot/scenario/session.go
Normal file
50
cmd/protbot/scenario/session.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// SetupSession performs the post-login session setup: ISSUE_LOGKEY, RIGHTS_RELOAD, LOADDATA.
|
||||
// Returns the loaddata response blob for inspection.
|
||||
func SetupSession(ch *protocol.ChannelConn, charID uint32) ([]byte, error) {
|
||||
// Step 1: Issue logkey.
|
||||
ack := ch.NextAckHandle()
|
||||
fmt.Printf("[session] Sending MSG_SYS_ISSUE_LOGKEY (ackHandle=%d)...\n", ack)
|
||||
if err := ch.SendPacket(protocol.BuildIssueLogkeyPacket(ack)); err != nil {
|
||||
return nil, fmt.Errorf("issue logkey send: %w", err)
|
||||
}
|
||||
resp, err := ch.WaitForAck(ack, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("issue logkey ack: %w", err)
|
||||
}
|
||||
fmt.Printf("[session] ISSUE_LOGKEY ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data))
|
||||
|
||||
// Step 2: Rights reload.
|
||||
ack = ch.NextAckHandle()
|
||||
fmt.Printf("[session] Sending MSG_SYS_RIGHTS_RELOAD (ackHandle=%d)...\n", ack)
|
||||
if err := ch.SendPacket(protocol.BuildRightsReloadPacket(ack)); err != nil {
|
||||
return nil, fmt.Errorf("rights reload send: %w", err)
|
||||
}
|
||||
resp, err = ch.WaitForAck(ack, 10*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("rights reload ack: %w", err)
|
||||
}
|
||||
fmt.Printf("[session] RIGHTS_RELOAD ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data))
|
||||
|
||||
// Step 3: Load save data.
|
||||
ack = ch.NextAckHandle()
|
||||
fmt.Printf("[session] Sending MSG_MHF_LOADDATA (ackHandle=%d)...\n", ack)
|
||||
if err := ch.SendPacket(protocol.BuildLoaddataPacket(ack)); err != nil {
|
||||
return nil, fmt.Errorf("loaddata send: %w", err)
|
||||
}
|
||||
resp, err = ch.WaitForAck(ack, 30*time.Second)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loaddata ack: %w", err)
|
||||
}
|
||||
fmt.Printf("[session] LOADDATA ACK (error=%d, %d bytes)\n", resp.ErrorCode, len(resp.Data))
|
||||
|
||||
return resp.Data, nil
|
||||
}
|
||||
111
cmd/protbot/scenario/stage.go
Normal file
111
cmd/protbot/scenario/stage.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package scenario
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"erupe-ce/common/byteframe"
|
||||
|
||||
"erupe-ce/cmd/protbot/protocol"
|
||||
)
|
||||
|
||||
// StageInfo holds a parsed stage entry from MSG_SYS_ENUMERATE_STAGE response.
|
||||
type StageInfo struct {
|
||||
ID string
|
||||
Reserved uint16
|
||||
Clients uint16
|
||||
Displayed uint16
|
||||
MaxPlayers uint16
|
||||
Flags uint8
|
||||
}
|
||||
|
||||
// EnterLobby enumerates available lobby stages and enters the first one.
|
||||
func EnterLobby(ch *protocol.ChannelConn) error {
|
||||
// Step 1: Enumerate stages with "sl1Ns" prefix (main lobby stages).
|
||||
ack := ch.NextAckHandle()
|
||||
enumPkt := protocol.BuildEnumerateStagePacket(ack, "sl1Ns")
|
||||
fmt.Printf("[stage] Sending MSG_SYS_ENUMERATE_STAGE (prefix=\"sl1Ns\", ackHandle=%d)...\n", ack)
|
||||
if err := ch.SendPacket(enumPkt); err != nil {
|
||||
return fmt.Errorf("enumerate stage send: %w", err)
|
||||
}
|
||||
|
||||
resp, err := ch.WaitForAck(ack, 10*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enumerate stage ack: %w", err)
|
||||
}
|
||||
if resp.ErrorCode != 0 {
|
||||
return fmt.Errorf("enumerate stage failed: error code %d", resp.ErrorCode)
|
||||
}
|
||||
|
||||
stages := parseEnumerateStageResponse(resp.Data)
|
||||
fmt.Printf("[stage] Found %d stage(s)\n", len(stages))
|
||||
for i, s := range stages {
|
||||
fmt.Printf("[stage] [%d] %s — %d/%d players, flags=0x%02X\n",
|
||||
i, s.ID, s.Clients, s.MaxPlayers, s.Flags)
|
||||
}
|
||||
|
||||
// Step 2: Enter the default lobby stage.
|
||||
// Even if no stages were enumerated, use the default stage ID.
|
||||
stageID := "sl1Ns200p0a0u0"
|
||||
if len(stages) > 0 {
|
||||
stageID = stages[0].ID
|
||||
}
|
||||
|
||||
ack = ch.NextAckHandle()
|
||||
enterPkt := protocol.BuildEnterStagePacket(ack, stageID)
|
||||
fmt.Printf("[stage] Sending MSG_SYS_ENTER_STAGE (stageID=%q, ackHandle=%d)...\n", stageID, ack)
|
||||
if err := ch.SendPacket(enterPkt); err != nil {
|
||||
return fmt.Errorf("enter stage send: %w", err)
|
||||
}
|
||||
|
||||
resp, err = ch.WaitForAck(ack, 10*time.Second)
|
||||
if err != nil {
|
||||
return fmt.Errorf("enter stage ack: %w", err)
|
||||
}
|
||||
if resp.ErrorCode != 0 {
|
||||
return fmt.Errorf("enter stage failed: error code %d", resp.ErrorCode)
|
||||
}
|
||||
fmt.Printf("[stage] Enter stage ACK received (error=%d)\n", resp.ErrorCode)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// parseEnumerateStageResponse parses the ACK data from MSG_SYS_ENUMERATE_STAGE.
|
||||
// Reference: Erupe server/channelserver/handlers_stage.go (handleMsgSysEnumerateStage)
|
||||
func parseEnumerateStageResponse(data []byte) []StageInfo {
|
||||
if len(data) < 2 {
|
||||
return nil
|
||||
}
|
||||
|
||||
bf := byteframe.NewByteFrameFromBytes(data)
|
||||
count := bf.ReadUint16()
|
||||
|
||||
var stages []StageInfo
|
||||
for i := uint16(0); i < count; i++ {
|
||||
s := StageInfo{}
|
||||
s.Reserved = bf.ReadUint16()
|
||||
s.Clients = bf.ReadUint16()
|
||||
s.Displayed = bf.ReadUint16()
|
||||
s.MaxPlayers = bf.ReadUint16()
|
||||
s.Flags = bf.ReadUint8()
|
||||
|
||||
// Stage ID is a pascal string with uint8 length prefix.
|
||||
strLen := bf.ReadUint8()
|
||||
if strLen > 0 {
|
||||
idBytes := bf.ReadBytes(uint(strLen))
|
||||
// Remove null terminator if present.
|
||||
if len(idBytes) > 0 && idBytes[len(idBytes)-1] == 0 {
|
||||
idBytes = idBytes[:len(idBytes)-1]
|
||||
}
|
||||
s.ID = string(idBytes)
|
||||
}
|
||||
|
||||
stages = append(stages, s)
|
||||
}
|
||||
|
||||
// After stages: uint32 timestamp, uint32 max clan members (we ignore these).
|
||||
_ = binary.BigEndian // suppress unused import if needed
|
||||
|
||||
return stages
|
||||
}
|
||||
@@ -219,34 +219,34 @@
|
||||
{
|
||||
"Name": "Newbie", "Description": "", "IP": "", "Type": 3, "Recommended": 2, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54001, "MaxPlayers": 100 },
|
||||
{ "Port": 54002, "MaxPlayers": 100 }
|
||||
{ "Port": 54001, "MaxPlayers": 100, "Enabled": true },
|
||||
{ "Port": 54002, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}, {
|
||||
"Name": "Normal", "Description": "", "IP": "", "Type": 1, "Recommended": 0, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54003, "MaxPlayers": 100 },
|
||||
{ "Port": 54004, "MaxPlayers": 100 }
|
||||
{ "Port": 54003, "MaxPlayers": 100, "Enabled": true },
|
||||
{ "Port": 54004, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}, {
|
||||
"Name": "Cities", "Description": "", "IP": "", "Type": 2, "Recommended": 0, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54005, "MaxPlayers": 100 }
|
||||
{ "Port": 54005, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}, {
|
||||
"Name": "Tavern", "Description": "", "IP": "", "Type": 4, "Recommended": 0, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54006, "MaxPlayers": 100 }
|
||||
{ "Port": 54006, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}, {
|
||||
"Name": "Return", "Description": "", "IP": "", "Type": 5, "Recommended": 0, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54007, "MaxPlayers": 100 }
|
||||
{ "Port": 54007, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}, {
|
||||
"Name": "MezFes", "Description": "", "IP": "", "Type": 6, "Recommended": 6, "AllowedClientFlags": 0,
|
||||
"Channels": [
|
||||
{ "Port": 54008, "MaxPlayers": 100 }
|
||||
{ "Port": 54008, "MaxPlayers": 100, "Enabled": true }
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -297,6 +297,15 @@ type EntranceChannelInfo struct {
|
||||
Port uint16
|
||||
MaxPlayers uint16
|
||||
CurrentPlayers uint16
|
||||
Enabled *bool // nil defaults to true for backward compatibility
|
||||
}
|
||||
|
||||
// IsEnabled returns whether this channel is enabled. Defaults to true if Enabled is nil.
|
||||
func (c *EntranceChannelInfo) IsEnabled() bool {
|
||||
if c.Enabled == nil {
|
||||
return true
|
||||
}
|
||||
return *c.Enabled
|
||||
}
|
||||
|
||||
var ErupeConfig *Config
|
||||
|
||||
@@ -536,6 +536,34 @@ func TestEntranceChannelInfo(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// TestEntranceChannelInfoIsEnabled tests the Enabled field and IsEnabled helper
|
||||
func TestEntranceChannelInfoIsEnabled(t *testing.T) {
|
||||
trueVal := true
|
||||
falseVal := false
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
enabled *bool
|
||||
want bool
|
||||
}{
|
||||
{"nil defaults to true", nil, true},
|
||||
{"explicit true", &trueVal, true},
|
||||
{"explicit false", &falseVal, false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info := EntranceChannelInfo{
|
||||
Port: 10001,
|
||||
Enabled: tt.enabled,
|
||||
}
|
||||
if got := info.IsEnabled(); got != tt.want {
|
||||
t.Errorf("IsEnabled() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDiscord verifies Discord struct
|
||||
func TestDiscord(t *testing.T) {
|
||||
discord := Discord{
|
||||
|
||||
@@ -13,7 +13,7 @@ services:
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- ./db-data/:/var/lib/postgresql/data/
|
||||
- ./db-data/:/var/lib/postgresql/
|
||||
- ../schemas/:/schemas/
|
||||
- ./init/setup.sh:/docker-entrypoint-initdb.d/setup.sh
|
||||
healthcheck:
|
||||
|
||||
38
main.go
38
main.go
@@ -16,6 +16,7 @@ import (
|
||||
"erupe-ce/server/discordbot"
|
||||
"erupe-ce/server/entranceserver"
|
||||
"erupe-ce/server/signserver"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
_ "github.com/lib/pq"
|
||||
@@ -129,11 +130,30 @@ func main() {
|
||||
}
|
||||
logger.Info("Database: Started successfully")
|
||||
|
||||
// Clear stale data
|
||||
if config.DebugOptions.ProxyPort == 0 {
|
||||
_ = db.MustExec("DELETE FROM sign_sessions")
|
||||
// Pre-compute all server IDs this instance will own, so we only
|
||||
// delete our own rows (safe for multi-instance on the same DB).
|
||||
var ownedServerIDs []string
|
||||
{
|
||||
si := 0
|
||||
for _, ee := range config.Entrance.Entries {
|
||||
ci := 0
|
||||
for range ee.Channels {
|
||||
sid := (4096 + si*256) + (16 + ci)
|
||||
ownedServerIDs = append(ownedServerIDs, fmt.Sprint(sid))
|
||||
ci++
|
||||
}
|
||||
si++
|
||||
}
|
||||
}
|
||||
|
||||
// Clear stale data scoped to this instance's server IDs
|
||||
if len(ownedServerIDs) > 0 {
|
||||
idList := strings.Join(ownedServerIDs, ",")
|
||||
if config.DebugOptions.ProxyPort == 0 {
|
||||
_ = db.MustExec("DELETE FROM sign_sessions WHERE server_id IN (" + idList + ")")
|
||||
}
|
||||
_ = db.MustExec("DELETE FROM servers WHERE server_id IN (" + idList + ")")
|
||||
}
|
||||
_ = db.MustExec("DELETE FROM servers")
|
||||
_ = db.MustExec(`UPDATE guild_characters SET treasure_hunt=NULL`)
|
||||
|
||||
// Clean the DB if the option is on.
|
||||
@@ -213,6 +233,12 @@ func main() {
|
||||
for j, ee := range config.Entrance.Entries {
|
||||
for i, ce := range ee.Channels {
|
||||
sid := (4096 + si*256) + (16 + ci)
|
||||
if !ce.IsEnabled() {
|
||||
logger.Info(fmt.Sprintf("Channel %d (%d): Disabled via config", count, ce.Port))
|
||||
ci++
|
||||
count++
|
||||
continue
|
||||
}
|
||||
c := *channelserver.NewServer(&channelserver.Config{
|
||||
ID: uint16(sid),
|
||||
Logger: logger.Named("channel-" + fmt.Sprint(count)),
|
||||
@@ -237,9 +263,9 @@ func main() {
|
||||
)
|
||||
channels = append(channels, &c)
|
||||
logger.Info(fmt.Sprintf("Channel %d (%d): Started successfully", count, ce.Port))
|
||||
ci++
|
||||
count++
|
||||
}
|
||||
ci++
|
||||
}
|
||||
ci = 0
|
||||
si++
|
||||
@@ -248,8 +274,10 @@ func main() {
|
||||
// Register all servers in DB
|
||||
_ = db.MustExec(channelQuery)
|
||||
|
||||
registry := channelserver.NewLocalChannelRegistry(channels)
|
||||
for _, c := range channels {
|
||||
c.Channels = channels
|
||||
c.Registry = registry
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
214
server/channelserver/channel_isolation_test.go
Normal file
214
server/channelserver/channel_isolation_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
_config "erupe-ce/config"
|
||||
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
// createListeningTestServer creates a channel server that binds to a real TCP port.
|
||||
// Port 0 lets the OS assign a free port. The server is automatically shut down
|
||||
// when the test completes.
|
||||
func createListeningTestServer(t *testing.T, id uint16) *Server {
|
||||
t.Helper()
|
||||
logger, _ := zap.NewDevelopment()
|
||||
s := NewServer(&Config{
|
||||
ID: id,
|
||||
Logger: logger,
|
||||
ErupeConfig: &_config.Config{
|
||||
DebugOptions: _config.DebugOptions{
|
||||
LogOutboundMessages: false,
|
||||
LogInboundMessages: false,
|
||||
},
|
||||
},
|
||||
})
|
||||
s.Port = 0 // Let OS pick a free port
|
||||
if err := s.Start(); err != nil {
|
||||
t.Fatalf("channel %d failed to start: %v", id, err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
s.Shutdown()
|
||||
time.Sleep(200 * time.Millisecond) // Let background goroutines and sessions exit.
|
||||
})
|
||||
return s
|
||||
}
|
||||
|
||||
// listenerAddr returns the address the server is listening on.
|
||||
func listenerAddr(s *Server) string {
|
||||
return s.listener.Addr().String()
|
||||
}
|
||||
|
||||
// TestChannelIsolation_ShutdownDoesNotAffectOthers verifies that shutting down
|
||||
// one channel server does not prevent other channels from accepting connections.
|
||||
func TestChannelIsolation_ShutdownDoesNotAffectOthers(t *testing.T) {
|
||||
ch1 := createListeningTestServer(t, 1)
|
||||
ch2 := createListeningTestServer(t, 2)
|
||||
ch3 := createListeningTestServer(t, 3)
|
||||
|
||||
addr1 := listenerAddr(ch1)
|
||||
addr2 := listenerAddr(ch2)
|
||||
addr3 := listenerAddr(ch3)
|
||||
|
||||
// Verify all three channels accept connections initially.
|
||||
for _, addr := range []string{addr1, addr2, addr3} {
|
||||
conn, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("initial connection to %s failed: %v", addr, err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// Shut down channel 1.
|
||||
ch1.Shutdown()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Channel 1 should refuse connections.
|
||||
_, err := net.DialTimeout("tcp", addr1, 500*time.Millisecond)
|
||||
if err == nil {
|
||||
t.Error("channel 1 should refuse connections after shutdown")
|
||||
}
|
||||
|
||||
// Channels 2 and 3 must still accept connections.
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
addr string
|
||||
}{
|
||||
{"channel 2", addr2},
|
||||
{"channel 3", addr3},
|
||||
} {
|
||||
conn, err := net.DialTimeout("tcp", tc.addr, time.Second)
|
||||
if err != nil {
|
||||
t.Errorf("%s should still accept connections after channel 1 shutdown, got: %v", tc.name, err)
|
||||
} else {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelIsolation_ListenerCloseDoesNotAffectOthers simulates an unexpected
|
||||
// listener failure (e.g. port conflict, OS-level error) on one channel and
|
||||
// verifies other channels continue operating.
|
||||
func TestChannelIsolation_ListenerCloseDoesNotAffectOthers(t *testing.T) {
|
||||
ch1 := createListeningTestServer(t, 1)
|
||||
ch2 := createListeningTestServer(t, 2)
|
||||
|
||||
addr2 := listenerAddr(ch2)
|
||||
|
||||
// Forcibly close channel 1's listener (simulating unexpected failure).
|
||||
ch1.listener.Close()
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
|
||||
// Channel 2 must still work.
|
||||
conn, err := net.DialTimeout("tcp", addr2, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("channel 2 should still accept connections after channel 1 listener closed: %v", err)
|
||||
}
|
||||
conn.Close()
|
||||
}
|
||||
|
||||
// TestChannelIsolation_SessionPanicDoesNotAffectChannel verifies that a panic
|
||||
// inside a session handler is recovered and does not crash the channel server.
|
||||
func TestChannelIsolation_SessionPanicDoesNotAffectChannel(t *testing.T) {
|
||||
ch := createListeningTestServer(t, 1)
|
||||
addr := listenerAddr(ch)
|
||||
|
||||
// Connect a client that will trigger a session.
|
||||
conn1, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("first connection failed: %v", err)
|
||||
}
|
||||
|
||||
// Send garbage data that will cause handlePacketGroup to hit the panic recovery.
|
||||
// The session's defer/recover should catch it without killing the channel.
|
||||
conn1.Write([]byte{0xFF, 0xFF, 0xFF, 0xFF})
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
conn1.Close()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// The channel should still accept new connections after the panic.
|
||||
conn2, err := net.DialTimeout("tcp", addr, time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("channel should still accept connections after session panic: %v", err)
|
||||
}
|
||||
conn2.Close()
|
||||
}
|
||||
|
||||
// TestChannelIsolation_CrossChannelRegistryAfterShutdown verifies that the
|
||||
// channel registry handles a shut-down channel gracefully during cross-channel
|
||||
// operations (search, find, disconnect).
|
||||
func TestChannelIsolation_CrossChannelRegistryAfterShutdown(t *testing.T) {
|
||||
channels := createTestChannels(3)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
// Add sessions to all channels.
|
||||
for i, ch := range channels {
|
||||
conn := &mockConn{}
|
||||
sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player")
|
||||
sess.stage = NewStage("sl1Ns200p0a0u0")
|
||||
ch.Lock()
|
||||
ch.sessions[conn] = sess
|
||||
ch.Unlock()
|
||||
}
|
||||
|
||||
// Simulate channel 1 shutting down by marking it and clearing sessions.
|
||||
channels[0].Lock()
|
||||
channels[0].isShuttingDown = true
|
||||
channels[0].sessions = make(map[net.Conn]*Session)
|
||||
channels[0].Unlock()
|
||||
|
||||
// Registry operations should still work for remaining channels.
|
||||
found := reg.FindSessionByCharID(2)
|
||||
if found == nil {
|
||||
t.Error("FindSessionByCharID(2) should find session on channel 2")
|
||||
}
|
||||
|
||||
found = reg.FindSessionByCharID(3)
|
||||
if found == nil {
|
||||
t.Error("FindSessionByCharID(3) should find session on channel 3")
|
||||
}
|
||||
|
||||
// Session from shut-down channel should not be found.
|
||||
found = reg.FindSessionByCharID(1)
|
||||
if found != nil {
|
||||
t.Error("FindSessionByCharID(1) should not find session on shut-down channel")
|
||||
}
|
||||
|
||||
// SearchSessions should return only sessions from live channels.
|
||||
results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("SearchSessions should return 2 results from live channels, got %d", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
// TestChannelIsolation_IndependentStages verifies that stages are per-channel
|
||||
// and one channel's stages don't leak into another.
|
||||
func TestChannelIsolation_IndependentStages(t *testing.T) {
|
||||
channels := createTestChannels(2)
|
||||
|
||||
stageName := "sl1Qs999p0a0u42"
|
||||
|
||||
// Add stage only to channel 1.
|
||||
channels[0].stagesLock.Lock()
|
||||
channels[0].stages[stageName] = NewStage(stageName)
|
||||
channels[0].stagesLock.Unlock()
|
||||
|
||||
// Channel 1 should have the stage.
|
||||
channels[0].stagesLock.RLock()
|
||||
_, ok1 := channels[0].stages[stageName]
|
||||
channels[0].stagesLock.RUnlock()
|
||||
if !ok1 {
|
||||
t.Error("channel 1 should have the stage")
|
||||
}
|
||||
|
||||
// Channel 2 should NOT have the stage.
|
||||
channels[1].stagesLock.RLock()
|
||||
_, ok2 := channels[1].stages[stageName]
|
||||
channels[1].stagesLock.RUnlock()
|
||||
if ok2 {
|
||||
t.Error("channel 2 should not have channel 1's stage")
|
||||
}
|
||||
}
|
||||
58
server/channelserver/channel_registry.go
Normal file
58
server/channelserver/channel_registry.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"erupe-ce/network/mhfpacket"
|
||||
"net"
|
||||
)
|
||||
|
||||
// ChannelRegistry abstracts cross-channel operations behind an interface.
|
||||
// The default LocalChannelRegistry wraps the in-process []*Server slice.
|
||||
// Future implementations may use DB/Redis/NATS for multi-process deployments.
|
||||
type ChannelRegistry interface {
|
||||
// Worldcast broadcasts a packet to all sessions across all channels.
|
||||
Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server)
|
||||
|
||||
// FindSessionByCharID looks up a session by character ID across all channels.
|
||||
FindSessionByCharID(charID uint32) *Session
|
||||
|
||||
// DisconnectUser disconnects all sessions belonging to the given character IDs.
|
||||
DisconnectUser(cids []uint32)
|
||||
|
||||
// FindChannelForStage searches all channels for a stage whose ID has the
|
||||
// given suffix and returns the owning channel's GlobalID, or "" if not found.
|
||||
FindChannelForStage(stageSuffix string) string
|
||||
|
||||
// SearchSessions searches sessions across all channels using a predicate,
|
||||
// returning up to max snapshot results.
|
||||
SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot
|
||||
|
||||
// SearchStages searches stages across all channels with a prefix filter,
|
||||
// returning up to max snapshot results.
|
||||
SearchStages(stagePrefix string, max int) []StageSnapshot
|
||||
|
||||
// NotifyMailToCharID finds the session for charID and sends a mail notification.
|
||||
NotifyMailToCharID(charID uint32, sender *Session, mail *Mail)
|
||||
}
|
||||
|
||||
// SessionSnapshot is an immutable copy of session data taken under lock.
|
||||
type SessionSnapshot struct {
|
||||
CharID uint32
|
||||
Name string
|
||||
StageID string
|
||||
ServerIP net.IP
|
||||
ServerPort uint16
|
||||
UserBinary3 []byte // Copy of userBinaryParts index 3
|
||||
}
|
||||
|
||||
// StageSnapshot is an immutable copy of stage data taken under lock.
|
||||
type StageSnapshot struct {
|
||||
ServerIP net.IP
|
||||
ServerPort uint16
|
||||
StageID string
|
||||
ClientCount int
|
||||
Reserved int
|
||||
MaxPlayers uint16
|
||||
RawBinData0 []byte
|
||||
RawBinData1 []byte
|
||||
RawBinData3 []byte
|
||||
}
|
||||
156
server/channelserver/channel_registry_local.go
Normal file
156
server/channelserver/channel_registry_local.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"erupe-ce/network/mhfpacket"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// LocalChannelRegistry is the in-process ChannelRegistry backed by []*Server.
|
||||
type LocalChannelRegistry struct {
|
||||
channels []*Server
|
||||
}
|
||||
|
||||
// NewLocalChannelRegistry creates a LocalChannelRegistry wrapping the given channels.
|
||||
func NewLocalChannelRegistry(channels []*Server) *LocalChannelRegistry {
|
||||
return &LocalChannelRegistry{channels: channels}
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) Worldcast(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) {
|
||||
for _, c := range r.channels {
|
||||
if c == ignoredChannel {
|
||||
continue
|
||||
}
|
||||
c.BroadcastMHF(pkt, ignoredSession)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) FindSessionByCharID(charID uint32) *Session {
|
||||
for _, c := range r.channels {
|
||||
c.Lock()
|
||||
for _, session := range c.sessions {
|
||||
if session.charID == charID {
|
||||
c.Unlock()
|
||||
return session
|
||||
}
|
||||
}
|
||||
c.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) DisconnectUser(cids []uint32) {
|
||||
for _, c := range r.channels {
|
||||
c.Lock()
|
||||
for _, session := range c.sessions {
|
||||
for _, cid := range cids {
|
||||
if session.charID == cid {
|
||||
_ = session.rawConn.Close()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) FindChannelForStage(stageSuffix string) string {
|
||||
for _, channel := range r.channels {
|
||||
channel.stagesLock.RLock()
|
||||
for id := range channel.stages {
|
||||
if strings.HasSuffix(id, stageSuffix) {
|
||||
gid := channel.GlobalID
|
||||
channel.stagesLock.RUnlock()
|
||||
return gid
|
||||
}
|
||||
}
|
||||
channel.stagesLock.RUnlock()
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) SearchSessions(predicate func(SessionSnapshot) bool, max int) []SessionSnapshot {
|
||||
var results []SessionSnapshot
|
||||
for _, c := range r.channels {
|
||||
if len(results) >= max {
|
||||
break
|
||||
}
|
||||
c.Lock()
|
||||
c.userBinaryPartsLock.RLock()
|
||||
for _, session := range c.sessions {
|
||||
if len(results) >= max {
|
||||
break
|
||||
}
|
||||
snap := SessionSnapshot{
|
||||
CharID: session.charID,
|
||||
Name: session.Name,
|
||||
ServerIP: net.ParseIP(c.IP).To4(),
|
||||
ServerPort: c.Port,
|
||||
}
|
||||
if session.stage != nil {
|
||||
snap.StageID = session.stage.id
|
||||
}
|
||||
ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}]
|
||||
if len(ub3) > 0 {
|
||||
snap.UserBinary3 = make([]byte, len(ub3))
|
||||
copy(snap.UserBinary3, ub3)
|
||||
}
|
||||
if predicate(snap) {
|
||||
results = append(results, snap)
|
||||
}
|
||||
}
|
||||
c.userBinaryPartsLock.RUnlock()
|
||||
c.Unlock()
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) SearchStages(stagePrefix string, max int) []StageSnapshot {
|
||||
var results []StageSnapshot
|
||||
for _, c := range r.channels {
|
||||
if len(results) >= max {
|
||||
break
|
||||
}
|
||||
c.stagesLock.RLock()
|
||||
for _, stage := range c.stages {
|
||||
if len(results) >= max {
|
||||
break
|
||||
}
|
||||
if !strings.HasPrefix(stage.id, stagePrefix) {
|
||||
continue
|
||||
}
|
||||
stage.RLock()
|
||||
bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}]
|
||||
bin0Copy := make([]byte, len(bin0))
|
||||
copy(bin0Copy, bin0)
|
||||
bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}]
|
||||
bin1Copy := make([]byte, len(bin1))
|
||||
copy(bin1Copy, bin1)
|
||||
bin3 := stage.rawBinaryData[stageBinaryKey{1, 3}]
|
||||
bin3Copy := make([]byte, len(bin3))
|
||||
copy(bin3Copy, bin3)
|
||||
|
||||
results = append(results, StageSnapshot{
|
||||
ServerIP: net.ParseIP(c.IP).To4(),
|
||||
ServerPort: c.Port,
|
||||
StageID: stage.id,
|
||||
ClientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
||||
Reserved: len(stage.reservedClientSlots),
|
||||
MaxPlayers: stage.maxPlayers,
|
||||
RawBinData0: bin0Copy,
|
||||
RawBinData1: bin1Copy,
|
||||
RawBinData3: bin3Copy,
|
||||
})
|
||||
stage.RUnlock()
|
||||
}
|
||||
c.stagesLock.RUnlock()
|
||||
}
|
||||
return results
|
||||
}
|
||||
|
||||
func (r *LocalChannelRegistry) NotifyMailToCharID(charID uint32, sender *Session, mail *Mail) {
|
||||
session := r.FindSessionByCharID(charID)
|
||||
if session != nil {
|
||||
SendMailNotification(sender, mail, session)
|
||||
}
|
||||
}
|
||||
190
server/channelserver/channel_registry_test.go
Normal file
190
server/channelserver/channel_registry_test.go
Normal file
@@ -0,0 +1,190 @@
|
||||
package channelserver
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func createTestChannels(count int) []*Server {
|
||||
channels := make([]*Server, count)
|
||||
for i := 0; i < count; i++ {
|
||||
s := createTestServer()
|
||||
s.ID = uint16(0x1010 + i)
|
||||
s.IP = "10.0.0.1"
|
||||
s.Port = uint16(54001 + i)
|
||||
s.GlobalID = "0101"
|
||||
s.userBinaryParts = make(map[userBinaryPartID][]byte)
|
||||
channels[i] = s
|
||||
}
|
||||
return channels
|
||||
}
|
||||
|
||||
func TestLocalRegistryFindSessionByCharID(t *testing.T) {
|
||||
channels := createTestChannels(2)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
conn1 := &mockConn{}
|
||||
sess1 := createTestSessionForServer(channels[0], conn1, 100, "Alice")
|
||||
channels[0].Lock()
|
||||
channels[0].sessions[conn1] = sess1
|
||||
channels[0].Unlock()
|
||||
|
||||
conn2 := &mockConn{}
|
||||
sess2 := createTestSessionForServer(channels[1], conn2, 200, "Bob")
|
||||
channels[1].Lock()
|
||||
channels[1].sessions[conn2] = sess2
|
||||
channels[1].Unlock()
|
||||
|
||||
// Find on first channel
|
||||
found := reg.FindSessionByCharID(100)
|
||||
if found == nil || found.charID != 100 {
|
||||
t.Errorf("FindSessionByCharID(100) = %v, want session with charID 100", found)
|
||||
}
|
||||
|
||||
// Find on second channel
|
||||
found = reg.FindSessionByCharID(200)
|
||||
if found == nil || found.charID != 200 {
|
||||
t.Errorf("FindSessionByCharID(200) = %v, want session with charID 200", found)
|
||||
}
|
||||
|
||||
// Not found
|
||||
found = reg.FindSessionByCharID(999)
|
||||
if found != nil {
|
||||
t.Errorf("FindSessionByCharID(999) = %v, want nil", found)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalRegistryFindChannelForStage(t *testing.T) {
|
||||
channels := createTestChannels(2)
|
||||
channels[0].GlobalID = "0101"
|
||||
channels[1].GlobalID = "0102"
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
channels[1].stagesLock.Lock()
|
||||
channels[1].stages["sl2Qs123p0a0u42"] = NewStage("sl2Qs123p0a0u42")
|
||||
channels[1].stagesLock.Unlock()
|
||||
|
||||
gid := reg.FindChannelForStage("u42")
|
||||
if gid != "0102" {
|
||||
t.Errorf("FindChannelForStage(u42) = %q, want %q", gid, "0102")
|
||||
}
|
||||
|
||||
gid = reg.FindChannelForStage("u999")
|
||||
if gid != "" {
|
||||
t.Errorf("FindChannelForStage(u999) = %q, want empty", gid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalRegistryDisconnectUser(t *testing.T) {
|
||||
channels := createTestChannels(1)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
conn := &mockConn{}
|
||||
sess := createTestSessionForServer(channels[0], conn, 42, "Target")
|
||||
channels[0].Lock()
|
||||
channels[0].sessions[conn] = sess
|
||||
channels[0].Unlock()
|
||||
|
||||
reg.DisconnectUser([]uint32{42})
|
||||
|
||||
if !conn.WasClosed() {
|
||||
t.Error("DisconnectUser should have closed the connection for charID 42")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalRegistrySearchSessions(t *testing.T) {
|
||||
channels := createTestChannels(2)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
// Add 3 sessions across 2 channels
|
||||
for i, ch := range channels {
|
||||
conn := &mockConn{}
|
||||
sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player")
|
||||
sess.stage = NewStage("sl1Ns200p0a0u0")
|
||||
ch.Lock()
|
||||
ch.sessions[conn] = sess
|
||||
ch.Unlock()
|
||||
}
|
||||
conn3 := &mockConn{}
|
||||
sess3 := createTestSessionForServer(channels[0], conn3, 3, "Player")
|
||||
sess3.stage = NewStage("sl1Ns200p0a0u0")
|
||||
channels[0].Lock()
|
||||
channels[0].sessions[conn3] = sess3
|
||||
channels[0].Unlock()
|
||||
|
||||
// Search all
|
||||
results := reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 10)
|
||||
if len(results) != 3 {
|
||||
t.Errorf("SearchSessions(all) returned %d results, want 3", len(results))
|
||||
}
|
||||
|
||||
// Search with max
|
||||
results = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 2)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("SearchSessions(max=2) returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
// Search with predicate
|
||||
results = reg.SearchSessions(func(s SessionSnapshot) bool { return s.CharID == 1 }, 10)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("SearchSessions(charID==1) returned %d results, want 1", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalRegistrySearchStages(t *testing.T) {
|
||||
channels := createTestChannels(1)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
channels[0].stagesLock.Lock()
|
||||
channels[0].stages["sl2Ls210test1"] = NewStage("sl2Ls210test1")
|
||||
channels[0].stages["sl2Ls210test2"] = NewStage("sl2Ls210test2")
|
||||
channels[0].stages["sl1Ns200other"] = NewStage("sl1Ns200other")
|
||||
channels[0].stagesLock.Unlock()
|
||||
|
||||
results := reg.SearchStages("sl2Ls210", 10)
|
||||
if len(results) != 2 {
|
||||
t.Errorf("SearchStages(sl2Ls210) returned %d results, want 2", len(results))
|
||||
}
|
||||
|
||||
results = reg.SearchStages("sl2Ls210", 1)
|
||||
if len(results) != 1 {
|
||||
t.Errorf("SearchStages(sl2Ls210, max=1) returned %d results, want 1", len(results))
|
||||
}
|
||||
}
|
||||
|
||||
func TestLocalRegistryConcurrentAccess(t *testing.T) {
|
||||
channels := createTestChannels(2)
|
||||
reg := NewLocalChannelRegistry(channels)
|
||||
|
||||
// Populate some sessions
|
||||
for _, ch := range channels {
|
||||
for i := 0; i < 10; i++ {
|
||||
conn := &mockConn{remoteAddr: &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 50000 + i}}
|
||||
sess := createTestSessionForServer(ch, conn, uint32(i+1), "Player")
|
||||
sess.stage = NewStage("sl1Ns200p0a0u0")
|
||||
ch.Lock()
|
||||
ch.sessions[conn] = sess
|
||||
ch.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
// Run concurrent operations
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < 50; i++ {
|
||||
wg.Add(3)
|
||||
go func(id int) {
|
||||
defer wg.Done()
|
||||
_ = reg.FindSessionByCharID(uint32(id%10 + 1))
|
||||
}(i)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = reg.FindChannelForStage("u0")
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = reg.SearchSessions(func(s SessionSnapshot) bool { return true }, 5)
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
@@ -48,6 +48,13 @@ func GetCharacterSaveData(s *Session, charID uint32) (*CharacterSaveData, error)
|
||||
}
|
||||
|
||||
func (save *CharacterSaveData) Save(s *Session) {
|
||||
if save.decompSave == nil {
|
||||
s.logger.Warn("No decompressed save data, skipping save",
|
||||
zap.Uint32("charID", save.CharID),
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if !s.kqfOverride {
|
||||
s.kqf = save.KQF
|
||||
} else {
|
||||
|
||||
@@ -304,11 +304,26 @@ func handleMsgMhfOperateGuildMember(s *Session, p mhfpacket.MHFPacket) {
|
||||
doAckSimpleFail(s, pkt.AckHandle, make([]byte, 4))
|
||||
} else {
|
||||
_ = mail.Send(s, nil)
|
||||
for _, channel := range s.server.Channels {
|
||||
for _, session := range channel.sessions {
|
||||
if session.charID == pkt.CharID {
|
||||
SendMailNotification(s, &mail, session)
|
||||
if s.server.Registry != nil {
|
||||
s.server.Registry.NotifyMailToCharID(pkt.CharID, s, &mail)
|
||||
} else {
|
||||
// Fallback: find the target session under lock, then notify outside the lock.
|
||||
var targetSession *Session
|
||||
for _, channel := range s.server.Channels {
|
||||
channel.Lock()
|
||||
for _, session := range channel.sessions {
|
||||
if session.charID == pkt.CharID {
|
||||
targetSession = session
|
||||
break
|
||||
}
|
||||
}
|
||||
channel.Unlock()
|
||||
if targetSession != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetSession != nil {
|
||||
SendMailNotification(s, &mail, targetSession)
|
||||
}
|
||||
}
|
||||
doAckSimpleSucceed(s, pkt.AckHandle, make([]byte, 4))
|
||||
|
||||
@@ -293,14 +293,16 @@ func logoutPlayer(s *Session) {
|
||||
}
|
||||
|
||||
// Update sign sessions and server player count
|
||||
_, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if s.server.db != nil {
|
||||
_, err := s.server.db.Exec("UPDATE sign_sessions SET server_id=NULL, char_id=NULL WHERE token=$1", s.token)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to clear sign session", zap.Error(err))
|
||||
}
|
||||
|
||||
_, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
_, err = s.server.db.Exec("UPDATE servers SET current_players=$1 WHERE server_id=$2", len(s.server.sessions), s.server.ID)
|
||||
if err != nil {
|
||||
s.logger.Error("Failed to update player count", zap.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
if s.stage == nil {
|
||||
@@ -399,11 +401,17 @@ func handleMsgSysEcho(s *Session, p mhfpacket.MHFPacket) {}
|
||||
func handleMsgSysLockGlobalSema(s *Session, p mhfpacket.MHFPacket) {
|
||||
pkt := p.(*mhfpacket.MsgSysLockGlobalSema)
|
||||
var sgid string
|
||||
for _, channel := range s.server.Channels {
|
||||
for id := range channel.stages {
|
||||
if strings.HasSuffix(id, pkt.UserIDString) {
|
||||
sgid = channel.GlobalID
|
||||
if s.server.Registry != nil {
|
||||
sgid = s.server.Registry.FindChannelForStage(pkt.UserIDString)
|
||||
} else {
|
||||
for _, channel := range s.server.Channels {
|
||||
channel.stagesLock.RLock()
|
||||
for id := range channel.stages {
|
||||
if strings.HasSuffix(id, pkt.UserIDString) {
|
||||
sgid = channel.GlobalID
|
||||
}
|
||||
}
|
||||
channel.stagesLock.RUnlock()
|
||||
}
|
||||
}
|
||||
bf := byteframe.NewByteFrame()
|
||||
@@ -468,7 +476,23 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
||||
resp.WriteUint16(0)
|
||||
switch pkt.SearchType {
|
||||
case 1, 2, 3: // usersearchidx, usersearchname, lobbysearchname
|
||||
// Snapshot matching sessions under lock, then build response outside locks.
|
||||
type sessionResult struct {
|
||||
charID uint32
|
||||
name []byte
|
||||
stageID []byte
|
||||
ip net.IP
|
||||
port uint16
|
||||
userBin3 []byte
|
||||
}
|
||||
var results []sessionResult
|
||||
|
||||
for _, c := range s.server.Channels {
|
||||
if count == maxResults {
|
||||
break
|
||||
}
|
||||
c.Lock()
|
||||
c.userBinaryPartsLock.RLock()
|
||||
for _, session := range c.sessions {
|
||||
if count == maxResults {
|
||||
break
|
||||
@@ -483,31 +507,45 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
sessionName := stringsupport.UTF8ToSJIS(session.Name)
|
||||
sessionStage := stringsupport.UTF8ToSJIS(session.stage.id)
|
||||
if !local {
|
||||
resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4()))
|
||||
} else {
|
||||
resp.WriteUint32(0x0100007F)
|
||||
}
|
||||
resp.WriteUint16(c.Port)
|
||||
resp.WriteUint32(session.charID)
|
||||
resp.WriteUint8(uint8(len(sessionStage) + 1))
|
||||
resp.WriteUint8(uint8(len(sessionName) + 1))
|
||||
resp.WriteUint16(uint16(len(c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}])))
|
||||
|
||||
// TODO: This case might be <=G2
|
||||
if _config.ErupeConfig.RealClientMode <= _config.G1 {
|
||||
resp.WriteBytes(make([]byte, 8))
|
||||
} else {
|
||||
resp.WriteBytes(make([]byte, 40))
|
||||
}
|
||||
resp.WriteBytes(make([]byte, 8))
|
||||
|
||||
resp.WriteNullTerminatedBytes(sessionStage)
|
||||
resp.WriteNullTerminatedBytes(sessionName)
|
||||
resp.WriteBytes(c.userBinaryParts[userBinaryPartID{session.charID, 3}])
|
||||
ub3 := c.userBinaryParts[userBinaryPartID{charID: session.charID, index: 3}]
|
||||
ub3Copy := make([]byte, len(ub3))
|
||||
copy(ub3Copy, ub3)
|
||||
results = append(results, sessionResult{
|
||||
charID: session.charID,
|
||||
name: stringsupport.UTF8ToSJIS(session.Name),
|
||||
stageID: stringsupport.UTF8ToSJIS(session.stage.id),
|
||||
ip: net.ParseIP(c.IP).To4(),
|
||||
port: c.Port,
|
||||
userBin3: ub3Copy,
|
||||
})
|
||||
}
|
||||
c.userBinaryPartsLock.RUnlock()
|
||||
c.Unlock()
|
||||
}
|
||||
|
||||
for _, r := range results {
|
||||
if !local {
|
||||
resp.WriteUint32(binary.LittleEndian.Uint32(r.ip))
|
||||
} else {
|
||||
resp.WriteUint32(0x0100007F)
|
||||
}
|
||||
resp.WriteUint16(r.port)
|
||||
resp.WriteUint32(r.charID)
|
||||
resp.WriteUint8(uint8(len(r.stageID) + 1))
|
||||
resp.WriteUint8(uint8(len(r.name) + 1))
|
||||
resp.WriteUint16(uint16(len(r.userBin3)))
|
||||
|
||||
// TODO: This case might be <=G2
|
||||
if _config.ErupeConfig.RealClientMode <= _config.G1 {
|
||||
resp.WriteBytes(make([]byte, 8))
|
||||
} else {
|
||||
resp.WriteBytes(make([]byte, 40))
|
||||
}
|
||||
resp.WriteBytes(make([]byte, 8))
|
||||
|
||||
resp.WriteNullTerminatedBytes(r.stageID)
|
||||
resp.WriteNullTerminatedBytes(r.name)
|
||||
resp.WriteBytes(r.userBin3)
|
||||
}
|
||||
case 4: // lobbysearch
|
||||
type FindPartyParams struct {
|
||||
@@ -594,12 +632,31 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Snapshot matching stages under lock, then build response outside locks.
|
||||
type stageResult struct {
|
||||
ip net.IP
|
||||
port uint16
|
||||
clientCount int
|
||||
reserved int
|
||||
maxPlayers uint16
|
||||
stageID string
|
||||
stageData []int16
|
||||
rawBinData0 []byte
|
||||
rawBinData1 []byte
|
||||
}
|
||||
var stageResults []stageResult
|
||||
|
||||
for _, c := range s.server.Channels {
|
||||
if count == maxResults {
|
||||
break
|
||||
}
|
||||
c.stagesLock.RLock()
|
||||
for _, stage := range c.stages {
|
||||
if count == maxResults {
|
||||
break
|
||||
}
|
||||
if strings.HasPrefix(stage.id, findPartyParams.StagePrefix) {
|
||||
stage.RLock()
|
||||
sb3 := byteframe.NewByteFrameFromBytes(stage.rawBinaryData[stageBinaryKey{1, 3}])
|
||||
_, _ = sb3.Seek(4, 0)
|
||||
|
||||
@@ -621,6 +678,7 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
||||
|
||||
if findPartyParams.RankRestriction >= 0 {
|
||||
if stageData[0] > findPartyParams.RankRestriction {
|
||||
stage.RUnlock()
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -634,47 +692,72 @@ func handleMsgMhfTransitMessage(s *Session, p mhfpacket.MHFPacket) {
|
||||
}
|
||||
}
|
||||
if !hasTarget {
|
||||
stage.RUnlock()
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// Copy binary data under lock
|
||||
bin0 := stage.rawBinaryData[stageBinaryKey{1, 0}]
|
||||
bin0Copy := make([]byte, len(bin0))
|
||||
copy(bin0Copy, bin0)
|
||||
bin1 := stage.rawBinaryData[stageBinaryKey{1, 1}]
|
||||
bin1Copy := make([]byte, len(bin1))
|
||||
copy(bin1Copy, bin1)
|
||||
|
||||
count++
|
||||
if !local {
|
||||
resp.WriteUint32(binary.LittleEndian.Uint32(net.ParseIP(c.IP).To4()))
|
||||
} else {
|
||||
resp.WriteUint32(0x0100007F)
|
||||
}
|
||||
resp.WriteUint16(c.Port)
|
||||
|
||||
resp.WriteUint16(0) // Static?
|
||||
resp.WriteUint16(0) // Unk, [0 1 2]
|
||||
resp.WriteUint16(uint16(len(stage.clients) + len(stage.reservedClientSlots)))
|
||||
resp.WriteUint16(stage.maxPlayers)
|
||||
// TODO: Retail returned the number of clients in quests, not workshop/my series
|
||||
resp.WriteUint16(uint16(len(stage.reservedClientSlots)))
|
||||
|
||||
resp.WriteUint8(0) // Static?
|
||||
resp.WriteUint8(uint8(stage.maxPlayers))
|
||||
resp.WriteUint8(1) // Static?
|
||||
resp.WriteUint8(uint8(len(stage.id) + 1))
|
||||
resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 0}])))
|
||||
resp.WriteUint8(uint8(len(stage.rawBinaryData[stageBinaryKey{1, 1}])))
|
||||
|
||||
for i := range stageData {
|
||||
if _config.ErupeConfig.RealClientMode >= _config.Z1 {
|
||||
resp.WriteInt16(stageData[i])
|
||||
} else {
|
||||
resp.WriteInt8(int8(stageData[i]))
|
||||
}
|
||||
}
|
||||
resp.WriteUint8(0) // Unk
|
||||
resp.WriteUint8(0) // Unk
|
||||
|
||||
resp.WriteNullTerminatedBytes([]byte(stage.id))
|
||||
resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 0}])
|
||||
resp.WriteBytes(stage.rawBinaryData[stageBinaryKey{1, 1}])
|
||||
stageResults = append(stageResults, stageResult{
|
||||
ip: net.ParseIP(c.IP).To4(),
|
||||
port: c.Port,
|
||||
clientCount: len(stage.clients) + len(stage.reservedClientSlots),
|
||||
reserved: len(stage.reservedClientSlots),
|
||||
maxPlayers: stage.maxPlayers,
|
||||
stageID: stage.id,
|
||||
stageData: stageData,
|
||||
rawBinData0: bin0Copy,
|
||||
rawBinData1: bin1Copy,
|
||||
})
|
||||
stage.RUnlock()
|
||||
}
|
||||
}
|
||||
c.stagesLock.RUnlock()
|
||||
}
|
||||
|
||||
for _, sr := range stageResults {
|
||||
if !local {
|
||||
resp.WriteUint32(binary.LittleEndian.Uint32(sr.ip))
|
||||
} else {
|
||||
resp.WriteUint32(0x0100007F)
|
||||
}
|
||||
resp.WriteUint16(sr.port)
|
||||
|
||||
resp.WriteUint16(0) // Static?
|
||||
resp.WriteUint16(0) // Unk, [0 1 2]
|
||||
resp.WriteUint16(uint16(sr.clientCount))
|
||||
resp.WriteUint16(sr.maxPlayers)
|
||||
// TODO: Retail returned the number of clients in quests, not workshop/my series
|
||||
resp.WriteUint16(uint16(sr.reserved))
|
||||
|
||||
resp.WriteUint8(0) // Static?
|
||||
resp.WriteUint8(uint8(sr.maxPlayers))
|
||||
resp.WriteUint8(1) // Static?
|
||||
resp.WriteUint8(uint8(len(sr.stageID) + 1))
|
||||
resp.WriteUint8(uint8(len(sr.rawBinData0)))
|
||||
resp.WriteUint8(uint8(len(sr.rawBinData1)))
|
||||
|
||||
for i := range sr.stageData {
|
||||
if _config.ErupeConfig.RealClientMode >= _config.Z1 {
|
||||
resp.WriteInt16(sr.stageData[i])
|
||||
} else {
|
||||
resp.WriteInt8(int8(sr.stageData[i]))
|
||||
}
|
||||
}
|
||||
resp.WriteUint8(0) // Unk
|
||||
resp.WriteUint8(0) // Unk
|
||||
|
||||
resp.WriteNullTerminatedBytes([]byte(sr.stageID))
|
||||
resp.WriteBytes(sr.rawBinData0)
|
||||
resp.WriteBytes(sr.rawBinData1)
|
||||
}
|
||||
}
|
||||
_, _ = resp.Seek(0, io.SeekStart)
|
||||
|
||||
@@ -37,6 +37,7 @@ type userBinaryPartID struct {
|
||||
type Server struct {
|
||||
sync.Mutex
|
||||
Channels []*Server
|
||||
Registry ChannelRegistry
|
||||
ID uint16
|
||||
GlobalID string
|
||||
IP string
|
||||
@@ -49,6 +50,7 @@ type Server struct {
|
||||
sessions map[net.Conn]*Session
|
||||
listener net.Listener // Listener that is created when Server.Start is called.
|
||||
isShuttingDown bool
|
||||
done chan struct{} // Closed on Shutdown to wake background goroutines.
|
||||
|
||||
stagesLock sync.RWMutex
|
||||
stages map[string]*Stage
|
||||
@@ -90,6 +92,7 @@ func NewServer(config *Config) *Server {
|
||||
erupeConfig: config.ErupeConfig,
|
||||
acceptConns: make(chan net.Conn),
|
||||
deleteConns: make(chan net.Conn),
|
||||
done: make(chan struct{}),
|
||||
sessions: make(map[net.Conn]*Session),
|
||||
stages: make(map[string]*Stage),
|
||||
userBinaryParts: make(map[userBinaryPartID][]byte),
|
||||
@@ -155,19 +158,23 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Shutdown tries to shut down the server gracefully.
|
||||
// Shutdown tries to shut down the server gracefully. Safe to call multiple times.
|
||||
func (s *Server) Shutdown() {
|
||||
s.Lock()
|
||||
alreadyShutDown := s.isShuttingDown
|
||||
s.isShuttingDown = true
|
||||
s.Unlock()
|
||||
|
||||
if alreadyShutDown {
|
||||
return
|
||||
}
|
||||
|
||||
close(s.done)
|
||||
|
||||
if s.listener != nil {
|
||||
_ = s.listener.Close()
|
||||
}
|
||||
|
||||
if s.acceptConns != nil {
|
||||
close(s.acceptConns)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) acceptClients() {
|
||||
@@ -185,25 +192,21 @@ func (s *Server) acceptClients() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
s.acceptConns <- conn
|
||||
select {
|
||||
case s.acceptConns <- conn:
|
||||
case <-s.done:
|
||||
_ = conn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) manageSessions() {
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case newConn := <-s.acceptConns:
|
||||
// Gracefully handle acceptConns channel closing.
|
||||
if newConn == nil {
|
||||
s.Lock()
|
||||
shutdown := s.isShuttingDown
|
||||
s.Unlock()
|
||||
|
||||
if shutdown {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
session := NewSession(s, newConn)
|
||||
|
||||
s.Lock()
|
||||
@@ -235,15 +238,28 @@ func (s *Server) getObjectId() uint16 {
|
||||
}
|
||||
|
||||
func (s *Server) invalidateSessions() {
|
||||
for !s.isShuttingDown {
|
||||
ticker := time.NewTicker(10 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.done:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
s.Lock()
|
||||
var timedOut []*Session
|
||||
for _, sess := range s.sessions {
|
||||
if time.Since(sess.lastPacket) > time.Second*time.Duration(30) {
|
||||
s.logger.Info("session timeout", zap.String("Name", sess.Name))
|
||||
logoutPlayer(sess)
|
||||
timedOut = append(timedOut, sess)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Second * 10)
|
||||
s.Unlock()
|
||||
|
||||
for _, sess := range timedOut {
|
||||
s.logger.Info("session timeout", zap.String("Name", sess.Name))
|
||||
logoutPlayer(sess)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -271,6 +287,10 @@ func (s *Server) BroadcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session)
|
||||
|
||||
// WorldcastMHF broadcasts a packet to all sessions across all channel servers.
|
||||
func (s *Server) WorldcastMHF(pkt mhfpacket.MHFPacket, ignoredSession *Session, ignoredChannel *Server) {
|
||||
if s.Registry != nil {
|
||||
s.Registry.Worldcast(pkt, ignoredSession, ignoredChannel)
|
||||
return
|
||||
}
|
||||
for _, c := range s.Channels {
|
||||
if c == ignoredChannel {
|
||||
continue
|
||||
@@ -317,12 +337,18 @@ func (s *Server) DiscordScreenShotSend(charName string, title string, descriptio
|
||||
|
||||
// FindSessionByCharID looks up a session by character ID across all channels.
|
||||
func (s *Server) FindSessionByCharID(charID uint32) *Session {
|
||||
if s.Registry != nil {
|
||||
return s.Registry.FindSessionByCharID(charID)
|
||||
}
|
||||
for _, c := range s.Channels {
|
||||
c.Lock()
|
||||
for _, session := range c.sessions {
|
||||
if session.charID == charID {
|
||||
c.Unlock()
|
||||
return session
|
||||
}
|
||||
}
|
||||
c.Unlock()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -341,7 +367,12 @@ func (s *Server) DisconnectUser(uid uint32) {
|
||||
cids = append(cids, cid)
|
||||
}
|
||||
}
|
||||
if s.Registry != nil {
|
||||
s.Registry.DisconnectUser(cids)
|
||||
return
|
||||
}
|
||||
for _, c := range s.Channels {
|
||||
c.Lock()
|
||||
for _, session := range c.sessions {
|
||||
for _, cid := range cids {
|
||||
if session.charID == cid {
|
||||
@@ -350,6 +381,7 @@ func (s *Server) DisconnectUser(uid uint32) {
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"erupe-ce/common/mhfcourse"
|
||||
_config "erupe-ce/config"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -172,7 +171,7 @@ func (s *Session) sendLoop() {
|
||||
s.logger.Warn("Failed to send packet", zap.Error(err))
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond)
|
||||
time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +214,7 @@ func (s *Session) recvLoop() {
|
||||
return
|
||||
}
|
||||
s.handlePacketGroup(pkt)
|
||||
time.Sleep(time.Duration(_config.ErupeConfig.LoopDelay) * time.Millisecond)
|
||||
time.Sleep(time.Duration(s.server.erupeConfig.LoopDelay) * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user