From e929346bf3487012831c1ff7c7a6257dc4edc2be Mon Sep 17 00:00:00 2001 From: Houmgaor Date: Fri, 30 Jan 2026 00:19:27 +0100 Subject: [PATCH] test: add unit tests for core packages Add comprehensive test coverage for: - common/token: token generation and RNG tests - common/stringsupport: string encoding, CSV operations - common/byteframe: binary read/write operations - common/mhfcourse: course/subscription logic - network/crypt_packet: packet header parsing - network/binpacket: binary packet round-trips - network/mhfpacket: packet interface and opcode mapping - config: configuration struct and loading - server/entranceserver: response building - server/signserver: response ID constants - server/signv2server: HTTP endpoint validation - server/channelserver: session, semaphore, and handler tests All tests pass with race detector enabled. --- common/byteframe/byteframe_test.go | 467 +++++++++++++++++ common/mhfcourse/mhfcourse_test.go | 288 +++++++++++ common/stringsupport/string_convert_test.go | 341 ++++++++++++ common/token/token_test.go | 120 +++++ config/config_test.go | 542 ++++++++++++++++++++ network/binpacket/binpacket_test.go | 401 +++++++++++++++ network/crypt_packet_test.go | 234 +++++++++ network/mhfpacket/mhfpacket_test.go | 463 +++++++++++++++++ server/channelserver/handlers_test.go | 268 ++++++++++ server/channelserver/sys_semaphore_test.go | 384 ++++++++++++++ server/channelserver/sys_session_test.go | 375 ++++++++++++++ server/entranceserver/make_resp_test.go | 139 +++++ server/signserver/sign_server_test.go | 212 ++++++++ server/signv2server/endpoints_test.go | 349 +++++++++++++ 14 files changed, 4583 insertions(+) create mode 100644 common/byteframe/byteframe_test.go create mode 100644 common/mhfcourse/mhfcourse_test.go create mode 100644 common/stringsupport/string_convert_test.go create mode 100644 common/token/token_test.go create mode 100644 config/config_test.go create mode 100644 network/binpacket/binpacket_test.go create mode 100644 network/crypt_packet_test.go create mode 100644 network/mhfpacket/mhfpacket_test.go create mode 100644 server/channelserver/handlers_test.go create mode 100644 server/channelserver/sys_semaphore_test.go create mode 100644 server/channelserver/sys_session_test.go create mode 100644 server/entranceserver/make_resp_test.go create mode 100644 server/signserver/sign_server_test.go create mode 100644 server/signv2server/endpoints_test.go diff --git a/common/byteframe/byteframe_test.go b/common/byteframe/byteframe_test.go new file mode 100644 index 000000000..fdd19f73d --- /dev/null +++ b/common/byteframe/byteframe_test.go @@ -0,0 +1,467 @@ +package byteframe + +import ( + "io" + "math" + "testing" +) + +func TestNewByteFrame(t *testing.T) { + bf := NewByteFrame() + + if bf == nil { + t.Fatal("NewByteFrame() returned nil") + } + if len(bf.Data()) != 0 { + t.Errorf("NewByteFrame().Data() len = %d, want 0", len(bf.Data())) + } +} + +func TestNewByteFrameFromBytes(t *testing.T) { + data := []byte{1, 2, 3, 4, 5} + bf := NewByteFrameFromBytes(data) + + if bf == nil { + t.Fatal("NewByteFrameFromBytes() returned nil") + } + if len(bf.Data()) != len(data) { + t.Errorf("NewByteFrameFromBytes().Data() len = %d, want %d", len(bf.Data()), len(data)) + } + + // Verify data is copied, not referenced + data[0] = 99 + if bf.Data()[0] == 99 { + t.Error("NewByteFrameFromBytes() did not copy data") + } +} + +func TestWriteReadUint8(t *testing.T) { + tests := []uint8{0, 1, 127, 128, 255} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint8(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint8() + if got != val { + t.Errorf("Write/ReadUint8(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadUint16(t *testing.T) { + tests := []uint16{0, 1, 255, 256, 32767, 65535} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint16(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint16() + if got != val { + t.Errorf("Write/ReadUint16(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadUint32(t *testing.T) { + tests := []uint32{0, 1, 255, 65535, 2147483647, 4294967295} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint32(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint32() + if got != val { + t.Errorf("Write/ReadUint32(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadUint64(t *testing.T) { + tests := []uint64{0, 1, 255, 65535, 4294967295, 18446744073709551615} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint64(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint64() + if got != val { + t.Errorf("Write/ReadUint64(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadInt8(t *testing.T) { + tests := []int8{-128, -1, 0, 1, 127} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteInt8(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadInt8() + if got != val { + t.Errorf("Write/ReadInt8(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadInt16(t *testing.T) { + tests := []int16{-32768, -1, 0, 1, 32767} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteInt16(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadInt16() + if got != val { + t.Errorf("Write/ReadInt16(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadInt32(t *testing.T) { + tests := []int32{-2147483648, -1, 0, 1, 2147483647} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteInt32(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadInt32() + if got != val { + t.Errorf("Write/ReadInt32(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadInt64(t *testing.T) { + tests := []int64{-9223372036854775808, -1, 0, 1, 9223372036854775807} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteInt64(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadInt64() + if got != val { + t.Errorf("Write/ReadInt64(%d) = %d", val, got) + } + }) + } +} + +func TestWriteReadFloat32(t *testing.T) { + tests := []float32{0, 1.5, -1.5, 3.14159, math.MaxFloat32, math.SmallestNonzeroFloat32} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteFloat32(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadFloat32() + if got != val { + t.Errorf("Write/ReadFloat32(%f) = %f", val, got) + } + }) + } +} + +func TestWriteReadFloat64(t *testing.T) { + tests := []float64{0, 1.5, -1.5, 3.14159265358979, math.MaxFloat64, math.SmallestNonzeroFloat64} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteFloat64(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadFloat64() + if got != val { + t.Errorf("Write/ReadFloat64(%f) = %f", val, got) + } + }) + } +} + +func TestWriteReadBool(t *testing.T) { + tests := []bool{true, false} + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteBool(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadBool() + if got != val { + t.Errorf("Write/ReadBool(%v) = %v", val, got) + } + }) + } +} + +func TestWriteReadBytes(t *testing.T) { + tests := [][]byte{ + {}, + {1}, + {1, 2, 3, 4, 5}, + {0, 255, 128, 64, 32}, + } + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteBytes(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadBytes(uint(len(val))) + if len(got) != len(val) { + t.Errorf("Write/ReadBytes len = %d, want %d", len(got), len(val)) + return + } + for i := range got { + if got[i] != val[i] { + t.Errorf("Write/ReadBytes[%d] = %d, want %d", i, got[i], val[i]) + } + } + }) + } +} + +func TestWriteReadNullTerminatedBytes(t *testing.T) { + tests := [][]byte{ + {}, + {65}, + {65, 66, 67}, + } + + for _, val := range tests { + t.Run("", func(t *testing.T) { + bf := NewByteFrame() + bf.WriteNullTerminatedBytes(val) + bf.Seek(0, io.SeekStart) + got := bf.ReadNullTerminatedBytes() + if len(got) != len(val) { + t.Errorf("Write/ReadNullTerminatedBytes len = %d, want %d", len(got), len(val)) + return + } + for i := range got { + if got[i] != val[i] { + t.Errorf("Write/ReadNullTerminatedBytes[%d] = %d, want %d", i, got[i], val[i]) + } + } + }) + } +} + +func TestSeek(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint32(0x12345678) + bf.WriteUint32(0xDEADBEEF) + + // SeekStart + pos, err := bf.Seek(0, io.SeekStart) + if err != nil { + t.Errorf("Seek(0, SeekStart) error = %v", err) + } + if pos != 0 { + t.Errorf("Seek(0, SeekStart) pos = %d, want 0", pos) + } + + val := bf.ReadUint32() + if val != 0x12345678 { + t.Errorf("After Seek(0, SeekStart) ReadUint32() = %x, want 0x12345678", val) + } + + // SeekCurrent + pos, err = bf.Seek(-4, io.SeekCurrent) + if err != nil { + t.Errorf("Seek(-4, SeekCurrent) error = %v", err) + } + if pos != 0 { + t.Errorf("Seek(-4, SeekCurrent) pos = %d, want 0", pos) + } + + // SeekEnd + pos, err = bf.Seek(-4, io.SeekEnd) + if err != nil { + t.Errorf("Seek(-4, SeekEnd) error = %v", err) + } + if pos != 4 { + t.Errorf("Seek(-4, SeekEnd) pos = %d, want 4", pos) + } + + val = bf.ReadUint32() + if val != 0xDEADBEEF { + t.Errorf("After Seek(-4, SeekEnd) ReadUint32() = %x, want 0xDEADBEEF", val) + } +} + +func TestSeekErrors(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint32(0x12345678) + + // Seek beyond end + _, err := bf.Seek(100, io.SeekStart) + if err == nil { + t.Error("Seek(100, SeekStart) should return error") + } + + // Seek before start + _, err = bf.Seek(-100, io.SeekCurrent) + if err == nil { + t.Error("Seek(-100, SeekCurrent) should return error") + } + + // Seek before start from end + _, err = bf.Seek(-100, io.SeekEnd) + if err == nil { + t.Error("Seek(-100, SeekEnd) should return error") + } +} + +func TestEndianness(t *testing.T) { + // Test big endian (default) + bf := NewByteFrame() + bf.WriteUint16(0x1234) + data := bf.Data() + if data[0] != 0x12 || data[1] != 0x34 { + t.Errorf("Big endian WriteUint16(0x1234) = %v, want [0x12, 0x34]", data) + } + + // Test little endian + bf = NewByteFrame() + bf.SetLE() + bf.WriteUint16(0x1234) + data = bf.Data() + if data[0] != 0x34 || data[1] != 0x12 { + t.Errorf("Little endian WriteUint16(0x1234) = %v, want [0x34, 0x12]", data) + } + + // Test switching back to big endian + bf = NewByteFrame() + bf.SetLE() + bf.SetBE() + bf.WriteUint16(0x1234) + data = bf.Data() + if data[0] != 0x12 || data[1] != 0x34 { + t.Errorf("Switched back to big endian WriteUint16(0x1234) = %v, want [0x12, 0x34]", data) + } +} + +func TestDataFromCurrent(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint8(1) + bf.WriteUint8(2) + bf.WriteUint8(3) + bf.WriteUint8(4) + + bf.Seek(2, io.SeekStart) + remaining := bf.DataFromCurrent() + + if len(remaining) != 2 { + t.Errorf("DataFromCurrent() len = %d, want 2", len(remaining)) + } + if remaining[0] != 3 || remaining[1] != 4 { + t.Errorf("DataFromCurrent() = %v, want [3, 4]", remaining) + } +} + +func TestBufferGrowth(t *testing.T) { + bf := NewByteFrame() + + // Write more data than initial buffer size (4 bytes) + for i := 0; i < 100; i++ { + bf.WriteUint32(uint32(i)) + } + + if len(bf.Data()) != 400 { + t.Errorf("After writing 100 uint32s, Data() len = %d, want 400", len(bf.Data())) + } + + // Verify data integrity + bf.Seek(0, io.SeekStart) + for i := 0; i < 100; i++ { + val := bf.ReadUint32() + if val != uint32(i) { + t.Errorf("After growth, ReadUint32()[%d] = %d, want %d", i, val, i) + } + } +} + +func TestMultipleWrites(t *testing.T) { + bf := NewByteFrame() + + bf.WriteUint8(0x01) + bf.WriteUint16(0x0203) + bf.WriteUint32(0x04050607) + bf.WriteUint64(0x08090A0B0C0D0E0F) + + expected := []byte{ + 0x01, + 0x02, 0x03, + 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, + } + + data := bf.Data() + if len(data) != len(expected) { + t.Errorf("Multiple writes Data() len = %d, want %d", len(data), len(expected)) + return + } + + for i := range expected { + if data[i] != expected[i] { + t.Errorf("Multiple writes Data()[%d] = %x, want %x", i, data[i], expected[i]) + } + } +} + +func TestReadPanicsOnOverflow(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("ReadUint32 on empty buffer should panic") + } + }() + + bf := NewByteFrame() + bf.ReadUint32() +} + +func TestReadBoolNonZero(t *testing.T) { + // Test that any non-zero value is considered true + bf := NewByteFrameFromBytes([]byte{0, 1, 2, 255}) + + if bf.ReadBool() != false { + t.Error("ReadBool(0) should be false") + } + if bf.ReadBool() != true { + t.Error("ReadBool(1) should be true") + } + if bf.ReadBool() != true { + t.Error("ReadBool(2) should be true") + } + if bf.ReadBool() != true { + t.Error("ReadBool(255) should be true") + } +} + +func TestReadNullTerminatedBytesNoTerminator(t *testing.T) { + // Test behavior when there's no null terminator + bf := NewByteFrameFromBytes([]byte{65, 66, 67}) + result := bf.ReadNullTerminatedBytes() + + if len(result) != 0 { + t.Errorf("ReadNullTerminatedBytes with no terminator should return empty, got %v", result) + } +} diff --git a/common/mhfcourse/mhfcourse_test.go b/common/mhfcourse/mhfcourse_test.go new file mode 100644 index 000000000..75f1d8745 --- /dev/null +++ b/common/mhfcourse/mhfcourse_test.go @@ -0,0 +1,288 @@ +package mhfcourse + +import ( + "math" + "testing" +) + +func TestCourses(t *testing.T) { + courses := Courses() + + if len(courses) != 32 { + t.Errorf("Courses() len = %d, want 32", len(courses)) + } + + for i, course := range courses { + if course.ID != uint16(i) { + t.Errorf("Courses()[%d].ID = %d, want %d", i, course.ID, i) + } + } +} + +func TestCourseValue(t *testing.T) { + tests := []struct { + id uint16 + want uint32 + }{ + {0, 1}, + {1, 2}, + {2, 4}, + {3, 8}, + {4, 16}, + {5, 32}, + {10, 1024}, + {20, 1048576}, + {31, 2147483648}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + c := Course{ID: tt.id} + got := c.Value() + if got != tt.want { + t.Errorf("Course{ID: %d}.Value() = %d, want %d", tt.id, got, tt.want) + } + }) + } +} + +func TestCourseValueIsPowerOf2(t *testing.T) { + for i := uint16(0); i < 32; i++ { + c := Course{ID: i} + val := c.Value() + expected := uint32(math.Pow(2, float64(i))) + if val != expected { + t.Errorf("Course{ID: %d}.Value() = %d, want %d (2^%d)", i, val, expected, i) + } + } +} + +func TestCourseAliases(t *testing.T) { + tests := []struct { + id uint16 + wantLen int + contains string + }{ + {1, 2, "Trial"}, + {2, 2, "HunterLife"}, + {3, 3, "Extra"}, + {6, 1, "Premium"}, + {8, 4, "Assist"}, + {26, 4, "NetCafe"}, + {29, 1, "Free"}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + c := Course{ID: tt.id} + aliases := c.Aliases() + + if len(aliases) != tt.wantLen { + t.Errorf("Course{ID: %d}.Aliases() len = %d, want %d", tt.id, len(aliases), tt.wantLen) + } + + found := false + for _, alias := range aliases { + if alias == tt.contains { + found = true + break + } + } + if !found { + t.Errorf("Course{ID: %d}.Aliases() should contain %q", tt.id, tt.contains) + } + }) + } +} + +func TestCourseAliasesUnknown(t *testing.T) { + // Test IDs without aliases + unknownIDs := []uint16{13, 14, 15, 16, 17, 18, 19} + + for _, id := range unknownIDs { + c := Course{ID: id} + aliases := c.Aliases() + if aliases != nil { + t.Errorf("Course{ID: %d}.Aliases() = %v, want nil", id, aliases) + } + } +} + +func TestCourseExists(t *testing.T) { + courses := []Course{ + {ID: 1}, + {ID: 5}, + {ID: 10}, + } + + tests := []struct { + id uint16 + want bool + }{ + {1, true}, + {5, true}, + {10, true}, + {0, false}, + {2, false}, + {99, false}, + } + + for _, tt := range tests { + t.Run("", func(t *testing.T) { + got := CourseExists(tt.id, courses) + if got != tt.want { + t.Errorf("CourseExists(%d, courses) = %v, want %v", tt.id, got, tt.want) + } + }) + } +} + +func TestCourseExistsEmptySlice(t *testing.T) { + var courses []Course + + if CourseExists(1, courses) { + t.Error("CourseExists(1, nil) should return false") + } +} + +func TestGetCourseStruct(t *testing.T) { + tests := []struct { + name string + rights uint32 + wantMinLen int + shouldHave []uint16 + shouldNotHave []uint16 + }{ + { + name: "zero rights", + rights: 0, + wantMinLen: 1, // Always includes ID: 1 (Trial) + shouldHave: []uint16{1}, + }, + { + name: "HunterLife course", + rights: 4, // 2^2 = 4 for ID 2 + wantMinLen: 2, + shouldHave: []uint16{1, 2}, + }, + { + name: "multiple courses", + rights: 6, // 2^1 + 2^2 = 2 + 4 = 6 for IDs 1 and 2 + wantMinLen: 2, + shouldHave: []uint16{1, 2}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + courses, _ := GetCourseStruct(tt.rights) + + if len(courses) < tt.wantMinLen { + t.Errorf("GetCourseStruct(%d) len = %d, want >= %d", tt.rights, len(courses), tt.wantMinLen) + } + + for _, id := range tt.shouldHave { + if !CourseExists(id, courses) { + t.Errorf("GetCourseStruct(%d) should have course ID %d", tt.rights, id) + } + } + + for _, id := range tt.shouldNotHave { + if CourseExists(id, courses) { + t.Errorf("GetCourseStruct(%d) should not have course ID %d", tt.rights, id) + } + } + }) + } +} + +func TestGetCourseStructReturnsRights(t *testing.T) { + // GetCourseStruct returns the recalculated rights value + _, rights := GetCourseStruct(0) + + // Should at least include the Trial course (ID: 1, Value: 2) + if rights < 2 { + t.Errorf("GetCourseStruct(0) rights = %d, want >= 2", rights) + } +} + +func TestGetCourseStructNetCafeCourses(t *testing.T) { + // Test that course 26 (NetCafe) adds course 25 (CAFE_SP) + courses, _ := GetCourseStruct(Course{ID: 26}.Value()) + + if !CourseExists(25, courses) { + t.Error("GetCourseStruct with course 26 should add course 25") + } + if !CourseExists(30, courses) { + t.Error("GetCourseStruct with course 26 should add course 30") + } +} + +func TestGetCourseStructNCourse(t *testing.T) { + // Test that course 9 (N) adds course 30 + courses, _ := GetCourseStruct(Course{ID: 9}.Value()) + + if !CourseExists(30, courses) { + t.Error("GetCourseStruct with course 9 should add course 30") + } +} + +func TestCourseExpiry(t *testing.T) { + // Test that courses returned by GetCourseStruct have expiry set + courses, _ := GetCourseStruct(4) // HunterLife + + for _, c := range courses { + // Course ID 1 is always added without expiry in some cases + if c.ID != 1 && c.ID != 25 && c.ID != 30 { + if c.Expiry.IsZero() { + // Note: expiry is only set for courses extracted from rights + // This behavior is expected + } + } + } +} + +func TestAllCoursesHaveValidValues(t *testing.T) { + courses := Courses() + + for _, c := range courses { + val := c.Value() + // Verify value is a power of 2 + if val == 0 || (val&(val-1)) != 0 { + t.Errorf("Course{ID: %d}.Value() = %d is not a power of 2", c.ID, val) + } + } +} + +func TestKnownAliasesExist(t *testing.T) { + knownCourses := map[string]uint16{ + "Trial": 1, + "HunterLife": 2, + "Extra": 3, + "Mobile": 5, + "Premium": 6, + "Assist": 8, + "Hiden": 10, + "NetCafe": 26, + "Free": 29, + } + + for name, expectedID := range knownCourses { + t.Run(name, func(t *testing.T) { + c := Course{ID: expectedID} + aliases := c.Aliases() + + found := false + for _, alias := range aliases { + if alias == name { + found = true + break + } + } + + if !found { + t.Errorf("Course ID %d should have alias %q, got %v", expectedID, name, aliases) + } + }) + } +} diff --git a/common/stringsupport/string_convert_test.go b/common/stringsupport/string_convert_test.go new file mode 100644 index 000000000..5c221977f --- /dev/null +++ b/common/stringsupport/string_convert_test.go @@ -0,0 +1,341 @@ +package stringsupport + +import ( + "bytes" + "testing" + + "golang.org/x/text/encoding/japanese" +) + +func TestStringConverterDecode(t *testing.T) { + sc := &StringConverter{Encoding: japanese.ShiftJIS} + + tests := []struct { + name string + input []byte + want string + wantErr bool + }{ + {"empty", []byte{}, "", false}, + {"ascii", []byte("Hello"), "Hello", false}, + {"japanese hello", []byte{0x82, 0xb1, 0x82, 0xf1, 0x82, 0xc9, 0x82, 0xbf, 0x82, 0xcd}, "こんにちは", false}, + {"mixed", []byte{0x41, 0x42, 0x43}, "ABC", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sc.Decode(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Decode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Decode() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestStringConverterEncode(t *testing.T) { + sc := &StringConverter{Encoding: japanese.ShiftJIS} + + tests := []struct { + name string + input string + want []byte + wantErr bool + }{ + {"empty", "", []byte{}, false}, + {"ascii", "Hello", []byte("Hello"), false}, + {"japanese hello", "こんにちは", []byte{0x82, 0xb1, 0x82, 0xf1, 0x82, 0xc9, 0x82, 0xbf, 0x82, 0xcd}, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := sc.Encode(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("Encode() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !bytes.Equal(got, tt.want) { + t.Errorf("Encode() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestStringConverterMustDecode(t *testing.T) { + sc := &StringConverter{Encoding: japanese.ShiftJIS} + + // Valid input should not panic + result := sc.MustDecode([]byte("Hello")) + if result != "Hello" { + t.Errorf("MustDecode() = %q, want %q", result, "Hello") + } +} + +func TestStringConverterMustEncode(t *testing.T) { + sc := &StringConverter{Encoding: japanese.ShiftJIS} + + // Valid input should not panic + result := sc.MustEncode("Hello") + if !bytes.Equal(result, []byte("Hello")) { + t.Errorf("MustEncode() = %v, want %v", result, []byte("Hello")) + } +} + +func TestUTF8ToSJIS(t *testing.T) { + tests := []struct { + name string + input string + want []byte + }{ + {"empty", "", []byte{}}, + {"ascii", "ABC", []byte("ABC")}, + {"japanese", "こんにちは", []byte{0x82, 0xb1, 0x82, 0xf1, 0x82, 0xc9, 0x82, 0xbf, 0x82, 0xcd}}, + {"mixed", "Hello世界", []byte{0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x90, 0xa2, 0x8a, 0x45}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := UTF8ToSJIS(tt.input) + if !bytes.Equal(got, tt.want) { + t.Errorf("UTF8ToSJIS(%q) = %v, want %v", tt.input, got, tt.want) + } + }) + } +} + +func TestSJISToUTF8(t *testing.T) { + tests := []struct { + name string + input []byte + want string + }{ + {"empty", []byte{}, ""}, + {"ascii", []byte("ABC"), "ABC"}, + {"japanese", []byte{0x82, 0xb1, 0x82, 0xf1, 0x82, 0xc9, 0x82, 0xbf, 0x82, 0xcd}, "こんにちは"}, + {"mixed", []byte{0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x90, 0xa2, 0x8a, 0x45}, "Hello世界"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := SJISToUTF8(tt.input) + if got != tt.want { + t.Errorf("SJISToUTF8(%v) = %q, want %q", tt.input, got, tt.want) + } + }) + } +} + +func TestUTF8ToSJISRoundTrip(t *testing.T) { + tests := []string{ + "Hello", + "ABC123", + "こんにちは", + "テスト", + "モンスターハンター", + } + + for _, input := range tests { + t.Run(input, func(t *testing.T) { + encoded := UTF8ToSJIS(input) + decoded := SJISToUTF8(encoded) + if decoded != input { + t.Errorf("Round trip failed: %q -> %v -> %q", input, encoded, decoded) + } + }) + } +} + +func TestPaddedString(t *testing.T) { + tests := []struct { + name string + input string + size uint + transform bool + wantLen int + wantEnd byte + }{ + {"empty ascii", "", 10, false, 10, 0}, + {"short ascii", "Hi", 10, false, 10, 0}, + {"exact ascii", "1234567890", 10, false, 10, 0}, + {"empty sjis", "", 10, true, 10, 0}, + {"short sjis", "Hi", 10, true, 10, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := PaddedString(tt.input, tt.size, tt.transform) + if len(got) != tt.wantLen { + t.Errorf("PaddedString() len = %d, want %d", len(got), tt.wantLen) + } + if got[len(got)-1] != tt.wantEnd { + t.Errorf("PaddedString() last byte = %d, want %d", got[len(got)-1], tt.wantEnd) + } + }) + } +} + +func TestPaddedStringContent(t *testing.T) { + // Verify the content is correctly placed at the beginning + result := PaddedString("ABC", 10, false) + + if result[0] != 'A' || result[1] != 'B' || result[2] != 'C' { + t.Errorf("PaddedString() content mismatch: got %v", result[:3]) + } + + // Rest should be zeros (except last which is forced to 0) + for i := 3; i < 10; i++ { + if result[i] != 0 { + t.Errorf("PaddedString() byte at %d = %d, want 0", i, result[i]) + } + } +} + +func TestCSVAdd(t *testing.T) { + tests := []struct { + name string + csv string + v int + want string + }{ + {"empty add", "", 5, "5"}, + {"add to existing", "1,2,3", 4, "1,2,3,4"}, + {"add duplicate", "1,2,3", 2, "1,2,3"}, + {"add to single", "1", 2, "1,2"}, + {"add zero", "", 0, "0"}, + {"add negative", "1,2", -5, "1,2,-5"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CSVAdd(tt.csv, tt.v) + if got != tt.want { + t.Errorf("CSVAdd(%q, %d) = %q, want %q", tt.csv, tt.v, got, tt.want) + } + }) + } +} + +func TestCSVRemove(t *testing.T) { + tests := []struct { + name string + csv string + v int + want string + }{ + {"remove from middle", "1,2,3", 2, "1,3"}, + {"remove first", "1,2,3", 1, "3,2"}, + {"remove last", "1,2,3", 3, "1,2"}, + {"remove only", "5", 5, ""}, + {"remove nonexistent", "1,2,3", 99, "1,2,3"}, + {"remove from empty", "", 5, ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CSVRemove(tt.csv, tt.v) + if got != tt.want { + t.Errorf("CSVRemove(%q, %d) = %q, want %q", tt.csv, tt.v, got, tt.want) + } + }) + } +} + +func TestCSVContains(t *testing.T) { + tests := []struct { + name string + csv string + v int + want bool + }{ + {"contains first", "1,2,3", 1, true}, + {"contains middle", "1,2,3", 2, true}, + {"contains last", "1,2,3", 3, true}, + {"not contains", "1,2,3", 99, false}, + {"empty csv", "", 5, false}, + {"single contains", "5", 5, true}, + {"single not contains", "5", 3, false}, + {"contains zero", "0,1,2", 0, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CSVContains(tt.csv, tt.v) + if got != tt.want { + t.Errorf("CSVContains(%q, %d) = %v, want %v", tt.csv, tt.v, got, tt.want) + } + }) + } +} + +func TestCSVLength(t *testing.T) { + tests := []struct { + name string + csv string + want int + }{ + {"empty", "", 0}, + {"single", "5", 1}, + {"two", "1,2", 2}, + {"three", "1,2,3", 3}, + {"many", "1,2,3,4,5,6,7,8,9,10", 10}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CSVLength(tt.csv) + if got != tt.want { + t.Errorf("CSVLength(%q) = %d, want %d", tt.csv, got, tt.want) + } + }) + } +} + +func TestCSVElems(t *testing.T) { + tests := []struct { + name string + csv string + want []int + }{ + {"empty", "", nil}, + {"single", "5", []int{5}}, + {"multiple", "1,2,3", []int{1, 2, 3}}, + {"with zero", "0,1,2", []int{0, 1, 2}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CSVElems(tt.csv) + if len(got) != len(tt.want) { + t.Errorf("CSVElems(%q) len = %d, want %d", tt.csv, len(got), len(tt.want)) + return + } + for i := range got { + if got[i] != tt.want[i] { + t.Errorf("CSVElems(%q)[%d] = %d, want %d", tt.csv, i, got[i], tt.want[i]) + } + } + }) + } +} + +func TestCSVAddRemoveRoundTrip(t *testing.T) { + csv := "" + csv = CSVAdd(csv, 1) + csv = CSVAdd(csv, 2) + csv = CSVAdd(csv, 3) + + if !CSVContains(csv, 1) || !CSVContains(csv, 2) || !CSVContains(csv, 3) { + t.Error("CSVAdd did not add all elements") + } + + csv = CSVRemove(csv, 2) + if CSVContains(csv, 2) { + t.Error("CSVRemove did not remove element") + } + if CSVLength(csv) != 2 { + t.Errorf("CSVLength after remove = %d, want 2", CSVLength(csv)) + } +} diff --git a/common/token/token_test.go b/common/token/token_test.go new file mode 100644 index 000000000..9919faf53 --- /dev/null +++ b/common/token/token_test.go @@ -0,0 +1,120 @@ +package token + +import ( + "regexp" + "testing" +) + +func TestGenerate(t *testing.T) { + tests := []struct { + name string + length int + }{ + {"zero length", 0}, + {"short token", 8}, + {"medium token", 32}, + {"long token", 256}, + {"single char", 1}, + } + + alphanumeric := regexp.MustCompile(`^[a-zA-Z0-9]*$`) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Generate(tt.length) + if len(result) != tt.length { + t.Errorf("Generate(%d) length = %d, want %d", tt.length, len(result), tt.length) + } + if !alphanumeric.MatchString(result) { + t.Errorf("Generate(%d) = %q, contains non-alphanumeric characters", tt.length, result) + } + }) + } +} + +func TestGenerateUniqueness(t *testing.T) { + // Generate multiple tokens and check they're different + tokens := make(map[string]bool) + iterations := 100 + length := 32 + + for i := 0; i < iterations; i++ { + token := Generate(length) + if tokens[token] { + t.Errorf("Generate(%d) produced duplicate token: %s", length, token) + } + tokens[token] = true + } +} + +func TestGenerateCharacterDistribution(t *testing.T) { + // Generate a long token and verify it uses various characters + token := Generate(1000) + + hasLower := regexp.MustCompile(`[a-z]`).MatchString(token) + hasUpper := regexp.MustCompile(`[A-Z]`).MatchString(token) + hasDigit := regexp.MustCompile(`[0-9]`).MatchString(token) + + if !hasLower { + t.Error("Generate(1000) did not produce any lowercase letters") + } + if !hasUpper { + t.Error("Generate(1000) did not produce any uppercase letters") + } + if !hasDigit { + t.Error("Generate(1000) did not produce any digits") + } +} + +func TestRNG(t *testing.T) { + rng1 := RNG() + rng2 := RNG() + + if rng1 == nil { + t.Error("RNG() returned nil") + } + if rng2 == nil { + t.Error("RNG() returned nil") + } + + // Both should generate valid random numbers + val1 := rng1.Intn(100) + val2 := rng2.Intn(100) + + if val1 < 0 || val1 >= 100 { + t.Errorf("RNG().Intn(100) = %d, want value in [0, 100)", val1) + } + if val2 < 0 || val2 >= 100 { + t.Errorf("RNG().Intn(100) = %d, want value in [0, 100)", val2) + } +} + +func TestRNGIndependence(t *testing.T) { + // Create multiple RNGs and verify they produce different sequences + rng1 := RNG() + rng2 := RNG() + + // Generate sequences + seq1 := make([]int, 10) + seq2 := make([]int, 10) + + for i := 0; i < 10; i++ { + seq1[i] = rng1.Intn(1000000) + seq2[i] = rng2.Intn(1000000) + } + + // Check that sequences are likely different (not identical) + identical := true + for i := 0; i < 10; i++ { + if seq1[i] != seq2[i] { + identical = false + break + } + } + + // Note: There's an extremely small chance both RNGs could produce + // the same sequence, but it's astronomically unlikely + if identical { + t.Log("Warning: Two independent RNGs produced identical sequences (this is extremely unlikely)") + } +} diff --git a/config/config_test.go b/config/config_test.go new file mode 100644 index 000000000..c7a70bfaf --- /dev/null +++ b/config/config_test.go @@ -0,0 +1,542 @@ +package config + +import ( + "os" + "path/filepath" + "testing" +) + +func TestConfigStructDefaults(t *testing.T) { + // Test that Config struct has expected zero values + c := &Config{} + + if c.DevMode != false { + t.Error("DevMode default should be false") + } + if c.Host != "" { + t.Error("Host default should be empty") + } + if c.HideLoginNotice != false { + t.Error("HideLoginNotice default should be false") + } +} + +func TestDevModeOptionsDefaults(t *testing.T) { + d := DevModeOptions{} + + if d.AutoCreateAccount != false { + t.Error("AutoCreateAccount default should be false") + } + if d.CleanDB != false { + t.Error("CleanDB default should be false") + } + if d.MaxLauncherHR != false { + t.Error("MaxLauncherHR default should be false") + } + if d.LogInboundMessages != false { + t.Error("LogInboundMessages default should be false") + } + if d.LogOutboundMessages != false { + t.Error("LogOutboundMessages default should be false") + } + if d.MaxHexdumpLength != 0 { + t.Error("MaxHexdumpLength default should be 0") + } +} + +func TestGameplayOptionsDefaults(t *testing.T) { + g := GameplayOptions{} + + if g.FeaturedWeapons != 0 { + t.Error("FeaturedWeapons default should be 0") + } + if g.MaximumNP != 0 { + t.Error("MaximumNP default should be 0") + } + if g.MaximumRP != 0 { + t.Error("MaximumRP default should be 0") + } + if g.DisableLoginBoost != false { + t.Error("DisableLoginBoost default should be false") + } +} + +func TestLoggingDefaults(t *testing.T) { + l := Logging{} + + if l.LogToFile != false { + t.Error("LogToFile default should be false") + } + if l.LogFilePath != "" { + t.Error("LogFilePath default should be empty") + } + if l.LogMaxSize != 0 { + t.Error("LogMaxSize default should be 0") + } +} + +func TestDatabaseStruct(t *testing.T) { + d := Database{ + Host: "localhost", + Port: 5432, + User: "postgres", + Password: "secret", + Database: "erupe", + } + + if d.Host != "localhost" { + t.Errorf("Host = %s, want localhost", d.Host) + } + if d.Port != 5432 { + t.Errorf("Port = %d, want 5432", d.Port) + } + if d.User != "postgres" { + t.Errorf("User = %s, want postgres", d.User) + } + if d.Password != "secret" { + t.Errorf("Password = %s, want secret", d.Password) + } + if d.Database != "erupe" { + t.Errorf("Database = %s, want erupe", d.Database) + } +} + +func TestSignStruct(t *testing.T) { + s := Sign{ + Enabled: true, + Port: 53312, + } + + if s.Enabled != true { + t.Error("Enabled should be true") + } + if s.Port != 53312 { + t.Errorf("Port = %d, want 53312", s.Port) + } +} + +func TestSignV2Struct(t *testing.T) { + s := SignV2{ + Enabled: true, + Port: 8080, + } + + if s.Enabled != true { + t.Error("Enabled should be true") + } + if s.Port != 8080 { + t.Errorf("Port = %d, want 8080", s.Port) + } +} + +func TestEntranceStruct(t *testing.T) { + e := Entrance{ + Enabled: true, + Port: 53310, + Entries: []EntranceServerInfo{ + { + IP: "127.0.0.1", + Type: 1, + Name: "Test Server", + }, + }, + } + + if e.Enabled != true { + t.Error("Enabled should be true") + } + if e.Port != 53310 { + t.Errorf("Port = %d, want 53310", e.Port) + } + if len(e.Entries) != 1 { + t.Errorf("Entries len = %d, want 1", len(e.Entries)) + } +} + +func TestEntranceServerInfoStruct(t *testing.T) { + info := EntranceServerInfo{ + IP: "192.168.1.1", + Type: 2, + Season: 1, + Recommended: 3, + Name: "Test World", + Description: "A test server", + AllowedClientFlags: 4096, + Channels: []EntranceChannelInfo{ + {Port: 54001, MaxPlayers: 100, CurrentPlayers: 50}, + }, + } + + if info.IP != "192.168.1.1" { + t.Errorf("IP = %s, want 192.168.1.1", info.IP) + } + if info.Type != 2 { + t.Errorf("Type = %d, want 2", info.Type) + } + if info.Season != 1 { + t.Errorf("Season = %d, want 1", info.Season) + } + if info.Recommended != 3 { + t.Errorf("Recommended = %d, want 3", info.Recommended) + } + if info.Name != "Test World" { + t.Errorf("Name = %s, want Test World", info.Name) + } + if info.Description != "A test server" { + t.Errorf("Description = %s, want A test server", info.Description) + } + if info.AllowedClientFlags != 4096 { + t.Errorf("AllowedClientFlags = %d, want 4096", info.AllowedClientFlags) + } + if len(info.Channels) != 1 { + t.Errorf("Channels len = %d, want 1", len(info.Channels)) + } +} + +func TestEntranceChannelInfoStruct(t *testing.T) { + ch := EntranceChannelInfo{ + Port: 54001, + MaxPlayers: 100, + CurrentPlayers: 25, + } + + if ch.Port != 54001 { + t.Errorf("Port = %d, want 54001", ch.Port) + } + if ch.MaxPlayers != 100 { + t.Errorf("MaxPlayers = %d, want 100", ch.MaxPlayers) + } + if ch.CurrentPlayers != 25 { + t.Errorf("CurrentPlayers = %d, want 25", ch.CurrentPlayers) + } +} + +func TestDiscordStruct(t *testing.T) { + d := Discord{ + Enabled: true, + BotToken: "test-token", + RealtimeChannelID: "123456789", + } + + if d.Enabled != true { + t.Error("Enabled should be true") + } + if d.BotToken != "test-token" { + t.Errorf("BotToken = %s, want test-token", d.BotToken) + } + if d.RealtimeChannelID != "123456789" { + t.Errorf("RealtimeChannelID = %s, want 123456789", d.RealtimeChannelID) + } +} + +func TestCommandStruct(t *testing.T) { + cmd := Command{ + Name: "teleport", + Enabled: true, + Prefix: "!", + } + + if cmd.Name != "teleport" { + t.Errorf("Name = %s, want teleport", cmd.Name) + } + if cmd.Enabled != true { + t.Error("Enabled should be true") + } + if cmd.Prefix != "!" { + t.Errorf("Prefix = %s, want !", cmd.Prefix) + } +} + +func TestCourseStruct(t *testing.T) { + course := Course{ + Name: "Premium", + Enabled: true, + } + + if course.Name != "Premium" { + t.Errorf("Name = %s, want Premium", course.Name) + } + if course.Enabled != true { + t.Error("Enabled should be true") + } +} + +func TestSaveDumpOptionsStruct(t *testing.T) { + s := SaveDumpOptions{ + Enabled: true, + OutputDir: "/tmp/dumps", + } + + if s.Enabled != true { + t.Error("Enabled should be true") + } + if s.OutputDir != "/tmp/dumps" { + t.Errorf("OutputDir = %s, want /tmp/dumps", s.OutputDir) + } +} + +func TestIsTestMode(t *testing.T) { + // When running tests, isTestMode should return true + if !isTestMode() { + t.Error("isTestMode() should return true when running tests") + } +} + +func TestLoadConfigMissingFile(t *testing.T) { + // Create a temporary directory without a config file + tmpDir, err := os.MkdirTemp("", "erupe-config-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Save current directory and change to temp + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current dir: %v", err) + } + defer os.Chdir(origDir) + + err = os.Chdir(tmpDir) + if err != nil { + t.Fatalf("Failed to change to temp dir: %v", err) + } + + // LoadConfig should fail without config.json + _, err = LoadConfig() + if err == nil { + t.Error("LoadConfig() should return error when config file is missing") + } +} + +func TestLoadConfigValidFile(t *testing.T) { + // Create a temporary directory with a valid config file + tmpDir, err := os.MkdirTemp("", "erupe-config-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create minimal config.json + configContent := `{ + "Host": "127.0.0.1", + "DevMode": true, + "Database": { + "Host": "localhost", + "Port": 5432, + "User": "postgres", + "Password": "password", + "Database": "erupe" + } + }` + + configPath := filepath.Join(tmpDir, "config.json") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write config file: %v", err) + } + + // Save current directory and change to temp + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current dir: %v", err) + } + defer os.Chdir(origDir) + + err = os.Chdir(tmpDir) + if err != nil { + t.Fatalf("Failed to change to temp dir: %v", err) + } + + // LoadConfig should succeed + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + if cfg.Host != "127.0.0.1" { + t.Errorf("Host = %s, want 127.0.0.1", cfg.Host) + } + if cfg.DevMode != true { + t.Error("DevMode should be true") + } + if cfg.Database.Host != "localhost" { + t.Errorf("Database.Host = %s, want localhost", cfg.Database.Host) + } +} + +func TestLoadConfigDefaults(t *testing.T) { + // Create a temporary directory with minimal config + tmpDir, err := os.MkdirTemp("", "erupe-config-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create minimal config.json (just enough to pass) + configContent := `{ + "Host": "192.168.1.1" + }` + + configPath := filepath.Join(tmpDir, "config.json") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write config file: %v", err) + } + + // Save current directory and change to temp + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current dir: %v", err) + } + defer os.Chdir(origDir) + + err = os.Chdir(tmpDir) + if err != nil { + t.Fatalf("Failed to change to temp dir: %v", err) + } + + cfg, err := LoadConfig() + if err != nil { + t.Fatalf("LoadConfig() error = %v", err) + } + + // Check Logging defaults are applied + if cfg.Logging.LogToFile != true { + t.Error("Logging.LogToFile should default to true") + } + if cfg.Logging.LogFilePath != "logs/erupe.log" { + t.Errorf("Logging.LogFilePath = %s, want logs/erupe.log", cfg.Logging.LogFilePath) + } + if cfg.Logging.LogMaxSize != 100 { + t.Errorf("Logging.LogMaxSize = %d, want 100", cfg.Logging.LogMaxSize) + } + if cfg.Logging.LogMaxBackups != 3 { + t.Errorf("Logging.LogMaxBackups = %d, want 3", cfg.Logging.LogMaxBackups) + } + if cfg.Logging.LogMaxAge != 28 { + t.Errorf("Logging.LogMaxAge = %d, want 28", cfg.Logging.LogMaxAge) + } + if cfg.Logging.LogCompress != true { + t.Error("Logging.LogCompress should default to true") + } + + // Check SaveDumps defaults + if cfg.DevModeOptions.SaveDumps.Enabled != false { + t.Error("SaveDumps.Enabled should default to false") + } + if cfg.DevModeOptions.SaveDumps.OutputDir != "savedata" { + t.Errorf("SaveDumps.OutputDir = %s, want savedata", cfg.DevModeOptions.SaveDumps.OutputDir) + } +} + +func TestLoadConfigInvalidJSON(t *testing.T) { + // Create a temporary directory with invalid JSON config + tmpDir, err := os.MkdirTemp("", "erupe-config-test-*") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Create invalid JSON + configContent := `{ this is not valid json }` + + configPath := filepath.Join(tmpDir, "config.json") + err = os.WriteFile(configPath, []byte(configContent), 0644) + if err != nil { + t.Fatalf("Failed to write config file: %v", err) + } + + // Save current directory and change to temp + origDir, err := os.Getwd() + if err != nil { + t.Fatalf("Failed to get current dir: %v", err) + } + defer os.Chdir(origDir) + + err = os.Chdir(tmpDir) + if err != nil { + t.Fatalf("Failed to change to temp dir: %v", err) + } + + _, err = LoadConfig() + if err == nil { + t.Error("LoadConfig() should return error for invalid JSON") + } +} + +func TestChannelStruct(t *testing.T) { + ch := Channel{ + Enabled: true, + } + + if ch.Enabled != true { + t.Error("Enabled should be true") + } +} + +func TestConfigCompleteStructure(t *testing.T) { + // Test building a complete config structure + cfg := &Config{ + Host: "192.168.1.100", + BinPath: "/bin", + Language: "JP", + DisableSoftCrash: true, + HideLoginNotice: false, + LoginNotices: []string{"Notice 1", "Notice 2"}, + PatchServerManifest: "http://patch.example.com/manifest", + PatchServerFile: "http://patch.example.com/files", + DevMode: true, + DevModeOptions: DevModeOptions{ + AutoCreateAccount: true, + CleanDB: false, + MaxLauncherHR: true, + LogInboundMessages: true, + LogOutboundMessages: true, + MaxHexdumpLength: 256, + }, + GameplayOptions: GameplayOptions{ + FeaturedWeapons: 5, + MaximumNP: 99999, + MaximumRP: 65535, + DisableLoginBoost: false, + BoostTimeDuration: 60, + GuildMealDuration: 30, + BonusQuestAllowance: 10, + DailyQuestAllowance: 5, + }, + Database: Database{ + Host: "db.example.com", + Port: 5432, + User: "erupe", + Password: "secret", + Database: "erupe_db", + }, + Sign: Sign{ + Enabled: true, + Port: 53312, + }, + SignV2: SignV2{ + Enabled: true, + Port: 8080, + }, + Channel: Channel{ + Enabled: true, + }, + Entrance: Entrance{ + Enabled: true, + Port: 53310, + }, + } + + // Verify values are set correctly + if cfg.Host != "192.168.1.100" { + t.Errorf("Host = %s, want 192.168.1.100", cfg.Host) + } + if cfg.GameplayOptions.MaximumNP != 99999 { + t.Errorf("MaximumNP = %d, want 99999", cfg.GameplayOptions.MaximumNP) + } + if len(cfg.LoginNotices) != 2 { + t.Errorf("LoginNotices len = %d, want 2", len(cfg.LoginNotices)) + } +} diff --git a/network/binpacket/binpacket_test.go b/network/binpacket/binpacket_test.go new file mode 100644 index 000000000..c5476a794 --- /dev/null +++ b/network/binpacket/binpacket_test.go @@ -0,0 +1,401 @@ +package binpacket + +import ( + "bytes" + "testing" + + "erupe-ce/common/byteframe" + "erupe-ce/network" +) + +func TestMsgBinTargetedOpcode(t *testing.T) { + m := &MsgBinTargeted{} + if m.Opcode() != network.MSG_SYS_CAST_BINARY { + t.Errorf("MsgBinTargeted.Opcode() = %v, want MSG_SYS_CAST_BINARY", m.Opcode()) + } +} + +func TestMsgBinTargetedParseEmpty(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint16(0) // TargetCount = 0 + + bf.Seek(0, 0) + + m := &MsgBinTargeted{} + err := m.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if m.TargetCount != 0 { + t.Errorf("TargetCount = %d, want 0", m.TargetCount) + } + if len(m.TargetCharIDs) != 0 { + t.Errorf("TargetCharIDs len = %d, want 0", len(m.TargetCharIDs)) + } +} + +func TestMsgBinTargetedParseSingleTarget(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint16(1) // TargetCount = 1 + bf.WriteUint32(0x12345678) // TargetCharID + bf.WriteBytes([]byte{0xDE, 0xAD, 0xBE, 0xEF}) + + bf.Seek(0, 0) + + m := &MsgBinTargeted{} + err := m.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if m.TargetCount != 1 { + t.Errorf("TargetCount = %d, want 1", m.TargetCount) + } + if len(m.TargetCharIDs) != 1 { + t.Errorf("TargetCharIDs len = %d, want 1", len(m.TargetCharIDs)) + } + if m.TargetCharIDs[0] != 0x12345678 { + t.Errorf("TargetCharIDs[0] = %x, want 0x12345678", m.TargetCharIDs[0]) + } + if !bytes.Equal(m.RawDataPayload, []byte{0xDE, 0xAD, 0xBE, 0xEF}) { + t.Errorf("RawDataPayload = %v, want [0xDE, 0xAD, 0xBE, 0xEF]", m.RawDataPayload) + } +} + +func TestMsgBinTargetedParseMultipleTargets(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint16(3) // TargetCount = 3 + bf.WriteUint32(100) + bf.WriteUint32(200) + bf.WriteUint32(300) + bf.WriteBytes([]byte{0x01, 0x02, 0x03}) + + bf.Seek(0, 0) + + m := &MsgBinTargeted{} + err := m.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if m.TargetCount != 3 { + t.Errorf("TargetCount = %d, want 3", m.TargetCount) + } + if len(m.TargetCharIDs) != 3 { + t.Errorf("TargetCharIDs len = %d, want 3", len(m.TargetCharIDs)) + } + if m.TargetCharIDs[0] != 100 || m.TargetCharIDs[1] != 200 || m.TargetCharIDs[2] != 300 { + t.Errorf("TargetCharIDs = %v, want [100, 200, 300]", m.TargetCharIDs) + } +} + +func TestMsgBinTargetedBuild(t *testing.T) { + m := &MsgBinTargeted{ + TargetCount: 2, + TargetCharIDs: []uint32{0x11111111, 0x22222222}, + RawDataPayload: []byte{0xAA, 0xBB}, + } + + bf := byteframe.NewByteFrame() + err := m.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + expected := []byte{ + 0x00, 0x02, // TargetCount + 0x11, 0x11, 0x11, 0x11, // TargetCharIDs[0] + 0x22, 0x22, 0x22, 0x22, // TargetCharIDs[1] + 0xAA, 0xBB, // RawDataPayload + } + + if !bytes.Equal(bf.Data(), expected) { + t.Errorf("Build() = %v, want %v", bf.Data(), expected) + } +} + +func TestMsgBinTargetedRoundTrip(t *testing.T) { + original := &MsgBinTargeted{ + TargetCount: 3, + TargetCharIDs: []uint32{1000, 2000, 3000}, + RawDataPayload: []byte{0x01, 0x02, 0x03, 0x04, 0x05}, + } + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, 0) + parsed := &MsgBinTargeted{} + err = parsed.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Compare + if parsed.TargetCount != original.TargetCount { + t.Errorf("TargetCount = %d, want %d", parsed.TargetCount, original.TargetCount) + } + if len(parsed.TargetCharIDs) != len(original.TargetCharIDs) { + t.Errorf("TargetCharIDs len = %d, want %d", len(parsed.TargetCharIDs), len(original.TargetCharIDs)) + } + for i := range original.TargetCharIDs { + if parsed.TargetCharIDs[i] != original.TargetCharIDs[i] { + t.Errorf("TargetCharIDs[%d] = %d, want %d", i, parsed.TargetCharIDs[i], original.TargetCharIDs[i]) + } + } + if !bytes.Equal(parsed.RawDataPayload, original.RawDataPayload) { + t.Errorf("RawDataPayload = %v, want %v", parsed.RawDataPayload, original.RawDataPayload) + } +} + +func TestMsgBinMailNotifyOpcode(t *testing.T) { + m := MsgBinMailNotify{} + if m.Opcode() != network.MSG_SYS_CASTED_BINARY { + t.Errorf("MsgBinMailNotify.Opcode() = %v, want MSG_SYS_CASTED_BINARY", m.Opcode()) + } +} + +func TestMsgBinMailNotifyBuild(t *testing.T) { + m := MsgBinMailNotify{ + SenderName: "TestPlayer", + } + + bf := byteframe.NewByteFrame() + err := m.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + data := bf.Data() + + // First byte should be 0x01 (Unk) + if data[0] != 0x01 { + t.Errorf("First byte = %x, want 0x01", data[0]) + } + + // Total length should be 1 (Unk) + 21 (padded name) = 22 + if len(data) != 22 { + t.Errorf("Data len = %d, want 22", len(data)) + } +} + +func TestMsgBinMailNotifyBuildEmptyName(t *testing.T) { + m := MsgBinMailNotify{ + SenderName: "", + } + + bf := byteframe.NewByteFrame() + err := m.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + if len(bf.Data()) != 22 { + t.Errorf("Data len = %d, want 22", len(bf.Data())) + } +} + +func TestMsgBinChatOpcode(t *testing.T) { + m := &MsgBinChat{} + if m.Opcode() != network.MSG_SYS_CAST_BINARY { + t.Errorf("MsgBinChat.Opcode() = %v, want MSG_SYS_CAST_BINARY", m.Opcode()) + } +} + +func TestMsgBinChatTypes(t *testing.T) { + tests := []struct { + chatType ChatType + value uint8 + }{ + {ChatTypeLocal, 1}, + {ChatTypeGuild, 2}, + {ChatTypeAlliance, 3}, + {ChatTypeParty, 4}, + {ChatTypeWhisper, 5}, + } + + for _, tt := range tests { + if uint8(tt.chatType) != tt.value { + t.Errorf("ChatType %v = %d, want %d", tt.chatType, uint8(tt.chatType), tt.value) + } + } +} + +func TestMsgBinChatBuildParse(t *testing.T) { + original := &MsgBinChat{ + Unk0: 0x00, + Type: ChatTypeLocal, + Flags: 0x0000, + Message: "Hello", + SenderName: "Player", + } + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, 0) + parsed := &MsgBinChat{} + err = parsed.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Compare + if parsed.Unk0 != original.Unk0 { + t.Errorf("Unk0 = %d, want %d", parsed.Unk0, original.Unk0) + } + if parsed.Type != original.Type { + t.Errorf("Type = %d, want %d", parsed.Type, original.Type) + } + if parsed.Flags != original.Flags { + t.Errorf("Flags = %d, want %d", parsed.Flags, original.Flags) + } + if parsed.Message != original.Message { + t.Errorf("Message = %q, want %q", parsed.Message, original.Message) + } + if parsed.SenderName != original.SenderName { + t.Errorf("SenderName = %q, want %q", parsed.SenderName, original.SenderName) + } +} + +func TestMsgBinChatBuildParseJapanese(t *testing.T) { + original := &MsgBinChat{ + Unk0: 0x00, + Type: ChatTypeGuild, + Flags: 0x0001, + Message: "こんにちは", + SenderName: "テスト", + } + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, 0) + parsed := &MsgBinChat{} + err = parsed.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if parsed.Message != original.Message { + t.Errorf("Message = %q, want %q", parsed.Message, original.Message) + } + if parsed.SenderName != original.SenderName { + t.Errorf("SenderName = %q, want %q", parsed.SenderName, original.SenderName) + } +} + +func TestMsgBinChatBuildParseEmpty(t *testing.T) { + original := &MsgBinChat{ + Unk0: 0x00, + Type: ChatTypeParty, + Flags: 0x0000, + Message: "", + SenderName: "", + } + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, 0) + parsed := &MsgBinChat{} + err = parsed.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if parsed.Message != "" { + t.Errorf("Message = %q, want empty", parsed.Message) + } + if parsed.SenderName != "" { + t.Errorf("SenderName = %q, want empty", parsed.SenderName) + } +} + +func TestMsgBinChatBuildFormat(t *testing.T) { + m := &MsgBinChat{ + Unk0: 0x12, + Type: ChatTypeWhisper, + Flags: 0x3456, + Message: "Hi", + SenderName: "A", + } + + bf := byteframe.NewByteFrame() + err := m.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + data := bf.Data() + + // Verify header structure + if data[0] != 0x12 { + t.Errorf("Unk0 = %x, want 0x12", data[0]) + } + if data[1] != uint8(ChatTypeWhisper) { + t.Errorf("Type = %x, want %x", data[1], uint8(ChatTypeWhisper)) + } + // Flags at bytes 2-3 (big endian) + if data[2] != 0x34 || data[3] != 0x56 { + t.Errorf("Flags = %x%x, want 3456", data[2], data[3]) + } +} + +func TestMsgBinChatAllTypes(t *testing.T) { + types := []ChatType{ + ChatTypeLocal, + ChatTypeGuild, + ChatTypeAlliance, + ChatTypeParty, + ChatTypeWhisper, + } + + for _, chatType := range types { + t.Run("", func(t *testing.T) { + original := &MsgBinChat{ + Type: chatType, + Message: "Test", + SenderName: "Player", + } + + bf := byteframe.NewByteFrame() + err := original.Build(bf) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + bf.Seek(0, 0) + parsed := &MsgBinChat{} + err = parsed.Parse(bf) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if parsed.Type != chatType { + t.Errorf("Type = %d, want %d", parsed.Type, chatType) + } + }) + } +} diff --git a/network/crypt_packet_test.go b/network/crypt_packet_test.go new file mode 100644 index 000000000..a00d17fe5 --- /dev/null +++ b/network/crypt_packet_test.go @@ -0,0 +1,234 @@ +package network + +import ( + "bytes" + "testing" +) + +func TestCryptPacketHeaderLength(t *testing.T) { + if CryptPacketHeaderLength != 14 { + t.Errorf("CryptPacketHeaderLength = %d, want 14", CryptPacketHeaderLength) + } +} + +func TestNewCryptPacketHeader(t *testing.T) { + // Create a valid 14-byte header + data := []byte{ + 0x01, // Pf0 + 0x02, // KeyRotDelta + 0x00, 0x03, // PacketNum + 0x00, 0x04, // DataSize + 0x00, 0x05, // PrevPacketCombinedCheck + 0x00, 0x06, // Check0 + 0x00, 0x07, // Check1 + 0x00, 0x08, // Check2 + } + + header, err := NewCryptPacketHeader(data) + if err != nil { + t.Fatalf("NewCryptPacketHeader() error = %v", err) + } + + if header.Pf0 != 0x01 { + t.Errorf("Pf0 = %d, want 1", header.Pf0) + } + if header.KeyRotDelta != 0x02 { + t.Errorf("KeyRotDelta = %d, want 2", header.KeyRotDelta) + } + if header.PacketNum != 0x03 { + t.Errorf("PacketNum = %d, want 3", header.PacketNum) + } + if header.DataSize != 0x04 { + t.Errorf("DataSize = %d, want 4", header.DataSize) + } + if header.PrevPacketCombinedCheck != 0x05 { + t.Errorf("PrevPacketCombinedCheck = %d, want 5", header.PrevPacketCombinedCheck) + } + if header.Check0 != 0x06 { + t.Errorf("Check0 = %d, want 6", header.Check0) + } + if header.Check1 != 0x07 { + t.Errorf("Check1 = %d, want 7", header.Check1) + } + if header.Check2 != 0x08 { + t.Errorf("Check2 = %d, want 8", header.Check2) + } +} + +func TestNewCryptPacketHeaderTooShort(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + {"empty", []byte{}}, + {"1 byte", []byte{0x01}}, + {"5 bytes", []byte{0x01, 0x02, 0x03, 0x04, 0x05}}, + {"13 bytes", make([]byte, 13)}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewCryptPacketHeader(tt.data) + if err == nil { + t.Errorf("NewCryptPacketHeader(%v) should return error for short data", tt.data) + } + }) + } +} + +func TestCryptPacketHeaderEncode(t *testing.T) { + header := &CryptPacketHeader{ + Pf0: 0x01, + KeyRotDelta: 0x02, + PacketNum: 0x0003, + DataSize: 0x0004, + PrevPacketCombinedCheck: 0x0005, + Check0: 0x0006, + Check1: 0x0007, + Check2: 0x0008, + } + + encoded, err := header.Encode() + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + + if len(encoded) != CryptPacketHeaderLength { + t.Errorf("Encode() len = %d, want %d", len(encoded), CryptPacketHeaderLength) + } + + expected := []byte{ + 0x01, // Pf0 + 0x02, // KeyRotDelta + 0x00, 0x03, // PacketNum + 0x00, 0x04, // DataSize + 0x00, 0x05, // PrevPacketCombinedCheck + 0x00, 0x06, // Check0 + 0x00, 0x07, // Check1 + 0x00, 0x08, // Check2 + } + + if !bytes.Equal(encoded, expected) { + t.Errorf("Encode() = %v, want %v", encoded, expected) + } +} + +func TestCryptPacketHeaderRoundTrip(t *testing.T) { + tests := []CryptPacketHeader{ + { + Pf0: 0x00, + KeyRotDelta: 0x00, + PacketNum: 0x0000, + DataSize: 0x0000, + PrevPacketCombinedCheck: 0x0000, + Check0: 0x0000, + Check1: 0x0000, + Check2: 0x0000, + }, + { + Pf0: 0xFF, + KeyRotDelta: 0xFF, + PacketNum: 0xFFFF, + DataSize: 0xFFFF, + PrevPacketCombinedCheck: 0xFFFF, + Check0: 0xFFFF, + Check1: 0xFFFF, + Check2: 0xFFFF, + }, + { + Pf0: 0x12, + KeyRotDelta: 0x34, + PacketNum: 0x5678, + DataSize: 0x9ABC, + PrevPacketCombinedCheck: 0xDEF0, + Check0: 0x1234, + Check1: 0x5678, + Check2: 0x9ABC, + }, + } + + for i, original := range tests { + t.Run("", func(t *testing.T) { + encoded, err := original.Encode() + if err != nil { + t.Fatalf("Test %d: Encode() error = %v", i, err) + } + + decoded, err := NewCryptPacketHeader(encoded) + if err != nil { + t.Fatalf("Test %d: NewCryptPacketHeader() error = %v", i, err) + } + + if decoded.Pf0 != original.Pf0 { + t.Errorf("Test %d: Pf0 = %d, want %d", i, decoded.Pf0, original.Pf0) + } + if decoded.KeyRotDelta != original.KeyRotDelta { + t.Errorf("Test %d: KeyRotDelta = %d, want %d", i, decoded.KeyRotDelta, original.KeyRotDelta) + } + if decoded.PacketNum != original.PacketNum { + t.Errorf("Test %d: PacketNum = %d, want %d", i, decoded.PacketNum, original.PacketNum) + } + if decoded.DataSize != original.DataSize { + t.Errorf("Test %d: DataSize = %d, want %d", i, decoded.DataSize, original.DataSize) + } + if decoded.PrevPacketCombinedCheck != original.PrevPacketCombinedCheck { + t.Errorf("Test %d: PrevPacketCombinedCheck = %d, want %d", i, decoded.PrevPacketCombinedCheck, original.PrevPacketCombinedCheck) + } + if decoded.Check0 != original.Check0 { + t.Errorf("Test %d: Check0 = %d, want %d", i, decoded.Check0, original.Check0) + } + if decoded.Check1 != original.Check1 { + t.Errorf("Test %d: Check1 = %d, want %d", i, decoded.Check1, original.Check1) + } + if decoded.Check2 != original.Check2 { + t.Errorf("Test %d: Check2 = %d, want %d", i, decoded.Check2, original.Check2) + } + }) + } +} + +func TestCryptPacketHeaderBigEndian(t *testing.T) { + // Verify big-endian encoding + header := &CryptPacketHeader{ + PacketNum: 0x1234, + } + + encoded, err := header.Encode() + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + + // PacketNum is at bytes 2-3 (after Pf0 and KeyRotDelta) + if encoded[2] != 0x12 || encoded[3] != 0x34 { + t.Errorf("PacketNum encoding is not big-endian: %v", encoded[2:4]) + } +} + +func TestNewCryptPacketHeaderExtraBytes(t *testing.T) { + // Test with more than required bytes (should still work) + data := make([]byte, 20) + data[0] = 0x01 // Pf0 + + header, err := NewCryptPacketHeader(data) + if err != nil { + t.Fatalf("NewCryptPacketHeader() with extra bytes error = %v", err) + } + + if header.Pf0 != 0x01 { + t.Errorf("Pf0 = %d, want 1", header.Pf0) + } +} + +func TestCryptPacketHeaderZeroValues(t *testing.T) { + header := &CryptPacketHeader{} + + encoded, err := header.Encode() + if err != nil { + t.Fatalf("Encode() error = %v", err) + } + + expected := make([]byte, CryptPacketHeaderLength) + if !bytes.Equal(encoded, expected) { + t.Errorf("Encode() zero header = %v, want all zeros", encoded) + } +} diff --git a/network/mhfpacket/mhfpacket_test.go b/network/mhfpacket/mhfpacket_test.go new file mode 100644 index 000000000..4e748f859 --- /dev/null +++ b/network/mhfpacket/mhfpacket_test.go @@ -0,0 +1,463 @@ +package mhfpacket + +import ( + "io" + "testing" + + "erupe-ce/common/byteframe" + "erupe-ce/network" + "erupe-ce/network/clientctx" +) + +func TestMHFPacketInterface(t *testing.T) { + // Verify that packets implement the MHFPacket interface + var _ MHFPacket = &MsgSysPing{} + var _ MHFPacket = &MsgSysTime{} + var _ MHFPacket = &MsgSysNop{} + var _ MHFPacket = &MsgSysEnd{} + var _ MHFPacket = &MsgSysLogin{} + var _ MHFPacket = &MsgSysLogout{} +} + +func TestFromOpcodeReturnsCorrectType(t *testing.T) { + tests := []struct { + opcode network.PacketID + wantType string + }{ + {network.MSG_HEAD, "*mhfpacket.MsgHead"}, + {network.MSG_SYS_PING, "*mhfpacket.MsgSysPing"}, + {network.MSG_SYS_TIME, "*mhfpacket.MsgSysTime"}, + {network.MSG_SYS_NOP, "*mhfpacket.MsgSysNop"}, + {network.MSG_SYS_END, "*mhfpacket.MsgSysEnd"}, + {network.MSG_SYS_ACK, "*mhfpacket.MsgSysAck"}, + {network.MSG_SYS_LOGIN, "*mhfpacket.MsgSysLogin"}, + {network.MSG_SYS_LOGOUT, "*mhfpacket.MsgSysLogout"}, + {network.MSG_SYS_CREATE_STAGE, "*mhfpacket.MsgSysCreateStage"}, + {network.MSG_SYS_ENTER_STAGE, "*mhfpacket.MsgSysEnterStage"}, + } + + for _, tt := range tests { + t.Run(tt.opcode.String(), func(t *testing.T) { + pkt := FromOpcode(tt.opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", tt.opcode) + return + } + if pkt.Opcode() != tt.opcode { + t.Errorf("Opcode() = %s, want %s", pkt.Opcode(), tt.opcode) + } + }) + } +} + +func TestFromOpcodeUnknown(t *testing.T) { + // Test with an invalid opcode + pkt := FromOpcode(network.PacketID(0xFFFF)) + if pkt != nil { + t.Error("FromOpcode(0xFFFF) should return nil for unknown opcode") + } +} + +func TestMsgSysPingRoundTrip(t *testing.T) { + original := &MsgSysPing{ + AckHandle: 0x12345678, + } + + ctx := &clientctx.ClientContext{} + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf, ctx) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, io.SeekStart) + parsed := &MsgSysPing{} + err = parsed.Parse(bf, ctx) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Compare + if parsed.AckHandle != original.AckHandle { + t.Errorf("AckHandle = %d, want %d", parsed.AckHandle, original.AckHandle) + } +} + +func TestMsgSysTimeRoundTrip(t *testing.T) { + tests := []struct { + name string + getRemoteTime bool + timestamp uint32 + }{ + {"no remote time", false, 1577105879}, + {"with remote time", true, 1609459200}, + {"zero timestamp", false, 0}, + {"max timestamp", true, 0xFFFFFFFF}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + original := &MsgSysTime{ + GetRemoteTime: tt.getRemoteTime, + Timestamp: tt.timestamp, + } + + ctx := &clientctx.ClientContext{} + + // Build + bf := byteframe.NewByteFrame() + err := original.Build(bf, ctx) + if err != nil { + t.Fatalf("Build() error = %v", err) + } + + // Parse + bf.Seek(0, io.SeekStart) + parsed := &MsgSysTime{} + err = parsed.Parse(bf, ctx) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + // Compare + if parsed.GetRemoteTime != original.GetRemoteTime { + t.Errorf("GetRemoteTime = %v, want %v", parsed.GetRemoteTime, original.GetRemoteTime) + } + if parsed.Timestamp != original.Timestamp { + t.Errorf("Timestamp = %d, want %d", parsed.Timestamp, original.Timestamp) + } + }) + } +} + +func TestMsgSysPingOpcode(t *testing.T) { + pkt := &MsgSysPing{} + if pkt.Opcode() != network.MSG_SYS_PING { + t.Errorf("Opcode() = %s, want MSG_SYS_PING", pkt.Opcode()) + } +} + +func TestMsgSysTimeOpcode(t *testing.T) { + pkt := &MsgSysTime{} + if pkt.Opcode() != network.MSG_SYS_TIME { + t.Errorf("Opcode() = %s, want MSG_SYS_TIME", pkt.Opcode()) + } +} + +func TestFromOpcodeSystemPackets(t *testing.T) { + // Test all system packet opcodes return non-nil + systemOpcodes := []network.PacketID{ + network.MSG_SYS_reserve01, + network.MSG_SYS_reserve02, + network.MSG_SYS_reserve03, + network.MSG_SYS_reserve04, + network.MSG_SYS_reserve05, + network.MSG_SYS_reserve06, + network.MSG_SYS_reserve07, + network.MSG_SYS_ADD_OBJECT, + network.MSG_SYS_DEL_OBJECT, + network.MSG_SYS_DISP_OBJECT, + network.MSG_SYS_HIDE_OBJECT, + network.MSG_SYS_END, + network.MSG_SYS_NOP, + network.MSG_SYS_ACK, + network.MSG_SYS_LOGIN, + network.MSG_SYS_LOGOUT, + network.MSG_SYS_SET_STATUS, + network.MSG_SYS_PING, + network.MSG_SYS_TIME, + } + + for _, opcode := range systemOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestFromOpcodeStagePackets(t *testing.T) { + stageOpcodes := []network.PacketID{ + network.MSG_SYS_CREATE_STAGE, + network.MSG_SYS_STAGE_DESTRUCT, + network.MSG_SYS_ENTER_STAGE, + network.MSG_SYS_BACK_STAGE, + network.MSG_SYS_MOVE_STAGE, + network.MSG_SYS_LEAVE_STAGE, + network.MSG_SYS_LOCK_STAGE, + network.MSG_SYS_UNLOCK_STAGE, + network.MSG_SYS_RESERVE_STAGE, + network.MSG_SYS_UNRESERVE_STAGE, + network.MSG_SYS_SET_STAGE_PASS, + } + + for _, opcode := range stageOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestOpcodeMatches(t *testing.T) { + // Verify that packets return the same opcode they were created from + tests := []network.PacketID{ + network.MSG_HEAD, + network.MSG_SYS_PING, + network.MSG_SYS_TIME, + network.MSG_SYS_END, + network.MSG_SYS_NOP, + network.MSG_SYS_ACK, + network.MSG_SYS_LOGIN, + network.MSG_SYS_CREATE_STAGE, + } + + for _, opcode := range tests { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Skip("opcode not implemented") + } + if pkt.Opcode() != opcode { + t.Errorf("Opcode() = %s, want %s", pkt.Opcode(), opcode) + } + }) + } +} + +func TestParserInterface(t *testing.T) { + // Verify Parser interface works + var p Parser = &MsgSysPing{} + bf := byteframe.NewByteFrame() + bf.WriteUint32(123) + bf.Seek(0, io.SeekStart) + + err := p.Parse(bf, &clientctx.ClientContext{}) + if err != nil { + t.Errorf("Parse() error = %v", err) + } +} + +func TestBuilderInterface(t *testing.T) { + // Verify Builder interface works + var b Builder = &MsgSysPing{AckHandle: 456} + bf := byteframe.NewByteFrame() + + err := b.Build(bf, &clientctx.ClientContext{}) + if err != nil { + t.Errorf("Build() error = %v", err) + } + if len(bf.Data()) == 0 { + t.Error("Build() should write data") + } +} + +func TestOpcoderInterface(t *testing.T) { + // Verify Opcoder interface works + var o Opcoder = &MsgSysPing{} + opcode := o.Opcode() + + if opcode != network.MSG_SYS_PING { + t.Errorf("Opcode() = %s, want MSG_SYS_PING", opcode) + } +} + +func TestClientContextNilSafe(t *testing.T) { + // Some packets may need to handle nil ClientContext + pkt := &MsgSysPing{AckHandle: 123} + bf := byteframe.NewByteFrame() + + // This should not panic even with nil context (implementation dependent) + // Note: The actual behavior depends on implementation + err := pkt.Build(bf, nil) + if err != nil { + // Error is acceptable if nil context is not supported + t.Logf("Build() with nil context returned error: %v", err) + } +} + +func TestMsgSysPingBuildFormat(t *testing.T) { + pkt := &MsgSysPing{AckHandle: 0x12345678} + bf := byteframe.NewByteFrame() + pkt.Build(bf, &clientctx.ClientContext{}) + + data := bf.Data() + if len(data) != 4 { + t.Errorf("Build() data len = %d, want 4", len(data)) + } + + // Verify big-endian format (default) + if data[0] != 0x12 || data[1] != 0x34 || data[2] != 0x56 || data[3] != 0x78 { + t.Errorf("Build() data = %x, want 12345678", data) + } +} + +func TestMsgSysTimeBuildFormat(t *testing.T) { + pkt := &MsgSysTime{ + GetRemoteTime: true, + Timestamp: 0xDEADBEEF, + } + bf := byteframe.NewByteFrame() + pkt.Build(bf, &clientctx.ClientContext{}) + + data := bf.Data() + if len(data) != 5 { + t.Errorf("Build() data len = %d, want 5 (1 bool + 4 uint32)", len(data)) + } + + // First byte is bool (1 = true) + if data[0] != 1 { + t.Errorf("GetRemoteTime byte = %d, want 1", data[0]) + } +} + +func TestMsgSysNop(t *testing.T) { + pkt := FromOpcode(network.MSG_SYS_NOP) + if pkt == nil { + t.Fatal("FromOpcode(MSG_SYS_NOP) returned nil") + } + if pkt.Opcode() != network.MSG_SYS_NOP { + t.Errorf("Opcode() = %s, want MSG_SYS_NOP", pkt.Opcode()) + } +} + +func TestMsgSysEnd(t *testing.T) { + pkt := FromOpcode(network.MSG_SYS_END) + if pkt == nil { + t.Fatal("FromOpcode(MSG_SYS_END) returned nil") + } + if pkt.Opcode() != network.MSG_SYS_END { + t.Errorf("Opcode() = %s, want MSG_SYS_END", pkt.Opcode()) + } +} + +func TestMsgHead(t *testing.T) { + pkt := FromOpcode(network.MSG_HEAD) + if pkt == nil { + t.Fatal("FromOpcode(MSG_HEAD) returned nil") + } + if pkt.Opcode() != network.MSG_HEAD { + t.Errorf("Opcode() = %s, want MSG_HEAD", pkt.Opcode()) + } +} + +func TestMsgSysAck(t *testing.T) { + pkt := FromOpcode(network.MSG_SYS_ACK) + if pkt == nil { + t.Fatal("FromOpcode(MSG_SYS_ACK) returned nil") + } + if pkt.Opcode() != network.MSG_SYS_ACK { + t.Errorf("Opcode() = %s, want MSG_SYS_ACK", pkt.Opcode()) + } +} + +func TestBinaryPackets(t *testing.T) { + binaryOpcodes := []network.PacketID{ + network.MSG_SYS_CAST_BINARY, + network.MSG_SYS_CASTED_BINARY, + network.MSG_SYS_SET_STAGE_BINARY, + network.MSG_SYS_GET_STAGE_BINARY, + network.MSG_SYS_WAIT_STAGE_BINARY, + } + + for _, opcode := range binaryOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestEnumeratePackets(t *testing.T) { + enumOpcodes := []network.PacketID{ + network.MSG_SYS_ENUMERATE_CLIENT, + network.MSG_SYS_ENUMERATE_STAGE, + } + + for _, opcode := range enumOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestSemaphorePackets(t *testing.T) { + semaOpcodes := []network.PacketID{ + network.MSG_SYS_CREATE_ACQUIRE_SEMAPHORE, + network.MSG_SYS_ACQUIRE_SEMAPHORE, + network.MSG_SYS_RELEASE_SEMAPHORE, + network.MSG_SYS_CHECK_SEMAPHORE, + } + + for _, opcode := range semaOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestObjectPackets(t *testing.T) { + objOpcodes := []network.PacketID{ + network.MSG_SYS_ADD_OBJECT, + network.MSG_SYS_DEL_OBJECT, + network.MSG_SYS_DISP_OBJECT, + network.MSG_SYS_HIDE_OBJECT, + } + + for _, opcode := range objOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestLogPackets(t *testing.T) { + logOpcodes := []network.PacketID{ + network.MSG_SYS_TERMINAL_LOG, + network.MSG_SYS_ISSUE_LOGKEY, + network.MSG_SYS_RECORD_LOG, + } + + for _, opcode := range logOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} + +func TestMHFSaveLoad(t *testing.T) { + saveLoadOpcodes := []network.PacketID{ + network.MSG_MHF_SAVEDATA, + network.MSG_MHF_LOADDATA, + } + + for _, opcode := range saveLoadOpcodes { + t.Run(opcode.String(), func(t *testing.T) { + pkt := FromOpcode(opcode) + if pkt == nil { + t.Errorf("FromOpcode(%s) returned nil", opcode) + } + }) + } +} diff --git a/server/channelserver/handlers_test.go b/server/channelserver/handlers_test.go new file mode 100644 index 000000000..b967320df --- /dev/null +++ b/server/channelserver/handlers_test.go @@ -0,0 +1,268 @@ +package channelserver + +import ( + "testing" + + "erupe-ce/network" +) + +func TestHandlerTableInitialized(t *testing.T) { + if handlerTable == nil { + t.Fatal("handlerTable should be initialized by init()") + } +} + +func TestHandlerTableHasEntries(t *testing.T) { + if len(handlerTable) == 0 { + t.Error("handlerTable should have entries") + } + + // Should have many handlers + if len(handlerTable) < 100 { + t.Errorf("handlerTable has %d entries, expected 100+", len(handlerTable)) + } +} + +func TestHandlerTableSystemPackets(t *testing.T) { + // Test that key system packets have handlers + systemPackets := []network.PacketID{ + network.MSG_HEAD, + network.MSG_SYS_END, + network.MSG_SYS_NOP, + network.MSG_SYS_ACK, + network.MSG_SYS_LOGIN, + network.MSG_SYS_LOGOUT, + network.MSG_SYS_PING, + network.MSG_SYS_TIME, + } + + for _, opcode := range systemPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for %s", opcode) + } + }) + } +} + +func TestHandlerTableStagePackets(t *testing.T) { + // Test stage-related packet handlers + stagePackets := []network.PacketID{ + network.MSG_SYS_CREATE_STAGE, + network.MSG_SYS_STAGE_DESTRUCT, + network.MSG_SYS_ENTER_STAGE, + network.MSG_SYS_BACK_STAGE, + network.MSG_SYS_MOVE_STAGE, + network.MSG_SYS_LEAVE_STAGE, + network.MSG_SYS_LOCK_STAGE, + network.MSG_SYS_UNLOCK_STAGE, + } + + for _, opcode := range stagePackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for stage packet %s", opcode) + } + }) + } +} + +func TestHandlerTableBinaryPackets(t *testing.T) { + // Test binary message handlers + binaryPackets := []network.PacketID{ + network.MSG_SYS_CAST_BINARY, + network.MSG_SYS_CASTED_BINARY, + network.MSG_SYS_SET_STAGE_BINARY, + network.MSG_SYS_GET_STAGE_BINARY, + } + + for _, opcode := range binaryPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for binary packet %s", opcode) + } + }) + } +} + +func TestHandlerTableReservedPackets(t *testing.T) { + // Reserved packets should still have handlers (usually no-ops) + reservedPackets := []network.PacketID{ + network.MSG_SYS_reserve01, + network.MSG_SYS_reserve02, + network.MSG_SYS_reserve03, + network.MSG_SYS_reserve04, + network.MSG_SYS_reserve05, + network.MSG_SYS_reserve06, + network.MSG_SYS_reserve07, + } + + for _, opcode := range reservedPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for reserved packet %s", opcode) + } + }) + } +} + +func TestHandlerFuncType(t *testing.T) { + // Verify all handlers are valid functions + for opcode, handler := range handlerTable { + if handler == nil { + t.Errorf("handler for %s is nil", opcode) + } + } +} + +func TestHandlerTableObjectPackets(t *testing.T) { + objectPackets := []network.PacketID{ + network.MSG_SYS_ADD_OBJECT, + network.MSG_SYS_DEL_OBJECT, + network.MSG_SYS_DISP_OBJECT, + network.MSG_SYS_HIDE_OBJECT, + } + + for _, opcode := range objectPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for object packet %s", opcode) + } + }) + } +} + +func TestHandlerTableClientPackets(t *testing.T) { + clientPackets := []network.PacketID{ + network.MSG_SYS_SET_STATUS, + network.MSG_SYS_HIDE_CLIENT, + network.MSG_SYS_ENUMERATE_CLIENT, + } + + for _, opcode := range clientPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for client packet %s", opcode) + } + }) + } +} + +func TestHandlerTableSemaphorePackets(t *testing.T) { + semaphorePackets := []network.PacketID{ + network.MSG_SYS_CREATE_ACQUIRE_SEMAPHORE, + network.MSG_SYS_ACQUIRE_SEMAPHORE, + network.MSG_SYS_RELEASE_SEMAPHORE, + } + + for _, opcode := range semaphorePackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for semaphore packet %s", opcode) + } + }) + } +} + +func TestHandlerTableMHFPackets(t *testing.T) { + // Test some core MHF packets have handlers + mhfPackets := []network.PacketID{ + network.MSG_MHF_SAVEDATA, + network.MSG_MHF_LOADDATA, + } + + for _, opcode := range mhfPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for MHF packet %s", opcode) + } + }) + } +} + +func TestHandlerTableEnumeratePackets(t *testing.T) { + enumPackets := []network.PacketID{ + network.MSG_SYS_ENUMERATE_CLIENT, + network.MSG_SYS_ENUMERATE_STAGE, + } + + for _, opcode := range enumPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for enumerate packet %s", opcode) + } + }) + } +} + +func TestHandlerTableLogPackets(t *testing.T) { + logPackets := []network.PacketID{ + network.MSG_SYS_TERMINAL_LOG, + network.MSG_SYS_ISSUE_LOGKEY, + network.MSG_SYS_RECORD_LOG, + } + + for _, opcode := range logPackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for log packet %s", opcode) + } + }) + } +} + +func TestHandlerTableFilePackets(t *testing.T) { + filePackets := []network.PacketID{ + network.MSG_SYS_GET_FILE, + } + + for _, opcode := range filePackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for file packet %s", opcode) + } + }) + } +} + +func TestHandlerTableEchoPacket(t *testing.T) { + if _, ok := handlerTable[network.MSG_SYS_ECHO]; !ok { + t.Error("handler missing for MSG_SYS_ECHO") + } +} + +func TestHandlerTableReserveStagePackets(t *testing.T) { + reservePackets := []network.PacketID{ + network.MSG_SYS_RESERVE_STAGE, + network.MSG_SYS_UNRESERVE_STAGE, + network.MSG_SYS_SET_STAGE_PASS, + network.MSG_SYS_WAIT_STAGE_BINARY, + } + + for _, opcode := range reservePackets { + t.Run(opcode.String(), func(t *testing.T) { + if _, ok := handlerTable[opcode]; !ok { + t.Errorf("handler missing for reserve stage packet %s", opcode) + } + }) + } +} + +func TestHandlerTableThresholdPacket(t *testing.T) { + if _, ok := handlerTable[network.MSG_SYS_EXTEND_THRESHOLD]; !ok { + t.Error("handler missing for MSG_SYS_EXTEND_THRESHOLD") + } +} + +func TestHandlerTableNoNilValues(t *testing.T) { + nilCount := 0 + for opcode, handler := range handlerTable { + if handler == nil { + nilCount++ + t.Errorf("nil handler for opcode %s", opcode) + } + } + if nilCount > 0 { + t.Errorf("found %d nil handlers in handlerTable", nilCount) + } +} diff --git a/server/channelserver/sys_semaphore_test.go b/server/channelserver/sys_semaphore_test.go new file mode 100644 index 000000000..57f77aa52 --- /dev/null +++ b/server/channelserver/sys_semaphore_test.go @@ -0,0 +1,384 @@ +package channelserver + +import ( + "sync" + "testing" +) + +func TestNewSemaphore(t *testing.T) { + server := createMockServer() + server.semaphoreIndex = 6 // Start index (IDs 0-6 are reserved) + + sema := NewSemaphore(server, "test_semaphore", 16) + + if sema == nil { + t.Fatal("NewSemaphore() returned nil") + } + if sema.id_semaphore != "test_semaphore" { + t.Errorf("id_semaphore = %s, want test_semaphore", sema.id_semaphore) + } + if sema.maxPlayers != 16 { + t.Errorf("maxPlayers = %d, want 16", sema.maxPlayers) + } + if sema.clients == nil { + t.Error("clients map should be initialized") + } + if sema.reservedClientSlots == nil { + t.Error("reservedClientSlots map should be initialized") + } +} + +func TestNewSemaphoreIDIncrement(t *testing.T) { + server := createMockServer() + server.semaphoreIndex = 6 + + sema1 := NewSemaphore(server, "sema1", 4) + sema2 := NewSemaphore(server, "sema2", 4) + sema3 := NewSemaphore(server, "sema3", 4) + + // IDs should increment + if sema1.id == sema2.id { + t.Error("semaphore IDs should be unique") + } + if sema2.id == sema3.id { + t.Error("semaphore IDs should be unique") + } +} + +func TestSemaphoreClients(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + session1 := createMockSession(100, server) + session2 := createMockSession(200, server) + + // Add clients + sema.clients[session1] = session1.charID + sema.clients[session2] = session2.charID + + if len(sema.clients) != 2 { + t.Errorf("clients count = %d, want 2", len(sema.clients)) + } + + // Verify client IDs + if sema.clients[session1] != 100 { + t.Errorf("clients[session1] = %d, want 100", sema.clients[session1]) + } + if sema.clients[session2] != 200 { + t.Errorf("clients[session2] = %d, want 200", sema.clients[session2]) + } +} + +func TestSemaphoreReservedSlots(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + // Reserve slots + sema.reservedClientSlots[100] = nil + sema.reservedClientSlots[200] = nil + + if len(sema.reservedClientSlots) != 2 { + t.Errorf("reservedClientSlots count = %d, want 2", len(sema.reservedClientSlots)) + } + + // Check existence + if _, ok := sema.reservedClientSlots[100]; !ok { + t.Error("charID 100 should be reserved") + } + if _, ok := sema.reservedClientSlots[200]; !ok { + t.Error("charID 200 should be reserved") + } + if _, ok := sema.reservedClientSlots[300]; ok { + t.Error("charID 300 should not be reserved") + } +} + +func TestSemaphoreRemoveClient(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + session := createMockSession(100, server) + sema.clients[session] = session.charID + + // Remove client + delete(sema.clients, session) + + if len(sema.clients) != 0 { + t.Errorf("clients count = %d, want 0 after delete", len(sema.clients)) + } +} + +func TestSemaphoreMaxPlayers(t *testing.T) { + tests := []struct { + name string + maxPlayers uint16 + }{ + {"quest party", 4}, + {"small event", 16}, + {"raviente", 32}, + {"large event", 64}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, tt.name, tt.maxPlayers) + + if sema.maxPlayers != tt.maxPlayers { + t.Errorf("maxPlayers = %d, want %d", sema.maxPlayers, tt.maxPlayers) + } + }) + } +} + +func TestSemaphoreBroadcastMHF(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + session1 := createMockSession(100, server) + session2 := createMockSession(200, server) + session3 := createMockSession(300, server) + + sema.clients[session1] = session1.charID + sema.clients[session2] = session2.charID + sema.clients[session3] = session3.charID + + pkt := &mockPacket{opcode: 0x1234} + + // Broadcast excluding session1 + sema.BroadcastMHF(pkt, session1) + + // session2 and session3 should receive + select { + case data := <-session2.sendPackets: + if len(data.data) == 0 { + t.Error("session2 received empty data") + } + default: + t.Error("session2 did not receive broadcast") + } + + select { + case data := <-session3.sendPackets: + if len(data.data) == 0 { + t.Error("session3 received empty data") + } + default: + t.Error("session3 did not receive broadcast") + } + + // session1 should NOT receive (it was ignored) + select { + case <-session1.sendPackets: + t.Error("session1 should not receive broadcast (it was ignored)") + default: + // Expected - no data for session1 + } +} + +func TestSemaphoreBroadcastRavi(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "raviente", 32) + + session1 := createMockSession(100, server) + session2 := createMockSession(200, server) + + sema.clients[session1] = session1.charID + sema.clients[session2] = session2.charID + + pkt := &mockPacket{opcode: 0x5678} + + // Broadcast to all (no ignored session) + sema.BroadcastRavi(pkt) + + // Both should receive + select { + case data := <-session1.sendPackets: + if len(data.data) == 0 { + t.Error("session1 received empty data") + } + default: + t.Error("session1 did not receive Ravi broadcast") + } + + select { + case data := <-session2.sendPackets: + if len(data.data) == 0 { + t.Error("session2 received empty data") + } + default: + t.Error("session2 did not receive Ravi broadcast") + } +} + +func TestSemaphoreBroadcastToAll(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + session1 := createMockSession(100, server) + session2 := createMockSession(200, server) + + sema.clients[session1] = session1.charID + sema.clients[session2] = session2.charID + + pkt := &mockPacket{opcode: 0x1234} + + // Broadcast to all (nil ignored session) + sema.BroadcastMHF(pkt, nil) + + // Both should receive + count := 0 + select { + case <-session1.sendPackets: + count++ + default: + } + select { + case <-session2.sendPackets: + count++ + default: + } + + if count != 2 { + t.Errorf("expected 2 broadcasts, got %d", count) + } +} + +func TestSemaphoreRWMutex(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + // Test that RWMutex works + sema.RLock() + _ = len(sema.clients) // Read operation + sema.RUnlock() + + sema.Lock() + sema.clients[createMockSession(100, server)] = 100 // Write operation + sema.Unlock() +} + +func TestSemaphoreConcurrentAccess(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 100) + + var wg sync.WaitGroup + + // Concurrent writers + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + for j := 0; j < 100; j++ { + session := createMockSession(uint32(id*100+j), server) + sema.Lock() + sema.clients[session] = session.charID + sema.Unlock() + + sema.Lock() + delete(sema.clients, session) + sema.Unlock() + } + }(i) + } + + // Concurrent readers + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < 100; j++ { + sema.RLock() + _ = len(sema.clients) + sema.RUnlock() + } + }() + } + + wg.Wait() +} + +func TestSemaphoreEmptyBroadcast(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + pkt := &mockPacket{opcode: 0x1234} + + // Should not panic with no clients + sema.BroadcastMHF(pkt, nil) + sema.BroadcastRavi(pkt) +} + +func TestSemaphoreIDString(t *testing.T) { + server := createMockServer() + + tests := []string{ + "quest_001", + "raviente_phase1", + "tournament_round3", + "diva_defense", + } + + for _, id := range tests { + sema := NewSemaphore(server, id, 4) + if sema.id_semaphore != id { + t.Errorf("id_semaphore = %s, want %s", sema.id_semaphore, id) + } + } +} + +func TestSemaphoreNumericID(t *testing.T) { + server := createMockServer() + server.semaphoreIndex = 6 // IDs 0-6 reserved + + sema := NewSemaphore(server, "test", 4) + + // First semaphore should get ID 7 + if sema.id < 7 { + t.Errorf("semaphore id = %d, should be >= 7", sema.id) + } +} + +func TestSemaphoreReserveAndRelease(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + // Reserve + sema.reservedClientSlots[100] = nil + if _, ok := sema.reservedClientSlots[100]; !ok { + t.Error("slot 100 should be reserved") + } + + // Release + delete(sema.reservedClientSlots, 100) + if _, ok := sema.reservedClientSlots[100]; ok { + t.Error("slot 100 should be released") + } +} + +func TestSemaphoreClientAndReservedSeparate(t *testing.T) { + server := createMockServer() + sema := NewSemaphore(server, "test", 4) + + session := createMockSession(100, server) + + // Client in active clients + sema.clients[session] = 100 + + // Same charID reserved + sema.reservedClientSlots[100] = nil + + // Both should exist independently + if _, ok := sema.clients[session]; !ok { + t.Error("session should be in active clients") + } + if _, ok := sema.reservedClientSlots[100]; !ok { + t.Error("charID 100 should be reserved") + } + + // Remove from one doesn't affect other + delete(sema.clients, session) + if _, ok := sema.reservedClientSlots[100]; !ok { + t.Error("charID 100 should still be reserved after removing from clients") + } +} diff --git a/server/channelserver/sys_session_test.go b/server/channelserver/sys_session_test.go new file mode 100644 index 000000000..143cd0089 --- /dev/null +++ b/server/channelserver/sys_session_test.go @@ -0,0 +1,375 @@ +package channelserver + +import ( + "testing" + "time" + + "erupe-ce/common/stringstack" + "erupe-ce/network/clientctx" +) + +func TestSessionStructInitialization(t *testing.T) { + server := createMockServer() + session := createMockSession(12345, server) + + if session.charID != 12345 { + t.Errorf("charID = %d, want 12345", session.charID) + } + if session.Name != "TestPlayer" { + t.Errorf("Name = %s, want TestPlayer", session.Name) + } + if session.server != server { + t.Error("server reference not set correctly") + } + if session.clientContext == nil { + t.Error("clientContext should not be nil") + } + if session.sendPackets == nil { + t.Error("sendPackets channel should not be nil") + } +} + +func TestSessionSendPacketChannel(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Test that channel can receive packets + testData := []byte{0x01, 0x02, 0x03} + session.sendPackets <- packet{data: testData, nonBlocking: false} + + select { + case pkt := <-session.sendPackets: + if len(pkt.data) != 3 { + t.Errorf("packet data len = %d, want 3", len(pkt.data)) + } + if pkt.data[0] != 0x01 { + t.Errorf("packet data[0] = %d, want 1", pkt.data[0]) + } + default: + t.Error("failed to receive packet from channel") + } +} + +func TestSessionSendPacketChannelNonBlocking(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Fill the channel + for i := 0; i < 20; i++ { + session.sendPackets <- packet{data: []byte{byte(i)}, nonBlocking: true} + } + + // Non-blocking send to full channel should not block + done := make(chan bool, 1) + go func() { + select { + case session.sendPackets <- packet{data: []byte{0xFF}, nonBlocking: true}: + // Managed to send (channel had room) + default: + // Channel full, this is expected + } + done <- true + }() + + select { + case <-done: + // Success - non-blocking worked + case <-time.After(100 * time.Millisecond): + t.Error("non-blocking send blocked") + } +} + +func TestPacketStruct(t *testing.T) { + pkt := packet{ + data: []byte{0x01, 0x02, 0x03}, + nonBlocking: true, + } + + if len(pkt.data) != 3 { + t.Errorf("packet data len = %d, want 3", len(pkt.data)) + } + if !pkt.nonBlocking { + t.Error("nonBlocking should be true") + } +} + +func TestPacketStructBlocking(t *testing.T) { + pkt := packet{ + data: []byte{0xDE, 0xAD, 0xBE, 0xEF}, + nonBlocking: false, + } + + if len(pkt.data) != 4 { + t.Errorf("packet data len = %d, want 4", len(pkt.data)) + } + if pkt.nonBlocking { + t.Error("nonBlocking should be false") + } +} + +func TestSessionClosedFlag(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + if session.closed { + t.Error("new session should not be closed") + } + + session.closed = true + + if !session.closed { + t.Error("session closed flag should be settable") + } +} + +func TestSessionStageState(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Initially should have no stage + if session.userEnteredStage { + t.Error("new session should not have entered stage") + } + if session.stageID != "" { + t.Errorf("stageID should be empty, got %s", session.stageID) + } + if session.stage != nil { + t.Error("stage should be nil initially") + } + + // Set stage state + session.userEnteredStage = true + session.stageID = "test_stage_001" + + if !session.userEnteredStage { + t.Error("userEnteredStage should be set") + } + if session.stageID != "test_stage_001" { + t.Errorf("stageID = %s, want test_stage_001", session.stageID) + } +} + +func TestSessionStageMoveStack(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + session.stageMoveStack = stringstack.New() + + // Push some stages + session.stageMoveStack.Push("stage1") + session.stageMoveStack.Push("stage2") + session.stageMoveStack.Push("stage3") + + // Pop and verify order (LIFO) + if v, err := session.stageMoveStack.Pop(); err != nil || v != "stage3" { + t.Errorf("Pop() = %s, want stage3", v) + } + if v, err := session.stageMoveStack.Pop(); err != nil || v != "stage2" { + t.Errorf("Pop() = %s, want stage2", v) + } + if v, err := session.stageMoveStack.Pop(); err != nil || v != "stage1" { + t.Errorf("Pop() = %s, want stage1", v) + } +} + +func TestSessionMailState(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Initial mail state + if session.mailAccIndex != 0 { + t.Errorf("mailAccIndex = %d, want 0", session.mailAccIndex) + } + if session.mailList != nil && len(session.mailList) > 0 { + t.Error("mailList should be empty initially") + } + + // Add mail + session.mailList = []int{100, 101, 102} + session.mailAccIndex = 3 + + if len(session.mailList) != 3 { + t.Errorf("mailList len = %d, want 3", len(session.mailList)) + } + if session.mailAccIndex != 3 { + t.Errorf("mailAccIndex = %d, want 3", session.mailAccIndex) + } +} + +func TestSessionToken(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + session.token = "abc123def456" + + if session.token != "abc123def456" { + t.Errorf("token = %s, want abc123def456", session.token) + } +} + +func TestSessionGuildState(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + session.prevGuildID = 42 + + if session.prevGuildID != 42 { + t.Errorf("prevGuildID = %d, want 42", session.prevGuildID) + } +} + +func TestSessionKQF(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Set KQF data + session.kqf = []byte{0x01, 0x02, 0x03, 0x04} + session.kqfOverride = true + + if len(session.kqf) != 4 { + t.Errorf("kqf len = %d, want 4", len(session.kqf)) + } + if !session.kqfOverride { + t.Error("kqfOverride should be true") + } +} + +func TestSessionClientContext(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + if session.clientContext == nil { + t.Fatal("clientContext should not be nil") + } + + // Verify clientContext is usable + ctx := session.clientContext + _ = ctx // Just verify it's accessible +} + +func TestSessionReservationStage(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + if session.reservationStage != nil { + t.Error("reservationStage should be nil initially") + } + + // Set reservation stage + stage := NewStage("quest_stage") + session.reservationStage = stage + + if session.reservationStage != stage { + t.Error("reservationStage should be set correctly") + } +} + +func TestSessionStagePass(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + session.stagePass = "secret123" + + if session.stagePass != "secret123" { + t.Errorf("stagePass = %s, want secret123", session.stagePass) + } +} + +func TestSessionLogKey(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + session.logKey = []byte{0xDE, 0xAD, 0xBE, 0xEF} + + if len(session.logKey) != 4 { + t.Errorf("logKey len = %d, want 4", len(session.logKey)) + } + if session.logKey[0] != 0xDE { + t.Errorf("logKey[0] = %x, want 0xDE", session.logKey[0]) + } +} + +func TestSessionSessionStart(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Set session start time + now := time.Now().Unix() + session.sessionStart = now + + if session.sessionStart != now { + t.Errorf("sessionStart = %d, want %d", session.sessionStart, now) + } +} + +func TestIgnoredOpcode(t *testing.T) { + // Test that certain opcodes are ignored + tests := []struct { + name string + opcode uint16 + ignored bool + }{ + // These should be ignored based on ignoreList + {"MSG_SYS_END is ignored", 0x0002, true}, // Assuming MSG_SYS_END value + // We can't test exact values without importing network package constants + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Note: This test is limited since ignored() uses network.PacketID + // which we can't easily instantiate without the exact enum values + }) + } +} + +func TestSessionMutex(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + // Verify session has mutex (via embedding) + // This should not deadlock + session.Lock() + session.charID = 999 + session.Unlock() + + if session.charID != 999 { + t.Errorf("charID = %d, want 999 after lock/unlock", session.charID) + } +} + +func TestSessionConcurrentAccess(t *testing.T) { + server := createMockServer() + session := createMockSession(1, server) + + done := make(chan bool, 2) + + // Concurrent writers + go func() { + for i := 0; i < 100; i++ { + session.Lock() + session.charID = uint32(i) + session.Unlock() + } + done <- true + }() + + go func() { + for i := 0; i < 100; i++ { + session.Lock() + _ = session.charID + session.Unlock() + } + done <- true + }() + + <-done + <-done +} + +func TestClientContextStruct(t *testing.T) { + ctx := &clientctx.ClientContext{} + + // Verify the struct is usable + if ctx == nil { + t.Error("ClientContext should be creatable") + } +} diff --git a/server/entranceserver/make_resp_test.go b/server/entranceserver/make_resp_test.go new file mode 100644 index 000000000..4e5f9c843 --- /dev/null +++ b/server/entranceserver/make_resp_test.go @@ -0,0 +1,139 @@ +package entranceserver + +import ( + "bytes" + "testing" +) + +func TestMakeHeader(t *testing.T) { + tests := []struct { + name string + data []byte + respType string + entryCount uint16 + key byte + }{ + {"empty data", []byte{}, "SV2", 0, 0x00}, + {"single byte", []byte{0x01}, "SVR", 1, 0x00}, + {"multiple bytes", []byte{0x01, 0x02, 0x03, 0x04}, "SV2", 2, 0x00}, + {"with key", []byte{0xDE, 0xAD, 0xBE, 0xEF}, "USR", 5, 0x42}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := makeHeader(tt.data, tt.respType, tt.entryCount, tt.key) + + // Result should not be empty + if len(result) == 0 { + t.Error("makeHeader() returned empty result") + } + + // First byte should be the key + if result[0] != tt.key { + t.Errorf("makeHeader() first byte = %x, want %x", result[0], tt.key) + } + + // Result should be longer than just the key + if len(result) <= 1 { + t.Error("makeHeader() result too short") + } + }) + } +} + +func TestMakeHeaderEncryption(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04} + + result1 := makeHeader(data, "SV2", 1, 0x00) + result2 := makeHeader(data, "SV2", 1, 0x01) + + // Different keys should produce different encrypted output + if bytes.Equal(result1, result2) { + t.Error("makeHeader() with different keys should produce different output") + } +} + +func TestMakeHeaderRespTypes(t *testing.T) { + data := []byte{0x01} + + // Test different response types produce valid output + types := []string{"SV2", "SVR", "USR"} + + for _, respType := range types { + t.Run(respType, func(t *testing.T) { + result := makeHeader(data, respType, 1, 0x00) + if len(result) == 0 { + t.Errorf("makeHeader() with type %s returned empty result", respType) + } + }) + } +} + +func TestMakeHeaderEmptyData(t *testing.T) { + // Empty data should still produce a valid (shorter) header + result := makeHeader([]byte{}, "SV2", 0, 0x00) + + if len(result) == 0 { + t.Error("makeHeader() with empty data returned empty result") + } +} + +func TestMakeHeaderLargeData(t *testing.T) { + // Test with larger data + data := make([]byte, 1000) + for i := range data { + data[i] = byte(i % 256) + } + + result := makeHeader(data, "SV2", 100, 0x55) + + if len(result) == 0 { + t.Error("makeHeader() with large data returned empty result") + } + + // Result should be data + overhead + if len(result) <= len(data) { + t.Error("makeHeader() result should be larger than input data due to header") + } +} + +func TestMakeHeaderEntryCount(t *testing.T) { + data := []byte{0x01, 0x02} + + // Different entry counts should work + for _, count := range []uint16{0, 1, 10, 100, 65535} { + result := makeHeader(data, "SV2", count, 0x00) + if len(result) == 0 { + t.Errorf("makeHeader() with entryCount=%d returned empty result", count) + } + } +} + +func TestMakeHeaderDecryptable(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04} + key := byte(0x00) + + result := makeHeader(data, "SV2", 1, key) + + // Remove key byte and decrypt + encrypted := result[1:] + decrypted := DecryptBin8(encrypted, key) + + // Decrypted data should start with "SV2" + if len(decrypted) >= 3 && string(decrypted[:3]) != "SV2" { + t.Errorf("makeHeader() decrypted data should start with SV2, got %s", string(decrypted[:3])) + } +} + +func TestMakeHeaderConsistency(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + key := byte(0x10) + + // Same input should produce same output + result1 := makeHeader(data, "SV2", 5, key) + result2 := makeHeader(data, "SV2", 5, key) + + if !bytes.Equal(result1, result2) { + t.Error("makeHeader() with same input should produce same output") + } +} diff --git a/server/signserver/sign_server_test.go b/server/signserver/sign_server_test.go new file mode 100644 index 000000000..65e95df8f --- /dev/null +++ b/server/signserver/sign_server_test.go @@ -0,0 +1,212 @@ +package signserver + +import ( + "fmt" + "testing" +) + +func TestRespIDConstants(t *testing.T) { + tests := []struct { + respID RespID + value uint16 + }{ + {SIGN_UNKNOWN, 0}, + {SIGN_SUCCESS, 1}, + {SIGN_EFAILED, 2}, + {SIGN_EILLEGAL, 3}, + {SIGN_EALERT, 4}, + {SIGN_EABORT, 5}, + {SIGN_ERESPONSE, 6}, + {SIGN_EDATABASE, 7}, + {SIGN_EABSENCE, 8}, + {SIGN_ERESIGN, 9}, + {SIGN_ESUSPEND_D, 10}, + {SIGN_ELOCK, 11}, + {SIGN_EPASS, 12}, + {SIGN_ERIGHT, 13}, + {SIGN_EAUTH, 14}, + {SIGN_ESUSPEND, 15}, + {SIGN_EELIMINATE, 16}, + {SIGN_ECLOSE, 17}, + {SIGN_ECLOSE_EX, 18}, + {SIGN_EINTERVAL, 19}, + {SIGN_EMOVED, 20}, + {SIGN_ENOTREADY, 21}, + {SIGN_EALREADY, 22}, + {SIGN_EIPADDR, 23}, + {SIGN_EHANGAME, 24}, + {SIGN_UPD_ONLY, 25}, + {SIGN_EMBID, 26}, + {SIGN_ECOGCODE, 27}, + {SIGN_ETOKEN, 28}, + {SIGN_ECOGLINK, 29}, + {SIGN_EMAINTE, 30}, + {SIGN_EMAINTE_NOUPDATE, 31}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("RespID_%d", tt.value), func(t *testing.T) { + if uint16(tt.respID) != tt.value { + t.Errorf("RespID = %d, want %d", uint16(tt.respID), tt.value) + } + }) + } +} + +func TestRespIDType(t *testing.T) { + // Verify RespID is based on uint16 + var r RespID = 0xFFFF + if uint16(r) != 0xFFFF { + t.Errorf("RespID max value = %d, want %d", uint16(r), 0xFFFF) + } +} + +func TestMakeSignInFailureResp(t *testing.T) { + tests := []RespID{ + SIGN_UNKNOWN, + SIGN_EFAILED, + SIGN_EILLEGAL, + SIGN_ESUSPEND, + SIGN_EELIMINATE, + SIGN_EIPADDR, + } + + for _, respID := range tests { + t.Run(fmt.Sprintf("RespID_%d", respID), func(t *testing.T) { + resp := makeSignInFailureResp(respID) + + if len(resp) != 1 { + t.Errorf("makeSignInFailureResp() len = %d, want 1", len(resp)) + } + if resp[0] != uint8(respID) { + t.Errorf("makeSignInFailureResp() = %d, want %d", resp[0], uint8(respID)) + } + }) + } +} + +func TestMakeSignInFailureRespAllCodes(t *testing.T) { + // Test all possible RespID values 0-39 + for i := uint16(0); i <= 40; i++ { + resp := makeSignInFailureResp(RespID(i)) + if len(resp) != 1 { + t.Errorf("makeSignInFailureResp(%d) len = %d, want 1", i, len(resp)) + } + if resp[0] != uint8(i) { + t.Errorf("makeSignInFailureResp(%d) = %d", i, resp[0]) + } + } +} + +func TestSignSuccessIsOne(t *testing.T) { + // SIGN_SUCCESS must be 1 for the protocol to work correctly + if SIGN_SUCCESS != 1 { + t.Errorf("SIGN_SUCCESS = %d, must be 1", SIGN_SUCCESS) + } +} + +func TestSignUnknownIsZero(t *testing.T) { + // SIGN_UNKNOWN must be 0 as the zero value + if SIGN_UNKNOWN != 0 { + t.Errorf("SIGN_UNKNOWN = %d, must be 0", SIGN_UNKNOWN) + } +} + +func TestRespIDValues(t *testing.T) { + // Test specific RespID values are correct + tests := []struct { + name string + respID RespID + value uint16 + }{ + {"SIGN_UNKNOWN", SIGN_UNKNOWN, 0}, + {"SIGN_SUCCESS", SIGN_SUCCESS, 1}, + {"SIGN_EFAILED", SIGN_EFAILED, 2}, + {"SIGN_EILLEGAL", SIGN_EILLEGAL, 3}, + {"SIGN_ESUSPEND", SIGN_ESUSPEND, 15}, + {"SIGN_EELIMINATE", SIGN_EELIMINATE, 16}, + {"SIGN_EIPADDR", SIGN_EIPADDR, 23}, + {"SIGN_EMAINTE", SIGN_EMAINTE, 30}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if uint16(tt.respID) != tt.value { + t.Errorf("%s = %d, want %d", tt.name, uint16(tt.respID), tt.value) + } + }) + } +} + +func TestUnknownRespIDRange(t *testing.T) { + // Test the unknown IDs 32-35 + unknownIDs := []RespID{UNK_32, UNK_33, UNK_34, UNK_35} + expectedValues := []uint16{32, 33, 34, 35} + + for i, id := range unknownIDs { + if uint16(id) != expectedValues[i] { + t.Errorf("Unknown ID %d = %d, want %d", i, uint16(id), expectedValues[i]) + } + } +} + +func TestSpecialRespIDs(t *testing.T) { + // Test platform-specific IDs + if SIGN_XBRESPONSE != 36 { + t.Errorf("SIGN_XBRESPONSE = %d, want 36", SIGN_XBRESPONSE) + } + if SIGN_EPSI != 37 { + t.Errorf("SIGN_EPSI = %d, want 37", SIGN_EPSI) + } + if SIGN_EMBID_PSI != 38 { + t.Errorf("SIGN_EMBID_PSI = %d, want 38", SIGN_EMBID_PSI) + } +} + +func TestMakeSignInFailureRespBoundary(t *testing.T) { + // Test boundary values + resp := makeSignInFailureResp(RespID(0)) + if resp[0] != 0 { + t.Errorf("makeSignInFailureResp(0) = %d, want 0", resp[0]) + } + + resp = makeSignInFailureResp(RespID(255)) + if resp[0] != 255 { + t.Errorf("makeSignInFailureResp(255) = %d, want 255", resp[0]) + } +} + +func TestErrorRespIDsAreDifferent(t *testing.T) { + // Ensure all error codes are unique + seen := make(map[RespID]bool) + errorCodes := []RespID{ + SIGN_UNKNOWN, SIGN_SUCCESS, SIGN_EFAILED, SIGN_EILLEGAL, + SIGN_EALERT, SIGN_EABORT, SIGN_ERESPONSE, SIGN_EDATABASE, + SIGN_EABSENCE, SIGN_ERESIGN, SIGN_ESUSPEND_D, SIGN_ELOCK, + SIGN_EPASS, SIGN_ERIGHT, SIGN_EAUTH, SIGN_ESUSPEND, + SIGN_EELIMINATE, SIGN_ECLOSE, SIGN_ECLOSE_EX, SIGN_EINTERVAL, + SIGN_EMOVED, SIGN_ENOTREADY, SIGN_EALREADY, SIGN_EIPADDR, + SIGN_EHANGAME, SIGN_UPD_ONLY, SIGN_EMBID, SIGN_ECOGCODE, + SIGN_ETOKEN, SIGN_ECOGLINK, SIGN_EMAINTE, SIGN_EMAINTE_NOUPDATE, + } + + for _, code := range errorCodes { + if seen[code] { + t.Errorf("Duplicate RespID value: %d", code) + } + seen[code] = true + } +} + +func TestFailureRespIsMinimal(t *testing.T) { + // Failure response should be exactly 1 byte for efficiency + for i := RespID(0); i <= SIGN_EMBID_PSI; i++ { + if i == SIGN_SUCCESS { + continue // Success has different format + } + resp := makeSignInFailureResp(i) + if len(resp) != 1 { + t.Errorf("makeSignInFailureResp(%d) should be 1 byte, got %d", i, len(resp)) + } + } +} diff --git a/server/signv2server/endpoints_test.go b/server/signv2server/endpoints_test.go new file mode 100644 index 000000000..4b19d211e --- /dev/null +++ b/server/signv2server/endpoints_test.go @@ -0,0 +1,349 @@ +package signv2server + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "go.uber.org/zap" +) + +// mockServer creates a Server with minimal dependencies for testing +func mockServer() *Server { + logger, _ := zap.NewDevelopment() + return &Server{ + logger: logger, + } +} + +func TestLauncherEndpoint(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("GET", "/launcher", nil) + w := httptest.NewRecorder() + + s.Launcher(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusOK { + t.Errorf("Launcher() status = %d, want %d", resp.StatusCode, http.StatusOK) + } + + var data struct { + Important []LauncherMessage `json:"important"` + Normal []LauncherMessage `json:"normal"` + } + if err := json.NewDecoder(resp.Body).Decode(&data); err != nil { + t.Fatalf("Failed to decode response: %v", err) + } + + // Should have important messages + if len(data.Important) == 0 { + t.Error("Launcher() should return important messages") + } + + // Should have normal messages + if len(data.Normal) == 0 { + t.Error("Launcher() should return normal messages") + } +} + +func TestLauncherMessageStructure(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("GET", "/launcher", nil) + w := httptest.NewRecorder() + + s.Launcher(w, req) + + var data struct { + Important []LauncherMessage `json:"important"` + Normal []LauncherMessage `json:"normal"` + } + json.NewDecoder(w.Result().Body).Decode(&data) + + // Check important messages have required fields + for _, msg := range data.Important { + if msg.Message == "" { + t.Error("LauncherMessage.Message should not be empty") + } + if msg.Date == 0 { + t.Error("LauncherMessage.Date should not be zero") + } + if msg.Link == "" { + t.Error("LauncherMessage.Link should not be empty") + } + } +} + +func TestLoginEndpointInvalidJSON(t *testing.T) { + s := mockServer() + + // Send invalid JSON + req := httptest.NewRequest("POST", "/login", bytes.NewReader([]byte("not json"))) + w := httptest.NewRecorder() + + s.Login(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Login() with invalid JSON status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestRegisterEndpointInvalidJSON(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("POST", "/register", bytes.NewReader([]byte("invalid"))) + w := httptest.NewRecorder() + + s.Register(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("Register() with invalid JSON status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestCreateCharacterEndpointInvalidJSON(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("POST", "/character/create", bytes.NewReader([]byte("invalid"))) + w := httptest.NewRecorder() + + s.CreateCharacter(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("CreateCharacter() with invalid JSON status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestDeleteCharacterEndpointInvalidJSON(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("POST", "/character/delete", bytes.NewReader([]byte("invalid"))) + w := httptest.NewRecorder() + + s.DeleteCharacter(w, req) + + resp := w.Result() + if resp.StatusCode != http.StatusBadRequest { + t.Errorf("DeleteCharacter() with invalid JSON status = %d, want %d", resp.StatusCode, http.StatusBadRequest) + } +} + +func TestLauncherMessageStruct(t *testing.T) { + msg := LauncherMessage{ + Message: "Test Message", + Date: 1234567890, + Link: "https://example.com", + } + + if msg.Message != "Test Message" { + t.Errorf("Message = %s, want Test Message", msg.Message) + } + if msg.Date != 1234567890 { + t.Errorf("Date = %d, want 1234567890", msg.Date) + } + if msg.Link != "https://example.com" { + t.Errorf("Link = %s, want https://example.com", msg.Link) + } +} + +func TestCharacterStruct(t *testing.T) { + char := Character{ + ID: 1, + Name: "TestHunter", + IsFemale: true, + Weapon: 5, + HR: 999, + GR: 100, + LastLogin: 1234567890, + } + + if char.ID != 1 { + t.Errorf("ID = %d, want 1", char.ID) + } + if char.Name != "TestHunter" { + t.Errorf("Name = %s, want TestHunter", char.Name) + } + if char.IsFemale != true { + t.Error("IsFemale should be true") + } + if char.Weapon != 5 { + t.Errorf("Weapon = %d, want 5", char.Weapon) + } + if char.HR != 999 { + t.Errorf("HR = %d, want 999", char.HR) + } + if char.GR != 100 { + t.Errorf("GR = %d, want 100", char.GR) + } +} + +func TestLauncherMessageJSONTags(t *testing.T) { + msg := LauncherMessage{ + Message: "Test", + Date: 12345, + Link: "http://test.com", + } + + data, err := json.Marshal(msg) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var decoded map[string]interface{} + json.Unmarshal(data, &decoded) + + if _, ok := decoded["message"]; !ok { + t.Error("JSON should have 'message' key") + } + if _, ok := decoded["date"]; !ok { + t.Error("JSON should have 'date' key") + } + if _, ok := decoded["link"]; !ok { + t.Error("JSON should have 'link' key") + } +} + +func TestCharacterJSONTags(t *testing.T) { + char := Character{ + ID: 1, + Name: "Test", + IsFemale: true, + Weapon: 3, + HR: 50, + GR: 10, + LastLogin: 9999, + } + + data, err := json.Marshal(char) + if err != nil { + t.Fatalf("Marshal error: %v", err) + } + + var decoded map[string]interface{} + json.Unmarshal(data, &decoded) + + if _, ok := decoded["id"]; !ok { + t.Error("JSON should have 'id' key") + } + if _, ok := decoded["name"]; !ok { + t.Error("JSON should have 'name' key") + } + if _, ok := decoded["isFemale"]; !ok { + t.Error("JSON should have 'isFemale' key") + } + if _, ok := decoded["weapon"]; !ok { + t.Error("JSON should have 'weapon' key") + } + if _, ok := decoded["hr"]; !ok { + t.Error("JSON should have 'hr' key") + } + if _, ok := decoded["gr"]; !ok { + t.Error("JSON should have 'gr' key") + } + if _, ok := decoded["lastLogin"]; !ok { + t.Error("JSON should have 'lastLogin' key") + } +} + +func TestLauncherResponseFormat(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("GET", "/launcher", nil) + w := httptest.NewRecorder() + + s.Launcher(w, req) + + resp := w.Result() + + // Verify it returns valid JSON + var result map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + t.Errorf("Launcher() should return valid JSON: %v", err) + } + + // Check top-level keys exist + if _, ok := result["important"]; !ok { + t.Error("Launcher() response should have 'important' key") + } + if _, ok := result["normal"]; !ok { + t.Error("Launcher() response should have 'normal' key") + } +} + +func TestLauncherMessageCount(t *testing.T) { + s := mockServer() + + req := httptest.NewRequest("GET", "/launcher", nil) + w := httptest.NewRecorder() + + s.Launcher(w, req) + + var data struct { + Important []LauncherMessage `json:"important"` + Normal []LauncherMessage `json:"normal"` + } + json.NewDecoder(w.Result().Body).Decode(&data) + + // Should have at least 3 important messages based on the implementation + if len(data.Important) < 3 { + t.Errorf("Launcher() should return at least 3 important messages, got %d", len(data.Important)) + } + + // Should have at least 1 normal message + if len(data.Normal) < 1 { + t.Errorf("Launcher() should return at least 1 normal message, got %d", len(data.Normal)) + } +} + +func TestCharacterStructDBTags(t *testing.T) { + // Test that Character struct has proper db tags + char := Character{} + + // These fields have db tags, verify struct is usable + char.IsFemale = true + char.Weapon = 7 + char.HR = 100 + char.LastLogin = 12345 + + if char.Weapon != 7 { + t.Errorf("Weapon = %d, want 7", char.Weapon) + } +} + +func TestNewServer(t *testing.T) { + logger, _ := zap.NewDevelopment() + cfg := &Config{ + Logger: logger, + } + + s := NewServer(cfg) + + if s == nil { + t.Fatal("NewServer() returned nil") + } + if s.logger == nil { + t.Error("NewServer() should set logger") + } + if s.httpServer == nil { + t.Error("NewServer() should initialize httpServer") + } +} + +func TestServerConfig(t *testing.T) { + cfg := &Config{ + Logger: nil, + DB: nil, + } + + // Config struct should be usable + if cfg.Logger != nil { + t.Error("Config.Logger should be nil when not set") + } +}