diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..8f2964736 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,6 @@ +# Contributing to Erupe + +Before submitting a new version: + +- Document your changes in [CHANGELOG.md](CHANGELOG.md). +- Run tests: `go test -v ./...` and check for race conditions: `go test -v -race ./...` diff --git a/common/bfutil/bfutil_test.go b/common/bfutil/bfutil_test.go new file mode 100644 index 000000000..51fad0e13 --- /dev/null +++ b/common/bfutil/bfutil_test.go @@ -0,0 +1,105 @@ +package bfutil + +import ( + "bytes" + "testing" +) + +func TestUpToNull(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "data with null terminator", + input: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x00, 0x57, 0x6F, 0x72, 0x6C, 0x64}, + expected: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, // "Hello" + }, + { + name: "data without null terminator", + input: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, + expected: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, // "Hello" + }, + { + name: "data with null at start", + input: []byte{0x00, 0x48, 0x65, 0x6C, 0x6C, 0x6F}, + expected: []byte{}, + }, + { + name: "empty slice", + input: []byte{}, + expected: []byte{}, + }, + { + name: "only null byte", + input: []byte{0x00}, + expected: []byte{}, + }, + { + name: "multiple null bytes", + input: []byte{0x48, 0x65, 0x00, 0x00, 0x6C, 0x6C, 0x6F}, + expected: []byte{0x48, 0x65}, // "He" + }, + { + name: "binary data with null", + input: []byte{0xFF, 0xAB, 0x12, 0x00, 0x34, 0x56}, + expected: []byte{0xFF, 0xAB, 0x12}, + }, + { + name: "binary data without null", + input: []byte{0xFF, 0xAB, 0x12, 0x34, 0x56}, + expected: []byte{0xFF, 0xAB, 0x12, 0x34, 0x56}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UpToNull(tt.input) + if !bytes.Equal(result, tt.expected) { + t.Errorf("UpToNull() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUpToNull_ReturnsSliceNotCopy(t *testing.T) { + // Test that UpToNull returns a slice of the original array, not a copy + input := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x00, 0x57, 0x6F, 0x72, 0x6C, 0x64} + result := UpToNull(input) + + // Verify we got the expected data + expected := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F} + if !bytes.Equal(result, expected) { + t.Errorf("UpToNull() = %v, want %v", result, expected) + } + + // The result should be a slice of the input array + if len(result) > 0 && cap(result) < len(expected) { + t.Error("Result should be a slice of input array") + } +} + +func BenchmarkUpToNull(b *testing.B) { + data := []byte("Hello, World!\x00Extra data here") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} + +func BenchmarkUpToNull_NoNull(b *testing.B) { + data := []byte("Hello, World! No null terminator in this string at all") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} + +func BenchmarkUpToNull_NullAtStart(b *testing.B) { + data := []byte("\x00Hello, World!") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} diff --git a/common/pascalstring/pascalstring_test.go b/common/pascalstring/pascalstring_test.go new file mode 100644 index 000000000..8c4e145c0 --- /dev/null +++ b/common/pascalstring/pascalstring_test.go @@ -0,0 +1,369 @@ +package pascalstring + +import ( + "bytes" + "erupe-ce/common/byteframe" + "testing" +) + +func TestUint8_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Hello" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + expectedLength := uint8(len(testString) + 1) // +1 for null terminator + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + // Should be "Hello\x00" + expected := []byte("Hello\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint8_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + // ASCII string (no special characters) + testString := "Test" + + Uint8(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint8() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + // Should end with null terminator + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint8_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + + if length != 1 { // Just null terminator + t.Errorf("length = %d, want 1", length) + } + + data := bf.ReadBytes(uint(length)) + if data[0] != 0 { + t.Error("empty string should produce just null terminator") + } +} + +func TestUint16_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "World" + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + expectedLength := uint16(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + expected := []byte("World\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint16_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + Uint16(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint16_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length != 1 { + t.Errorf("length = %d, want 1", length) + } +} + +func TestUint32_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Testing" + + Uint32(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint32() + expectedLength := uint32(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + expected := []byte("Testing\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint32_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + Uint32(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint32() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint32_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint32(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint32() + + if length != 1 { + t.Errorf("length = %d, want 1", length) + } +} + +func TestUint8_LongString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "This is a longer test string with more characters" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + expectedLength := uint8(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + if !bytes.HasSuffix(data, []byte{0}) { + t.Error("data should end with null terminator") + } + if !bytes.HasPrefix(data, []byte("This is")) { + t.Error("data should start with expected string") + } +} + +func TestUint16_LongString(t *testing.T) { + bf := byteframe.NewByteFrame() + // Create a string longer than 255 to test uint16 + testString := "" + for i := 0; i < 300; i++ { + testString += "A" + } + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + expectedLength := uint16(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + if !bytes.HasSuffix(data, []byte{0}) { + t.Error("data should end with null terminator") + } +} + +func TestAllFunctions_NullTermination(t *testing.T) { + tests := []struct { + name string + writeFn func(*byteframe.ByteFrame, string, bool) + readSize func(*byteframe.ByteFrame) uint + }{ + { + name: "Uint8", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint8(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint8()) + }, + }, + { + name: "Uint16", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint16(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint16()) + }, + }, + { + name: "Uint32", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint32(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint32()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + tt.writeFn(bf, testString, false) + + bf.Seek(0, 0) + size := tt.readSize(bf) + data := bf.ReadBytes(size) + + // Verify null termination + if data[len(data)-1] != 0 { + t.Errorf("%s: data should end with null terminator", tt.name) + } + + // Verify length includes null terminator + if size != uint(len(testString)+1) { + t.Errorf("%s: size = %d, want %d", tt.name, size, len(testString)+1) + } + }) + } +} + +func TestTransform_JapaneseCharacters(t *testing.T) { + // Test with Japanese characters that should be transformed to Shift-JIS + bf := byteframe.NewByteFrame() + testString := "テスト" // "Test" in Japanese katakana + + Uint16(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length == 0 { + t.Error("Transformed Japanese string should have non-zero length") + } + + // The transformed Shift-JIS should be different length than UTF-8 + // UTF-8: 9 bytes (3 chars * 3 bytes each), Shift-JIS: 6 bytes (3 chars * 2 bytes each) + 1 null + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("Transformed string should end with null terminator") + } +} + +func TestTransform_InvalidUTF8(t *testing.T) { + // This test verifies graceful handling of encoding errors + // When transformation fails, the functions should write length 0 + + bf := byteframe.NewByteFrame() + // Create a string with invalid UTF-8 sequence + // Note: Go strings are generally valid UTF-8, but we can test the error path + testString := "Valid ASCII" + + Uint8(bf, testString, true) + // Should succeed for ASCII characters + + bf.Seek(0, 0) + length := bf.ReadUint8() + if length == 0 { + t.Error("ASCII string should transform successfully") + } +} + +func BenchmarkUint8_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint8(bf, testString, false) + } +} + +func BenchmarkUint8_WithTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint8(bf, testString, true) + } +} + +func BenchmarkUint16_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint16(bf, testString, false) + } +} + +func BenchmarkUint32_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint32(bf, testString, false) + } +} + +func BenchmarkUint16_Japanese(b *testing.B) { + testString := "テストメッセージ" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint16(bf, testString, true) + } +} diff --git a/common/stringstack/stringstack_test.go b/common/stringstack/stringstack_test.go new file mode 100644 index 000000000..3bfcf7656 --- /dev/null +++ b/common/stringstack/stringstack_test.go @@ -0,0 +1,343 @@ +package stringstack + +import ( + "testing" +) + +func TestNew(t *testing.T) { + s := New() + if s == nil { + t.Fatal("New() returned nil") + } + if len(s.stack) != 0 { + t.Errorf("New() stack length = %d, want 0", len(s.stack)) + } +} + +func TestStringStack_Set(t *testing.T) { + s := New() + s.Set("first") + + if len(s.stack) != 1 { + t.Errorf("Set() stack length = %d, want 1", len(s.stack)) + } + if s.stack[0] != "first" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "first") + } +} + +func TestStringStack_Set_Replaces(t *testing.T) { + s := New() + s.Push("item1") + s.Push("item2") + s.Push("item3") + + // Set should replace the entire stack + s.Set("new_item") + + if len(s.stack) != 1 { + t.Errorf("Set() stack length = %d, want 1", len(s.stack)) + } + if s.stack[0] != "new_item" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "new_item") + } +} + +func TestStringStack_Push(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + s.Push("third") + + if len(s.stack) != 3 { + t.Errorf("Push() stack length = %d, want 3", len(s.stack)) + } + if s.stack[0] != "first" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "first") + } + if s.stack[1] != "second" { + t.Errorf("stack[1] = %q, want %q", s.stack[1], "second") + } + if s.stack[2] != "third" { + t.Errorf("stack[2] = %q, want %q", s.stack[2], "third") + } +} + +func TestStringStack_Pop(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + s.Push("third") + + // Pop should return LIFO (last in, first out) + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "third" { + t.Errorf("Pop() = %q, want %q", val, "third") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "second" { + t.Errorf("Pop() = %q, want %q", val, "second") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "first" { + t.Errorf("Pop() = %q, want %q", val, "first") + } + + if len(s.stack) != 0 { + t.Errorf("stack length = %d, want 0 after popping all items", len(s.stack)) + } +} + +func TestStringStack_Pop_Empty(t *testing.T) { + s := New() + + val, err := s.Pop() + if err == nil { + t.Error("Pop() on empty stack should return error") + } + if val != "" { + t.Errorf("Pop() on empty stack returned %q, want empty string", val) + } + + expectedError := "no items on stack" + if err.Error() != expectedError { + t.Errorf("Pop() error = %q, want %q", err.Error(), expectedError) + } +} + +func TestStringStack_LIFO_Behavior(t *testing.T) { + s := New() + items := []string{"A", "B", "C", "D", "E"} + + for _, item := range items { + s.Push(item) + } + + // Pop should return in reverse order (LIFO) + for i := len(items) - 1; i >= 0; i-- { + val, err := s.Pop() + if err != nil { + t.Fatalf("Pop() error = %v", err) + } + if val != items[i] { + t.Errorf("Pop() = %q, want %q", val, items[i]) + } + } +} + +func TestStringStack_PushAfterPop(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + + val, _ := s.Pop() + if val != "second" { + t.Errorf("Pop() = %q, want %q", val, "second") + } + + s.Push("third") + + val, _ = s.Pop() + if val != "third" { + t.Errorf("Pop() = %q, want %q", val, "third") + } + + val, _ = s.Pop() + if val != "first" { + t.Errorf("Pop() = %q, want %q", val, "first") + } +} + +func TestStringStack_EmptyStrings(t *testing.T) { + s := New() + s.Push("") + s.Push("text") + s.Push("") + + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "" { + t.Errorf("Pop() = %q, want empty string", val) + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "text" { + t.Errorf("Pop() = %q, want %q", val, "text") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "" { + t.Errorf("Pop() = %q, want empty string", val) + } +} + +func TestStringStack_LongStrings(t *testing.T) { + s := New() + longString := "" + for i := 0; i < 1000; i++ { + longString += "A" + } + + s.Push(longString) + val, err := s.Pop() + + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != longString { + t.Error("Pop() returned different string than pushed") + } + if len(val) != 1000 { + t.Errorf("Pop() string length = %d, want 1000", len(val)) + } +} + +func TestStringStack_ManyItems(t *testing.T) { + s := New() + count := 1000 + + // Push many items + for i := 0; i < count; i++ { + s.Push("item") + } + + if len(s.stack) != count { + t.Errorf("stack length = %d, want %d", len(s.stack), count) + } + + // Pop all items + for i := 0; i < count; i++ { + _, err := s.Pop() + if err != nil { + t.Errorf("Pop()[%d] error = %v", i, err) + } + } + + // Should be empty now + if len(s.stack) != 0 { + t.Errorf("stack length = %d, want 0 after popping all", len(s.stack)) + } + + // Next pop should error + _, err := s.Pop() + if err == nil { + t.Error("Pop() on empty stack should return error") + } +} + +func TestStringStack_SetAfterOperations(t *testing.T) { + s := New() + s.Push("a") + s.Push("b") + s.Push("c") + s.Pop() + s.Push("d") + + // Set should clear everything + s.Set("reset") + + if len(s.stack) != 1 { + t.Errorf("stack length = %d, want 1 after Set", len(s.stack)) + } + + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "reset" { + t.Errorf("Pop() = %q, want %q", val, "reset") + } +} + +func TestStringStack_SpecialCharacters(t *testing.T) { + s := New() + specialStrings := []string{ + "Hello\nWorld", + "Tab\tSeparated", + "Quote\"Test", + "Backslash\\Test", + "Unicode: テスト", + "Emoji: 😀", + "", + " ", + " spaces ", + } + + for _, str := range specialStrings { + s.Push(str) + } + + // Pop in reverse order + for i := len(specialStrings) - 1; i >= 0; i-- { + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != specialStrings[i] { + t.Errorf("Pop() = %q, want %q", val, specialStrings[i]) + } + } +} + +func BenchmarkStringStack_Push(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Push("test string") + } +} + +func BenchmarkStringStack_Pop(b *testing.B) { + s := New() + // Pre-populate + for i := 0; i < 10000; i++ { + s.Push("test string") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if len(s.stack) == 0 { + // Repopulate + for j := 0; j < 10000; j++ { + s.Push("test string") + } + } + _, _ = s.Pop() + } +} + +func BenchmarkStringStack_PushPop(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Push("test") + _, _ = s.Pop() + } +} + +func BenchmarkStringStack_Set(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Set("test string") + } +} diff --git a/config/config.go b/config/config.go index 1e91b24b7..32193080e 100644 --- a/config/config.go +++ b/config/config.go @@ -5,6 +5,7 @@ import ( "log" "net" "os" + "strings" "time" "github.com/spf13/viper" @@ -149,6 +150,11 @@ type EntranceChannelInfo struct { var ErupeConfig *Config func init() { + // Skip config loading during tests + if isTestMode() { + return + } + var err error ErupeConfig, err = LoadConfig() if err != nil { @@ -157,6 +163,16 @@ func init() { } +func isTestMode() bool { + // Check if we're running in test mode + for _, arg := range os.Args { + if strings.HasPrefix(arg, "-test.") { + return true + } + } + return false +} + // getOutboundIP4 gets the preferred outbound ip4 of this machine // From https://stackoverflow.com/a/37382208 func getOutboundIP4() net.IP { @@ -200,7 +216,7 @@ func LoadConfig() (*Config, error) { } func preventClose(text string) { - if ErupeConfig.DisableSoftCrash { + if ErupeConfig != nil && ErupeConfig.DisableSoftCrash { os.Exit(0) } fmt.Println("\nFailed to start Erupe:\n" + text) diff --git a/network/clientctx/clientcontext_test.go b/network/clientctx/clientcontext_test.go new file mode 100644 index 000000000..2eb333ab5 --- /dev/null +++ b/network/clientctx/clientcontext_test.go @@ -0,0 +1,31 @@ +package clientctx + +import ( + "testing" +) + +// TestClientContext_Exists verifies that the ClientContext type exists +// and can be instantiated, even though it's currently unused. +func TestClientContext_Exists(t *testing.T) { + // This test documents that ClientContext is currently an empty struct + // and is marked as unused in the codebase. + var ctx ClientContext + + // Verify it's a zero-size struct + _ = ctx + + // Just verify we can create it + ctx2 := ClientContext{} + _ = ctx2 +} + +// TestClientContext_IsEmpty verifies that ClientContext has no fields +func TestClientContext_IsEmpty(t *testing.T) { + // The struct should be empty as marked by the comment "// Unused" + // This test documents the current state of the struct + ctx := ClientContext{} + _ = ctx + + // If fields are added in the future, this test will need to be updated + // Currently it's just a placeholder/documentation test +} diff --git a/network/crypto/crypto.go b/network/crypto/crypto.go index 87746fa90..8f6327161 100644 --- a/network/crypto/crypto.go +++ b/network/crypto/crypto.go @@ -18,6 +18,16 @@ func Decrypt(data []byte, key uint32, overrideByteKey *byte) (outputData []byte, return _generalCrypt(data, key, 1, overrideByteKey) } +// Crypto is a unified interface for both encryption and decryption. +// If encrypt is true, it encrypts the data; otherwise it decrypts. +// This function exists for compatibility with tests. +func Crypto(data []byte, rotKey uint32, encrypt bool, overrideByteKey *byte) ([]byte, uint16, uint16, uint16, uint16) { + if encrypt { + return Encrypt(data, rotKey, overrideByteKey) + } + return Decrypt(data, rotKey, overrideByteKey) +} + // _generalCrypt is a generalized MHF crypto function that can perform both encryption and decryption, // these two crypto operations are combined into a single function because they shared most of their logic. // encrypt: cryptType==0 diff --git a/network/crypto/crypto_test.go b/network/crypto/crypto_test.go index 32ff7ee7c..b661262d7 100644 --- a/network/crypto/crypto_test.go +++ b/network/crypto/crypto_test.go @@ -65,7 +65,7 @@ func TestEncrypt(t *testing.T) { for k, tt := range tests { testname := fmt.Sprintf("encrypt_test_%d", k) t.Run(testname, func(t *testing.T) { - out, cc, c0, c1, c2 := Encrypt(tt.decryptedData, tt.key, nil) + out, cc, c0, c1, c2 := Crypto(tt.decryptedData, tt.key, true, nil) if cc != tt.ecc { t.Errorf("got cc 0x%X, want 0x%X", cc, tt.ecc) } else if c0 != tt.ec0 { @@ -86,7 +86,7 @@ func TestDecrypt(t *testing.T) { for k, tt := range tests { testname := fmt.Sprintf("decrypt_test_%d", k) t.Run(testname, func(t *testing.T) { - out, cc, c0, c1, c2 := Decrypt(tt.encryptedData, tt.key, nil) + out, cc, c0, c1, c2 := Crypto(tt.encryptedData, tt.key, false, nil) if cc != tt.ecc { t.Errorf("got cc 0x%X, want 0x%X", cc, tt.ecc) } else if c0 != tt.ec0 { diff --git a/server/channelserver/compression/deltacomp/deltacomp_test.go b/server/channelserver/compression/deltacomp/deltacomp_test.go index 0df33934b..11da4fc9f 100644 --- a/server/channelserver/compression/deltacomp/deltacomp_test.go +++ b/server/channelserver/compression/deltacomp/deltacomp_test.go @@ -4,7 +4,7 @@ import ( "bytes" "encoding/hex" "fmt" - "io/ioutil" + "os" "testing" "erupe-ce/server/channelserver/compression/nullcomp" @@ -68,7 +68,7 @@ var tests = []struct { } func readTestDataFile(filename string) []byte { - data, err := ioutil.ReadFile(fmt.Sprintf("./test_data/%s", filename)) + data, err := os.ReadFile(fmt.Sprintf("./test_data/%s", filename)) if err != nil { panic(err) } diff --git a/server/channelserver/compression/nullcomp/nullcomp_test.go b/server/channelserver/compression/nullcomp/nullcomp_test.go new file mode 100644 index 000000000..8b94049aa --- /dev/null +++ b/server/channelserver/compression/nullcomp/nullcomp_test.go @@ -0,0 +1,407 @@ +package nullcomp + +import ( + "bytes" + "testing" +) + +func TestDecompress_WithValidHeader(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "empty data after header", + input: []byte("cmp\x2020110113\x20\x20\x20\x00"), + expected: []byte{}, + }, + { + name: "single regular byte", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x42"), + expected: []byte{0x42}, + }, + { + name: "multiple regular bytes", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x48\x65\x6c\x6c\x6f"), + expected: []byte("Hello"), + }, + { + name: "single null byte compression", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x00\x05"), + expected: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "multiple null bytes with max count", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x00\xFF"), + expected: make([]byte, 255), + }, + { + name: "mixed regular and null bytes", + input: append( + []byte("cmp\x2020110113\x20\x20\x20\x00\x48\x65\x6c\x6c\x6f"), + []byte{0x00, 0x03, 0x57, 0x6f, 0x72, 0x6c, 0x64}..., + ), + expected: []byte("Hello\x00\x00\x00World"), + }, + { + name: "multiple null compressions", + input: append( + []byte("cmp\x2020110113\x20\x20\x20\x00"), + []byte{0x41, 0x00, 0x02, 0x42, 0x00, 0x03, 0x43}..., + ), + expected: []byte{0x41, 0x00, 0x00, 0x42, 0x00, 0x00, 0x00, 0x43}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Decompress(tt.input) + if err != nil { + t.Fatalf("Decompress() error = %v", err) + } + if !bytes.Equal(result, tt.expected) { + t.Errorf("Decompress() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestDecompress_WithoutHeader(t *testing.T) { + tests := []struct { + name string + input []byte + expectError bool + expectOriginal bool // Expect original data returned + }{ + { + name: "plain data without header (16+ bytes)", + // Data must be at least 16 bytes to read header + input: []byte("Hello, World!!!!"), // Exactly 16 bytes + expectError: false, + expectOriginal: true, + }, + { + name: "binary data without header (16+ bytes)", + input: []byte{ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, + 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, + }, + expectError: false, + expectOriginal: true, + }, + { + name: "data shorter than 16 bytes", + // When data is shorter than 16 bytes, Read returns what it can with err=nil + // Then n != len(header) returns nil, nil (not an error) + input: []byte("Short"), + expectError: false, + expectOriginal: false, // Returns empty slice + }, + { + name: "empty data", + input: []byte{}, + expectError: true, // EOF on first read + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Decompress(tt.input) + if tt.expectError { + if err == nil { + t.Errorf("Decompress() expected error but got none") + } + return + } + if err != nil { + t.Fatalf("Decompress() error = %v", err) + } + if tt.expectOriginal && !bytes.Equal(result, tt.input) { + t.Errorf("Decompress() = %v, want %v (original data)", result, tt.input) + } + }) + } +} + +func TestDecompress_InvalidData(t *testing.T) { + tests := []struct { + name string + input []byte + expectErr bool + }{ + { + name: "incomplete header", + // Less than 16 bytes: Read returns what it can (no error), + // but n != len(header) returns nil, nil + input: []byte("cmp\x20201"), + expectErr: false, + }, + { + name: "header with missing null count", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x00"), + expectErr: false, // Valid header, EOF during decompression is handled + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Decompress(tt.input) + if tt.expectErr { + if err == nil { + t.Errorf("Decompress() expected error but got none, result = %v", result) + } + } else { + if err != nil { + t.Errorf("Decompress() unexpected error = %v", err) + } + } + }) + } +} + +func TestCompress_BasicData(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + { + name: "empty data", + input: []byte{}, + }, + { + name: "regular bytes without nulls", + input: []byte("Hello, World!"), + }, + { + name: "single null byte", + input: []byte{0x00}, + }, + { + name: "multiple consecutive nulls", + input: []byte{0x00, 0x00, 0x00, 0x00, 0x00}, + }, + { + name: "mixed data with nulls", + input: []byte("Hello\x00\x00\x00World"), + }, + { + name: "data starting with nulls", + input: []byte{0x00, 0x00, 0x48, 0x65, 0x6c, 0x6c, 0x6f}, + }, + { + name: "data ending with nulls", + input: []byte{0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x00, 0x00, 0x00}, + }, + { + name: "alternating nulls and bytes", + input: []byte{0x41, 0x00, 0x42, 0x00, 0x43}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + compressed, err := Compress(tt.input) + if err != nil { + t.Fatalf("Compress() error = %v", err) + } + + // Verify it has the correct header + expectedHeader := []byte("cmp\x2020110113\x20\x20\x20\x00") + if !bytes.HasPrefix(compressed, expectedHeader) { + t.Errorf("Compress() result doesn't have correct header") + } + + // Verify round-trip + decompressed, err := Decompress(compressed) + if err != nil { + t.Fatalf("Decompress() error = %v", err) + } + if !bytes.Equal(decompressed, tt.input) { + t.Errorf("Round-trip failed: got %v, want %v", decompressed, tt.input) + } + }) + } +} + +func TestCompress_LargeNullSequences(t *testing.T) { + tests := []struct { + name string + nullCount int + }{ + { + name: "exactly 255 nulls", + nullCount: 255, + }, + { + name: "256 nulls (overflow case)", + nullCount: 256, + }, + { + name: "500 nulls", + nullCount: 500, + }, + { + name: "1000 nulls", + nullCount: 1000, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + input := make([]byte, tt.nullCount) + compressed, err := Compress(input) + if err != nil { + t.Fatalf("Compress() error = %v", err) + } + + // Verify round-trip + decompressed, err := Decompress(compressed) + if err != nil { + t.Fatalf("Decompress() error = %v", err) + } + if !bytes.Equal(decompressed, input) { + t.Errorf("Round-trip failed: got len=%d, want len=%d", len(decompressed), len(input)) + } + }) + } +} + +func TestCompressDecompress_RoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + }{ + { + name: "binary data with mixed nulls", + data: []byte{0x01, 0x02, 0x00, 0x00, 0x03, 0x04, 0x00, 0x05}, + }, + { + name: "large binary data", + data: append(append([]byte{0xFF, 0xFE, 0xFD}, make([]byte, 300)...), []byte{0x01, 0x02, 0x03}...), + }, + { + name: "text with embedded nulls", + data: []byte("Test\x00\x00Data\x00\x00\x00End"), + }, + { + name: "all non-null bytes", + data: []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}, + }, + { + name: "only null bytes", + data: make([]byte, 100), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Compress + compressed, err := Compress(tt.data) + if err != nil { + t.Fatalf("Compress() error = %v", err) + } + + // Decompress + decompressed, err := Decompress(compressed) + if err != nil { + t.Fatalf("Decompress() error = %v", err) + } + + // Verify + if !bytes.Equal(decompressed, tt.data) { + t.Errorf("Round-trip failed:\ngot = %v\nwant = %v", decompressed, tt.data) + } + }) + } +} + +func TestCompress_CompressionEfficiency(t *testing.T) { + // Test that data with many nulls is actually compressed + input := make([]byte, 1000) + compressed, err := Compress(input) + if err != nil { + t.Fatalf("Compress() error = %v", err) + } + + // The compressed size should be much smaller than the original + // With 1000 nulls, we expect roughly 16 (header) + 4*3 (for 255*3 + 235) bytes + if len(compressed) >= len(input) { + t.Errorf("Compression failed: compressed size (%d) >= input size (%d)", len(compressed), len(input)) + } +} + +func TestDecompress_EdgeCases(t *testing.T) { + tests := []struct { + name string + input []byte + }{ + { + name: "only header", + input: []byte("cmp\x2020110113\x20\x20\x20\x00"), + }, + { + name: "null with count 1", + input: []byte("cmp\x2020110113\x20\x20\x20\x00\x00\x01"), + }, + { + name: "multiple sections of compressed nulls", + input: append([]byte("cmp\x2020110113\x20\x20\x20\x00"), []byte{0x00, 0x10, 0x41, 0x00, 0x20, 0x42}...), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := Decompress(tt.input) + if err != nil { + t.Fatalf("Decompress() unexpected error = %v", err) + } + // Just ensure it doesn't crash and returns something + _ = result + }) + } +} + +func BenchmarkCompress(b *testing.B) { + data := make([]byte, 10000) + // Fill with some pattern (half nulls, half data) + for i := 0; i < len(data); i++ { + if i%2 == 0 { + data[i] = 0x00 + } else { + data[i] = byte(i % 256) + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := Compress(data) + if err != nil { + b.Fatal(err) + } + } +} + +func BenchmarkDecompress(b *testing.B) { + data := make([]byte, 10000) + for i := 0; i < len(data); i++ { + if i%2 == 0 { + data[i] = 0x00 + } else { + data[i] = byte(i % 256) + } + } + + compressed, err := Compress(data) + if err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := Decompress(compressed) + if err != nil { + b.Fatal(err) + } + } +} diff --git a/server/channelserver/handlers_cast_binary.go b/server/channelserver/handlers_cast_binary.go index 1b6b9a6ba..b5c7d8dc7 100644 --- a/server/channelserver/handlers_cast_binary.go +++ b/server/channelserver/handlers_cast_binary.go @@ -39,6 +39,12 @@ var commands map[string]config.Command func init() { commands = make(map[string]config.Command) + + // Skip initialization if config is not loaded (e.g., during tests) + if config.ErupeConfig == nil { + return + } + zapConfig := zap.NewDevelopmentConfig() zapConfig.DisableCaller = true zapLogger, _ := zapConfig.Build()