diff --git a/common/bfutil/bfutil_test.go b/common/bfutil/bfutil_test.go new file mode 100644 index 000000000..51fad0e13 --- /dev/null +++ b/common/bfutil/bfutil_test.go @@ -0,0 +1,105 @@ +package bfutil + +import ( + "bytes" + "testing" +) + +func TestUpToNull(t *testing.T) { + tests := []struct { + name string + input []byte + expected []byte + }{ + { + name: "data with null terminator", + input: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x00, 0x57, 0x6F, 0x72, 0x6C, 0x64}, + expected: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, // "Hello" + }, + { + name: "data without null terminator", + input: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, + expected: []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F}, // "Hello" + }, + { + name: "data with null at start", + input: []byte{0x00, 0x48, 0x65, 0x6C, 0x6C, 0x6F}, + expected: []byte{}, + }, + { + name: "empty slice", + input: []byte{}, + expected: []byte{}, + }, + { + name: "only null byte", + input: []byte{0x00}, + expected: []byte{}, + }, + { + name: "multiple null bytes", + input: []byte{0x48, 0x65, 0x00, 0x00, 0x6C, 0x6C, 0x6F}, + expected: []byte{0x48, 0x65}, // "He" + }, + { + name: "binary data with null", + input: []byte{0xFF, 0xAB, 0x12, 0x00, 0x34, 0x56}, + expected: []byte{0xFF, 0xAB, 0x12}, + }, + { + name: "binary data without null", + input: []byte{0xFF, 0xAB, 0x12, 0x34, 0x56}, + expected: []byte{0xFF, 0xAB, 0x12, 0x34, 0x56}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UpToNull(tt.input) + if !bytes.Equal(result, tt.expected) { + t.Errorf("UpToNull() = %v, want %v", result, tt.expected) + } + }) + } +} + +func TestUpToNull_ReturnsSliceNotCopy(t *testing.T) { + // Test that UpToNull returns a slice of the original array, not a copy + input := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F, 0x00, 0x57, 0x6F, 0x72, 0x6C, 0x64} + result := UpToNull(input) + + // Verify we got the expected data + expected := []byte{0x48, 0x65, 0x6C, 0x6C, 0x6F} + if !bytes.Equal(result, expected) { + t.Errorf("UpToNull() = %v, want %v", result, expected) + } + + // The result should be a slice of the input array + if len(result) > 0 && cap(result) < len(expected) { + t.Error("Result should be a slice of input array") + } +} + +func BenchmarkUpToNull(b *testing.B) { + data := []byte("Hello, World!\x00Extra data here") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} + +func BenchmarkUpToNull_NoNull(b *testing.B) { + data := []byte("Hello, World! No null terminator in this string at all") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} + +func BenchmarkUpToNull_NullAtStart(b *testing.B) { + data := []byte("\x00Hello, World!") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UpToNull(data) + } +} diff --git a/common/byteframe/byteframe_test.go b/common/byteframe/byteframe_test.go new file mode 100644 index 000000000..cd9c4b93e --- /dev/null +++ b/common/byteframe/byteframe_test.go @@ -0,0 +1,502 @@ +package byteframe + +import ( + "bytes" + "encoding/binary" + "io" + "math" + "testing" +) + +func TestNewByteFrame(t *testing.T) { + bf := NewByteFrame() + if bf == nil { + t.Fatal("NewByteFrame() returned nil") + } + if bf.index != 0 { + t.Errorf("index = %d, want 0", bf.index) + } + if bf.usedSize != 0 { + t.Errorf("usedSize = %d, want 0", bf.usedSize) + } + if len(bf.buf) != 4 { + t.Errorf("buf length = %d, want 4", len(bf.buf)) + } + if bf.byteOrder != binary.BigEndian { + t.Error("byteOrder should be BigEndian by default") + } +} + +func TestNewByteFrameFromBytes(t *testing.T) { + input := []byte{0x01, 0x02, 0x03, 0x04} + bf := NewByteFrameFromBytes(input) + if bf == nil { + t.Fatal("NewByteFrameFromBytes() returned nil") + } + if bf.index != 0 { + t.Errorf("index = %d, want 0", bf.index) + } + if bf.usedSize != uint(len(input)) { + t.Errorf("usedSize = %d, want %d", bf.usedSize, len(input)) + } + if !bytes.Equal(bf.buf, input) { + t.Errorf("buf = %v, want %v", bf.buf, input) + } + // Verify it's a copy, not the same slice + input[0] = 0xFF + if bf.buf[0] == 0xFF { + t.Error("NewByteFrameFromBytes should make a copy, not use the same slice") + } +} + +func TestByteFrame_WriteAndReadUint8(t *testing.T) { + bf := NewByteFrame() + values := []uint8{0, 1, 127, 128, 255} + + for _, v := range values { + bf.WriteUint8(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadUint8() + if got != expected { + t.Errorf("ReadUint8()[%d] = %d, want %d", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadUint16(t *testing.T) { + tests := []struct { + name string + value uint16 + }{ + {"zero", 0}, + {"one", 1}, + {"max_int8", 127}, + {"max_uint8", 255}, + {"max_int16", 32767}, + {"max_uint16", 65535}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint16(tt.value) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint16() + if got != tt.value { + t.Errorf("ReadUint16() = %d, want %d", got, tt.value) + } + }) + } +} + +func TestByteFrame_WriteAndReadUint32(t *testing.T) { + tests := []struct { + name string + value uint32 + }{ + {"zero", 0}, + {"one", 1}, + {"max_uint16", 65535}, + {"max_uint32", 4294967295}, + {"arbitrary", 0x12345678}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint32(tt.value) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint32() + if got != tt.value { + t.Errorf("ReadUint32() = %d, want %d", got, tt.value) + } + }) + } +} + +func TestByteFrame_WriteAndReadUint64(t *testing.T) { + tests := []struct { + name string + value uint64 + }{ + {"zero", 0}, + {"one", 1}, + {"max_uint32", 4294967295}, + {"max_uint64", 18446744073709551615}, + {"arbitrary", 0x123456789ABCDEF0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint64(tt.value) + bf.Seek(0, io.SeekStart) + got := bf.ReadUint64() + if got != tt.value { + t.Errorf("ReadUint64() = %d, want %d", got, tt.value) + } + }) + } +} + +func TestByteFrame_WriteAndReadInt8(t *testing.T) { + values := []int8{-128, -1, 0, 1, 127} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteInt8(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadInt8() + if got != expected { + t.Errorf("ReadInt8()[%d] = %d, want %d", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadInt16(t *testing.T) { + values := []int16{-32768, -1, 0, 1, 32767} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteInt16(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadInt16() + if got != expected { + t.Errorf("ReadInt16()[%d] = %d, want %d", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadInt32(t *testing.T) { + values := []int32{-2147483648, -1, 0, 1, 2147483647} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteInt32(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadInt32() + if got != expected { + t.Errorf("ReadInt32()[%d] = %d, want %d", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadInt64(t *testing.T) { + values := []int64{-9223372036854775808, -1, 0, 1, 9223372036854775807} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteInt64(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadInt64() + if got != expected { + t.Errorf("ReadInt64()[%d] = %d, want %d", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadFloat32(t *testing.T) { + values := []float32{0.0, -1.5, 1.5, 3.14159, math.MaxFloat32, -math.MaxFloat32} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteFloat32(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadFloat32() + if got != expected { + t.Errorf("ReadFloat32()[%d] = %f, want %f", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadFloat64(t *testing.T) { + values := []float64{0.0, -1.5, 1.5, 3.14159265358979, math.MaxFloat64, -math.MaxFloat64} + bf := NewByteFrame() + + for _, v := range values { + bf.WriteFloat64(v) + } + + bf.Seek(0, io.SeekStart) + for i, expected := range values { + got := bf.ReadFloat64() + if got != expected { + t.Errorf("ReadFloat64()[%d] = %f, want %f", i, got, expected) + } + } +} + +func TestByteFrame_WriteAndReadBool(t *testing.T) { + bf := NewByteFrame() + bf.WriteBool(true) + bf.WriteBool(false) + bf.WriteBool(true) + + bf.Seek(0, io.SeekStart) + if got := bf.ReadBool(); got != true { + t.Errorf("ReadBool()[0] = %v, want true", got) + } + if got := bf.ReadBool(); got != false { + t.Errorf("ReadBool()[1] = %v, want false", got) + } + if got := bf.ReadBool(); got != true { + t.Errorf("ReadBool()[2] = %v, want true", got) + } +} + +func TestByteFrame_WriteAndReadBytes(t *testing.T) { + bf := NewByteFrame() + input := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + bf.WriteBytes(input) + + bf.Seek(0, io.SeekStart) + got := bf.ReadBytes(uint(len(input))) + if !bytes.Equal(got, input) { + t.Errorf("ReadBytes() = %v, want %v", got, input) + } +} + +func TestByteFrame_WriteAndReadNullTerminatedBytes(t *testing.T) { + bf := NewByteFrame() + input := []byte("Hello, World!") + bf.WriteNullTerminatedBytes(input) + + bf.Seek(0, io.SeekStart) + got := bf.ReadNullTerminatedBytes() + if !bytes.Equal(got, input) { + t.Errorf("ReadNullTerminatedBytes() = %v, want %v", got, input) + } +} + +func TestByteFrame_ReadNullTerminatedBytes_NoNull(t *testing.T) { + bf := NewByteFrame() + input := []byte("Hello") + bf.WriteBytes(input) + + bf.Seek(0, io.SeekStart) + got := bf.ReadNullTerminatedBytes() + // When there's no null terminator, it should return empty slice + if len(got) != 0 { + t.Errorf("ReadNullTerminatedBytes() = %v, want empty slice", got) + } +} + +func TestByteFrame_Endianness(t *testing.T) { + // Test BigEndian (default) + bfBE := NewByteFrame() + bfBE.WriteUint16(0x1234) + dataBE := bfBE.Data() + if dataBE[0] != 0x12 || dataBE[1] != 0x34 { + t.Errorf("BigEndian: got %X %X, want 12 34", dataBE[0], dataBE[1]) + } + + // Test LittleEndian + bfLE := NewByteFrame() + bfLE.SetLE() + bfLE.WriteUint16(0x1234) + dataLE := bfLE.Data() + if dataLE[0] != 0x34 || dataLE[1] != 0x12 { + t.Errorf("LittleEndian: got %X %X, want 34 12", dataLE[0], dataLE[1]) + } +} + +func TestByteFrame_Seek(t *testing.T) { + bf := NewByteFrame() + bf.WriteBytes([]byte{0x01, 0x02, 0x03, 0x04, 0x05}) + + tests := []struct { + name string + offset int64 + whence int + wantIndex uint + wantErr bool + }{ + {"seek_start_0", 0, io.SeekStart, 0, false}, + {"seek_start_2", 2, io.SeekStart, 2, false}, + {"seek_start_5", 5, io.SeekStart, 5, false}, + {"seek_start_beyond", 6, io.SeekStart, 5, true}, + {"seek_current_forward", 2, io.SeekCurrent, 5, true}, // Will go beyond max + {"seek_current_backward", -3, io.SeekCurrent, 2, false}, + {"seek_current_before_start", -10, io.SeekCurrent, 2, true}, + {"seek_end_0", 0, io.SeekEnd, 5, false}, + {"seek_end_negative", -2, io.SeekEnd, 3, false}, + {"seek_end_beyond", 1, io.SeekEnd, 3, true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Reset to known position for each test + bf.Seek(5, io.SeekStart) + + pos, err := bf.Seek(tt.offset, tt.whence) + if tt.wantErr { + if err == nil { + t.Errorf("Seek() expected error, got nil") + } + } else { + if err != nil { + t.Errorf("Seek() unexpected error: %v", err) + } + if bf.index != tt.wantIndex { + t.Errorf("index = %d, want %d", bf.index, tt.wantIndex) + } + if uint(pos) != tt.wantIndex { + t.Errorf("returned position = %d, want %d", pos, tt.wantIndex) + } + } + }) + } +} + +func TestByteFrame_Data(t *testing.T) { + bf := NewByteFrame() + input := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + bf.WriteBytes(input) + + data := bf.Data() + if !bytes.Equal(data, input) { + t.Errorf("Data() = %v, want %v", data, input) + } +} + +func TestByteFrame_DataFromCurrent(t *testing.T) { + bf := NewByteFrame() + bf.WriteBytes([]byte{0x01, 0x02, 0x03, 0x04, 0x05}) + bf.Seek(2, io.SeekStart) + + data := bf.DataFromCurrent() + expected := []byte{0x03, 0x04, 0x05} + if !bytes.Equal(data, expected) { + t.Errorf("DataFromCurrent() = %v, want %v", data, expected) + } +} + +func TestByteFrame_Index(t *testing.T) { + bf := NewByteFrame() + if bf.Index() != 0 { + t.Errorf("Index() = %d, want 0", bf.Index()) + } + + bf.WriteUint8(0x01) + if bf.Index() != 1 { + t.Errorf("Index() = %d, want 1", bf.Index()) + } + + bf.WriteUint16(0x0102) + if bf.Index() != 3 { + t.Errorf("Index() = %d, want 3", bf.Index()) + } +} + +func TestByteFrame_BufferGrowth(t *testing.T) { + bf := NewByteFrame() + initialCap := len(bf.buf) + + // Write enough data to force growth + for i := 0; i < 100; i++ { + bf.WriteUint32(uint32(i)) + } + + if len(bf.buf) <= initialCap { + t.Errorf("Buffer should have grown, initial cap: %d, current: %d", initialCap, len(bf.buf)) + } + + // Verify all data is still accessible + bf.Seek(0, io.SeekStart) + for i := 0; i < 100; i++ { + got := bf.ReadUint32() + if got != uint32(i) { + t.Errorf("After growth, ReadUint32()[%d] = %d, want %d", i, got, i) + break + } + } +} + +func TestByteFrame_ReadPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Error("Reading beyond buffer should panic") + } + }() + + bf := NewByteFrame() + bf.WriteUint8(0x01) + bf.Seek(0, io.SeekStart) + bf.ReadUint8() + bf.ReadUint16() // Should panic - trying to read 2 bytes when only 1 was written +} + +func TestByteFrame_SequentialWrites(t *testing.T) { + bf := NewByteFrame() + bf.WriteUint8(0x01) + bf.WriteUint16(0x0203) + bf.WriteUint32(0x04050607) + bf.WriteUint64(0x08090A0B0C0D0E0F) + + expected := []byte{ + 0x01, // uint8 + 0x02, 0x03, // uint16 + 0x04, 0x05, 0x06, 0x07, // uint32 + 0x08, 0x09, 0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, // uint64 + } + + data := bf.Data() + if !bytes.Equal(data, expected) { + t.Errorf("Sequential writes: got %X, want %X", data, expected) + } +} + +func BenchmarkByteFrame_WriteUint8(b *testing.B) { + bf := NewByteFrame() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.WriteUint8(0x42) + } +} + +func BenchmarkByteFrame_WriteUint32(b *testing.B) { + bf := NewByteFrame() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.WriteUint32(0x12345678) + } +} + +func BenchmarkByteFrame_ReadUint32(b *testing.B) { + bf := NewByteFrame() + for i := 0; i < 1000; i++ { + bf.WriteUint32(0x12345678) + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.Seek(0, io.SeekStart) + bf.ReadUint32() + } +} + +func BenchmarkByteFrame_WriteBytes(b *testing.B) { + bf := NewByteFrame() + data := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.WriteBytes(data) + } +} diff --git a/common/decryption/jpk_test.go b/common/decryption/jpk_test.go new file mode 100644 index 000000000..159e034be --- /dev/null +++ b/common/decryption/jpk_test.go @@ -0,0 +1,234 @@ +package decryption + +import ( + "bytes" + "erupe-ce/common/byteframe" + "io" + "testing" +) + +func TestUnpackSimple_UncompressedData(t *testing.T) { + // Test data that doesn't have JPK header - should be returned as-is + input := []byte{0x00, 0x01, 0x02, 0x03, 0x04, 0x05} + result := UnpackSimple(input) + + if !bytes.Equal(result, input) { + t.Errorf("UnpackSimple() with uncompressed data should return input as-is, got %v, want %v", result, input) + } +} + +func TestUnpackSimple_InvalidHeader(t *testing.T) { + // Test data with wrong header + input := []byte{0xFF, 0xFF, 0xFF, 0xFF, 0x01, 0x02, 0x03, 0x04} + result := UnpackSimple(input) + + if !bytes.Equal(result, input) { + t.Errorf("UnpackSimple() with invalid header should return input as-is, got %v, want %v", result, input) + } +} + +func TestUnpackSimple_JPKHeaderWrongType(t *testing.T) { + // Test JPK header but wrong type (not type 3) + bf := byteframe.NewByteFrame() + bf.SetLE() + bf.WriteUint32(0x1A524B4A) // JPK header + bf.WriteUint16(0x00) // Reserved + bf.WriteUint16(1) // Type 1 instead of 3 + bf.WriteInt32(12) // Start offset + bf.WriteInt32(10) // Out size + + result := UnpackSimple(bf.Data()) + // Should return the input as-is since it's not type 3 + if !bytes.Equal(result, bf.Data()) { + t.Error("UnpackSimple() with non-type-3 JPK should return input as-is") + } +} + +func TestUnpackSimple_ValidJPKType3_EmptyData(t *testing.T) { + // Create a valid JPK type 3 header with minimal compressed data + bf := byteframe.NewByteFrame() + bf.SetLE() + bf.WriteUint32(0x1A524B4A) // JPK header "JKR\x1A" + bf.WriteUint16(0x00) // Reserved + bf.WriteUint16(3) // Type 3 + bf.WriteInt32(12) // Start offset (points to byte 12, after header) + bf.WriteInt32(0) // Out size (empty output) + + result := UnpackSimple(bf.Data()) + // Should return empty buffer + if len(result) != 0 { + t.Errorf("UnpackSimple() with zero output size should return empty slice, got length %d", len(result)) + } +} + +func TestUnpackSimple_JPKHeader(t *testing.T) { + // Test that the function correctly identifies JPK header (0x1A524B4A = "JKR\x1A" in little endian) + bf := byteframe.NewByteFrame() + bf.SetLE() + bf.WriteUint32(0x1A524B4A) // Correct JPK magic + + data := bf.Data() + if len(data) < 4 { + t.Fatal("Not enough data written") + } + + // Verify the header bytes are correct + bf.Seek(0, io.SeekStart) + header := bf.ReadUint32() + if header != 0x1A524B4A { + t.Errorf("Header = 0x%X, want 0x1A524B4A", header) + } +} + +func TestJPKBitShift_Initialization(t *testing.T) { + // Test that the function doesn't crash with bad initial global state + mShiftIndex = 10 + mFlag = 0xFF + + // Create data without JPK header (will return as-is) + // Need at least 4 bytes since UnpackSimple reads a uint32 header + bf := byteframe.NewByteFrame() + bf.WriteUint32(0xAABBCCDD) // Not a JPK header + + data := bf.Data() + result := UnpackSimple(data) + + // Without JPK header, should return data as-is + if !bytes.Equal(result, data) { + t.Error("UnpackSimple with non-JPK data should return input as-is") + } +} + +func TestReadByte(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint8(0x42) + bf.WriteUint8(0xAB) + + bf.Seek(0, io.SeekStart) + b1 := ReadByte(bf) + b2 := ReadByte(bf) + + if b1 != 0x42 { + t.Errorf("ReadByte() = 0x%X, want 0x42", b1) + } + if b2 != 0xAB { + t.Errorf("ReadByte() = 0x%X, want 0xAB", b2) + } +} + +func TestJPKCopy(t *testing.T) { + outBuffer := make([]byte, 20) + // Set up some initial data + outBuffer[0] = 'A' + outBuffer[1] = 'B' + outBuffer[2] = 'C' + + index := 3 + // Copy 3 bytes from offset 2 (looking back 2+1=3 positions) + JPKCopy(outBuffer, 2, 3, &index) + + // Should have copied 'A', 'B', 'C' to positions 3, 4, 5 + if outBuffer[3] != 'A' || outBuffer[4] != 'B' || outBuffer[5] != 'C' { + t.Errorf("JPKCopy failed: got %v at positions 3-5, want ['A', 'B', 'C']", outBuffer[3:6]) + } + if index != 6 { + t.Errorf("index = %d, want 6", index) + } +} + +func TestJPKCopy_OverlappingCopy(t *testing.T) { + // Test copying with overlapping regions (common in LZ-style compression) + outBuffer := make([]byte, 20) + outBuffer[0] = 'X' + + index := 1 + // Copy from 1 position back, 5 times - should repeat the pattern + JPKCopy(outBuffer, 0, 5, &index) + + // Should produce: X X X X X (repeating X) + for i := 1; i < 6; i++ { + if outBuffer[i] != 'X' { + t.Errorf("outBuffer[%d] = %c, want 'X'", i, outBuffer[i]) + } + } + if index != 6 { + t.Errorf("index = %d, want 6", index) + } +} + +func TestProcessDecode_EmptyOutput(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint8(0x00) + + outBuffer := make([]byte, 0) + // Should not panic with empty output buffer + ProcessDecode(bf, outBuffer) +} + +func TestUnpackSimple_EdgeCases(t *testing.T) { + // Test with data that has at least 4 bytes (header size required) + tests := []struct { + name string + input []byte + }{ + { + name: "four bytes non-JPK", + input: []byte{0x00, 0x01, 0x02, 0x03}, + }, + { + name: "partial header padded", + input: []byte{0x4A, 0x4B, 0x00, 0x00}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UnpackSimple(tt.input) + // Should return input as-is without crashing + if !bytes.Equal(result, tt.input) { + t.Errorf("UnpackSimple() = %v, want %v", result, tt.input) + } + }) + } +} + +func BenchmarkUnpackSimple_Uncompressed(b *testing.B) { + data := make([]byte, 1024) + for i := range data { + data[i] = byte(i % 256) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UnpackSimple(data) + } +} + +func BenchmarkUnpackSimple_JPKHeader(b *testing.B) { + bf := byteframe.NewByteFrame() + bf.SetLE() + bf.WriteUint32(0x1A524B4A) // JPK header + bf.WriteUint16(0x00) + bf.WriteUint16(3) + bf.WriteInt32(12) + bf.WriteInt32(0) + data := bf.Data() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UnpackSimple(data) + } +} + +func BenchmarkReadByte(b *testing.B) { + bf := byteframe.NewByteFrame() + for i := 0; i < 1000; i++ { + bf.WriteUint8(byte(i % 256)) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf.Seek(0, io.SeekStart) + _ = ReadByte(bf) + } +} diff --git a/common/mhfcid/mhfcid_test.go b/common/mhfcid/mhfcid_test.go new file mode 100644 index 000000000..ab18af15b --- /dev/null +++ b/common/mhfcid/mhfcid_test.go @@ -0,0 +1,258 @@ +package mhfcid + +import ( + "testing" +) + +func TestConvertCID(t *testing.T) { + tests := []struct { + name string + input string + expected uint32 + }{ + { + name: "all ones", + input: "111111", + expected: 0, // '1' maps to 0, so 0*32^0 + 0*32^1 + ... = 0 + }, + { + name: "all twos", + input: "222222", + expected: 1 + 32 + 1024 + 32768 + 1048576 + 33554432, // 1*32^0 + 1*32^1 + 1*32^2 + 1*32^3 + 1*32^4 + 1*32^5 + }, + { + name: "sequential", + input: "123456", + expected: 0 + 32 + 2*1024 + 3*32768 + 4*1048576 + 5*33554432, // 0 + 1*32 + 2*32^2 + 3*32^3 + 4*32^4 + 5*32^5 + }, + { + name: "with letters A-Z", + input: "ABCDEF", + expected: 9 + 10*32 + 11*1024 + 12*32768 + 13*1048576 + 14*33554432, + }, + { + name: "mixed numbers and letters", + input: "1A2B3C", + expected: 0 + 9*32 + 1*1024 + 10*32768 + 2*1048576 + 11*33554432, + }, + { + name: "max valid characters", + input: "ZZZZZZ", + expected: 31 + 31*32 + 31*1024 + 31*32768 + 31*1048576 + 31*33554432, // 31 * (1 + 32 + 1024 + 32768 + 1048576 + 33554432) + }, + { + name: "no banned chars: O excluded", + input: "N1P1Q1", // N=21, P=22, Q=23 - note no O + expected: 21 + 0*32 + 22*1024 + 0*32768 + 23*1048576 + 0*33554432, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertCID(tt.input) + if result != tt.expected { + t.Errorf("ConvertCID(%q) = %d, want %d", tt.input, result, tt.expected) + } + }) + } +} + +func TestConvertCID_InvalidLength(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"empty", ""}, + {"too short - 1", "1"}, + {"too short - 5", "12345"}, + {"too long - 7", "1234567"}, + {"too long - 10", "1234567890"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertCID(tt.input) + if result != 0 { + t.Errorf("ConvertCID(%q) = %d, want 0 (invalid length should return 0)", tt.input, result) + } + }) + } +} + +func TestConvertCID_BannedCharacters(t *testing.T) { + // Banned characters: 0, I, O, S + tests := []struct { + name string + input string + }{ + {"contains 0", "111011"}, + {"contains I", "111I11"}, + {"contains O", "11O111"}, + {"contains S", "S11111"}, + {"all banned", "000III"}, + {"mixed banned", "I0OS11"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertCID(tt.input) + // Characters not in the map will contribute 0 to the result + // The function doesn't explicitly reject them, it just doesn't map them + // So we're testing that banned characters don't crash the function + _ = result // Just verify it doesn't panic + }) + } +} + +func TestConvertCID_LowercaseNotSupported(t *testing.T) { + // The map only contains uppercase letters + input := "abcdef" + result := ConvertCID(input) + // Lowercase letters aren't mapped, so they'll contribute 0 + if result != 0 { + t.Logf("ConvertCID(%q) = %d (lowercase not in map, contributes 0)", input, result) + } +} + +func TestConvertCID_CharacterMapping(t *testing.T) { + // Verify specific character mappings + tests := []struct { + char rune + expected uint32 + }{ + {'1', 0}, + {'2', 1}, + {'9', 8}, + {'A', 9}, + {'B', 10}, + {'Z', 31}, + {'J', 17}, // J comes after I is skipped + {'P', 22}, // P comes after O is skipped + {'T', 25}, // T comes after S is skipped + } + + for _, tt := range tests { + t.Run(string(tt.char), func(t *testing.T) { + // Create a CID with the character in the first position (32^0) + input := string(tt.char) + "11111" + result := ConvertCID(input) + // The first character contributes its value * 32^0 = value * 1 + if result != tt.expected { + t.Errorf("ConvertCID(%q) first char value = %d, want %d", input, result, tt.expected) + } + }) + } +} + +func TestConvertCID_Base32Like(t *testing.T) { + // Test that it behaves like base-32 conversion + // The position multiplier should be powers of 32 + tests := []struct { + name string + input string + expected uint32 + }{ + { + name: "position 0 only", + input: "211111", // 2 in position 0 + expected: 1, // 1 * 32^0 + }, + { + name: "position 1 only", + input: "121111", // 2 in position 1 + expected: 32, // 1 * 32^1 + }, + { + name: "position 2 only", + input: "112111", // 2 in position 2 + expected: 1024, // 1 * 32^2 + }, + { + name: "position 3 only", + input: "111211", // 2 in position 3 + expected: 32768, // 1 * 32^3 + }, + { + name: "position 4 only", + input: "111121", // 2 in position 4 + expected: 1048576, // 1 * 32^4 + }, + { + name: "position 5 only", + input: "111112", // 2 in position 5 + expected: 33554432, // 1 * 32^5 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertCID(tt.input) + if result != tt.expected { + t.Errorf("ConvertCID(%q) = %d, want %d", tt.input, result, tt.expected) + } + }) + } +} + +func TestConvertCID_SkippedCharacters(t *testing.T) { + // Verify that 0, I, O, S are actually skipped in the character sequence + // The alphabet should be: 1-9 (0 skipped), A-H (I skipped), J-N (O skipped), P-R (S skipped), T-Z + + // Test that characters after skipped ones have the right values + tests := []struct { + name string + char1 string // Character before skip + char2 string // Character after skip + diff uint32 // Expected difference (should be 1) + }{ + {"before/after I skip", "H", "J", 1}, // H=16, J=17 + {"before/after O skip", "N", "P", 1}, // N=21, P=22 + {"before/after S skip", "R", "T", 1}, // R=24, T=25 + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cid1 := tt.char1 + "11111" + cid2 := tt.char2 + "11111" + val1 := ConvertCID(cid1) + val2 := ConvertCID(cid2) + diff := val2 - val1 + if diff != tt.diff { + t.Errorf("Difference between %s and %s = %d, want %d (val1=%d, val2=%d)", + tt.char1, tt.char2, diff, tt.diff, val1, val2) + } + }) + } +} + +func BenchmarkConvertCID(b *testing.B) { + testCID := "A1B2C3" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ConvertCID(testCID) + } +} + +func BenchmarkConvertCID_AllLetters(b *testing.B) { + testCID := "ABCDEF" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ConvertCID(testCID) + } +} + +func BenchmarkConvertCID_AllNumbers(b *testing.B) { + testCID := "123456" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ConvertCID(testCID) + } +} + +func BenchmarkConvertCID_InvalidLength(b *testing.B) { + testCID := "123" // Too short + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ConvertCID(testCID) + } +} diff --git a/common/mhfcourse/mhfcourse_test.go b/common/mhfcourse/mhfcourse_test.go new file mode 100644 index 000000000..fb1c416d8 --- /dev/null +++ b/common/mhfcourse/mhfcourse_test.go @@ -0,0 +1,385 @@ +package mhfcourse + +import ( + _config "erupe-ce/config" + "math" + "testing" + "time" +) + +func TestCourse_Aliases(t *testing.T) { + tests := []struct { + id uint16 + wantLen int + want []string + }{ + {1, 2, []string{"Trial", "TL"}}, + {2, 2, []string{"HunterLife", "HL"}}, + {3, 3, []string{"Extra", "ExtraA", "EX"}}, + {8, 4, []string{"Assist", "***ist", "Legend", "Rasta"}}, + {26, 4, []string{"NetCafe", "Cafe", "OfficialCafe", "Official"}}, + {13, 0, nil}, // Unknown course + {99, 0, nil}, // Unknown course + } + + for _, tt := range tests { + t.Run(string(rune(tt.id)), func(t *testing.T) { + c := Course{ID: tt.id} + got := c.Aliases() + if len(got) != tt.wantLen { + t.Errorf("Course{ID: %d}.Aliases() length = %d, want %d", tt.id, len(got), tt.wantLen) + } + if tt.want != nil { + for i, alias := range tt.want { + if i >= len(got) || got[i] != alias { + t.Errorf("Course{ID: %d}.Aliases()[%d] = %q, want %q", tt.id, i, got[i], alias) + } + } + } + }) + } +} + +func TestCourses(t *testing.T) { + courses := Courses() + if len(courses) != 32 { + t.Errorf("Courses() length = %d, want 32", len(courses)) + } + + // Verify IDs are sequential from 0 to 31 + for i, course := range courses { + if course.ID != uint16(i) { + t.Errorf("Courses()[%d].ID = %d, want %d", i, course.ID, i) + } + } +} + +func TestCourse_Value(t *testing.T) { + tests := []struct { + id uint16 + expected uint32 + }{ + {0, 1}, // 2^0 + {1, 2}, // 2^1 + {2, 4}, // 2^2 + {3, 8}, // 2^3 + {4, 16}, // 2^4 + {5, 32}, // 2^5 + {10, 1024}, // 2^10 + {15, 32768}, // 2^15 + {20, 1048576}, // 2^20 + {31, 2147483648}, // 2^31 + } + + for _, tt := range tests { + t.Run(string(rune(tt.id)), func(t *testing.T) { + c := Course{ID: tt.id} + got := c.Value() + if got != tt.expected { + t.Errorf("Course{ID: %d}.Value() = %d, want %d", tt.id, got, tt.expected) + } + }) + } +} + +func TestCourseExists(t *testing.T) { + courses := []Course{ + {ID: 1}, + {ID: 5}, + {ID: 10}, + {ID: 15}, + } + + tests := []struct { + name string + id uint16 + expected bool + }{ + {"exists first", 1, true}, + {"exists middle", 5, true}, + {"exists last", 15, true}, + {"not exists", 3, false}, + {"not exists 0", 0, false}, + {"not exists 20", 20, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := CourseExists(tt.id, courses) + if got != tt.expected { + t.Errorf("CourseExists(%d, courses) = %v, want %v", tt.id, got, tt.expected) + } + }) + } +} + +func TestCourseExists_EmptySlice(t *testing.T) { + var courses []Course + if CourseExists(1, courses) { + t.Error("CourseExists(1, []) should return false for empty slice") + } +} + +func TestGetCourseStruct(t *testing.T) { + // Save original config and restore after test + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + + // Set up test config + _config.ErupeConfig.DefaultCourses = []uint16{1, 2} + + tests := []struct { + name string + rights uint32 + wantMinLen int // Minimum expected courses (including defaults) + checkCourses []uint16 + }{ + { + name: "no rights", + rights: 0, + wantMinLen: 2, // Just default courses + checkCourses: []uint16{1, 2}, + }, + { + name: "course 3 only", + rights: 8, // 2^3 + wantMinLen: 3, // defaults + course 3 + checkCourses: []uint16{1, 2, 3}, + }, + { + name: "course 1", + rights: 2, // 2^1 + wantMinLen: 2, + checkCourses: []uint16{1, 2}, + }, + { + name: "multiple courses", + rights: 2 + 8 + 32, // courses 1, 3, 5 + wantMinLen: 4, + checkCourses: []uint16{1, 2, 3, 5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + courses, newRights := GetCourseStruct(tt.rights) + + if len(courses) < tt.wantMinLen { + t.Errorf("GetCourseStruct(%d) returned %d courses, want at least %d", tt.rights, len(courses), tt.wantMinLen) + } + + // Verify expected courses are present + for _, id := range tt.checkCourses { + found := false + for _, c := range courses { + if c.ID == id { + found = true + break + } + } + if !found { + t.Errorf("GetCourseStruct(%d) missing expected course ID %d", tt.rights, id) + } + } + + // Verify newRights is a valid sum of course values + if newRights < tt.rights { + t.Logf("GetCourseStruct(%d) newRights = %d (may include additional courses)", tt.rights, newRights) + } + }) + } +} + +func TestGetCourseStruct_NetcafeCourse(t *testing.T) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{} + + // Course 26 (NetCafe) should add course 25 + courses, _ := GetCourseStruct(1 << 26) + + hasNetcafe := false + hasCafeSP := false + hasRealNetcafe := false + for _, c := range courses { + if c.ID == 26 { + hasNetcafe = true + } + if c.ID == 25 { + hasCafeSP = true + } + if c.ID == 30 { + hasRealNetcafe = true + } + } + + if !hasNetcafe { + t.Error("Course 26 (NetCafe) should be present") + } + if !hasCafeSP { + t.Error("Course 25 should be added when course 26 is present") + } + if !hasRealNetcafe { + t.Error("Course 30 should be added when course 26 is present") + } +} + +func TestGetCourseStruct_NCourse(t *testing.T) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{} + + // Course 9 should add course 30 + courses, _ := GetCourseStruct(1 << 9) + + hasNCourse := false + hasRealNetcafe := false + for _, c := range courses { + if c.ID == 9 { + hasNCourse = true + } + if c.ID == 30 { + hasRealNetcafe = true + } + } + + if !hasNCourse { + t.Error("Course 9 (N) should be present") + } + if !hasRealNetcafe { + t.Error("Course 30 should be added when course 9 is present") + } +} + +func TestGetCourseStruct_HidenCourse(t *testing.T) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{} + + // Course 10 (Hiden) should add course 31 + courses, _ := GetCourseStruct(1 << 10) + + hasHiden := false + hasHidenExtra := false + for _, c := range courses { + if c.ID == 10 { + hasHiden = true + } + if c.ID == 31 { + hasHidenExtra = true + } + } + + if !hasHiden { + t.Error("Course 10 (Hiden) should be present") + } + if !hasHidenExtra { + t.Error("Course 31 should be added when course 10 is present") + } +} + +func TestGetCourseStruct_ExpiryDate(t *testing.T) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{} + + courses, _ := GetCourseStruct(1 << 3) + + expectedExpiry := time.Date(2030, 1, 1, 0, 0, 0, 0, time.FixedZone("UTC+9", 9*60*60)) + + for _, c := range courses { + if c.ID == 3 && !c.Expiry.IsZero() { + if !c.Expiry.Equal(expectedExpiry) { + t.Errorf("Course expiry = %v, want %v", c.Expiry, expectedExpiry) + } + } + } +} + +func TestGetCourseStruct_ReturnsRecalculatedRights(t *testing.T) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{} + + courses, newRights := GetCourseStruct(2 + 8 + 32) // courses 1, 3, 5 + + // Calculate expected rights from returned courses + var expectedRights uint32 + for _, c := range courses { + expectedRights += c.Value() + } + + if newRights != expectedRights { + t.Errorf("GetCourseStruct() newRights = %d, want %d (sum of returned course values)", newRights, expectedRights) + } +} + +func TestCourse_ValueMatchesPowerOfTwo(t *testing.T) { + // Verify that Value() correctly implements 2^ID + for id := uint16(0); id < 32; id++ { + c := Course{ID: id} + expected := uint32(math.Pow(2, float64(id))) + got := c.Value() + if got != expected { + t.Errorf("Course{ID: %d}.Value() = %d, want %d", id, got, expected) + } + } +} + +func BenchmarkCourse_Value(b *testing.B) { + c := Course{ID: 15} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = c.Value() + } +} + +func BenchmarkCourseExists(b *testing.B) { + courses := []Course{ + {ID: 1}, {ID: 2}, {ID: 3}, {ID: 4}, {ID: 5}, + {ID: 10}, {ID: 15}, {ID: 20}, {ID: 25}, {ID: 30}, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CourseExists(15, courses) + } +} + +func BenchmarkGetCourseStruct(b *testing.B) { + // Save original config + originalDefaultCourses := _config.ErupeConfig.DefaultCourses + defer func() { + _config.ErupeConfig.DefaultCourses = originalDefaultCourses + }() + _config.ErupeConfig.DefaultCourses = []uint16{1, 2} + + rights := uint32(2 + 8 + 32 + 128 + 512) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = GetCourseStruct(rights) + } +} + +func BenchmarkCourses(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Courses() + } +} diff --git a/common/mhfitem/mhfitem_test.go b/common/mhfitem/mhfitem_test.go new file mode 100644 index 000000000..c92e561eb --- /dev/null +++ b/common/mhfitem/mhfitem_test.go @@ -0,0 +1,551 @@ +package mhfitem + +import ( + "bytes" + "erupe-ce/common/byteframe" + "erupe-ce/common/token" + _config "erupe-ce/config" + "testing" +) + +func TestReadWarehouseItem(t *testing.T) { + bf := byteframe.NewByteFrame() + bf.WriteUint32(12345) // WarehouseID + bf.WriteUint16(100) // ItemID + bf.WriteUint16(5) // Quantity + bf.WriteUint32(999999) // Unk0 + + bf.Seek(0, 0) + item := ReadWarehouseItem(bf) + + if item.WarehouseID != 12345 { + t.Errorf("WarehouseID = %d, want 12345", item.WarehouseID) + } + if item.Item.ItemID != 100 { + t.Errorf("ItemID = %d, want 100", item.Item.ItemID) + } + if item.Quantity != 5 { + t.Errorf("Quantity = %d, want 5", item.Quantity) + } + if item.Unk0 != 999999 { + t.Errorf("Unk0 = %d, want 999999", item.Unk0) + } +} + +func TestReadWarehouseItem_ZeroWarehouseID(t *testing.T) { + // When WarehouseID is 0, it should be replaced with a random value + bf := byteframe.NewByteFrame() + bf.WriteUint32(0) // WarehouseID = 0 + bf.WriteUint16(100) // ItemID + bf.WriteUint16(5) // Quantity + bf.WriteUint32(0) // Unk0 + + bf.Seek(0, 0) + item := ReadWarehouseItem(bf) + + if item.WarehouseID == 0 { + t.Error("WarehouseID should be replaced with random value when input is 0") + } +} + +func TestMHFItemStack_ToBytes(t *testing.T) { + item := MHFItemStack{ + WarehouseID: 12345, + Item: MHFItem{ItemID: 100}, + Quantity: 5, + Unk0: 999999, + } + + data := item.ToBytes() + if len(data) != 12 { // 4 + 2 + 2 + 4 + t.Errorf("ToBytes() length = %d, want 12", len(data)) + } + + // Read it back + bf := byteframe.NewByteFrameFromBytes(data) + readItem := ReadWarehouseItem(bf) + + if readItem.WarehouseID != item.WarehouseID { + t.Errorf("WarehouseID = %d, want %d", readItem.WarehouseID, item.WarehouseID) + } + if readItem.Item.ItemID != item.Item.ItemID { + t.Errorf("ItemID = %d, want %d", readItem.Item.ItemID, item.Item.ItemID) + } + if readItem.Quantity != item.Quantity { + t.Errorf("Quantity = %d, want %d", readItem.Quantity, item.Quantity) + } + if readItem.Unk0 != item.Unk0 { + t.Errorf("Unk0 = %d, want %d", readItem.Unk0, item.Unk0) + } +} + +func TestSerializeWarehouseItems(t *testing.T) { + items := []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5, Unk0: 0}, + {WarehouseID: 2, Item: MHFItem{ItemID: 200}, Quantity: 10, Unk0: 0}, + } + + data := SerializeWarehouseItems(items) + bf := byteframe.NewByteFrameFromBytes(data) + + count := bf.ReadUint16() + if count != 2 { + t.Errorf("count = %d, want 2", count) + } + + bf.ReadUint16() // Skip unused + + for i := 0; i < 2; i++ { + item := ReadWarehouseItem(bf) + if item.WarehouseID != items[i].WarehouseID { + t.Errorf("item[%d] WarehouseID = %d, want %d", i, item.WarehouseID, items[i].WarehouseID) + } + if item.Item.ItemID != items[i].Item.ItemID { + t.Errorf("item[%d] ItemID = %d, want %d", i, item.Item.ItemID, items[i].Item.ItemID) + } + } +} + +func TestSerializeWarehouseItems_Empty(t *testing.T) { + items := []MHFItemStack{} + data := SerializeWarehouseItems(items) + bf := byteframe.NewByteFrameFromBytes(data) + + count := bf.ReadUint16() + if count != 0 { + t.Errorf("count = %d, want 0", count) + } +} + +func TestDiffItemStacks(t *testing.T) { + tests := []struct { + name string + old []MHFItemStack + update []MHFItemStack + wantLen int + checkFn func(t *testing.T, result []MHFItemStack) + }{ + { + name: "update existing quantity", + old: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5}, + }, + update: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 10}, + }, + wantLen: 1, + checkFn: func(t *testing.T, result []MHFItemStack) { + if result[0].Quantity != 10 { + t.Errorf("Quantity = %d, want 10", result[0].Quantity) + } + }, + }, + { + name: "add new item", + old: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5}, + }, + update: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5}, + {WarehouseID: 0, Item: MHFItem{ItemID: 200}, Quantity: 3}, // WarehouseID 0 = new + }, + wantLen: 2, + checkFn: func(t *testing.T, result []MHFItemStack) { + hasNewItem := false + for _, item := range result { + if item.Item.ItemID == 200 { + hasNewItem = true + if item.WarehouseID == 0 { + t.Error("New item should have generated WarehouseID") + } + } + } + if !hasNewItem { + t.Error("New item should be in result") + } + }, + }, + { + name: "remove item (quantity 0)", + old: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5}, + {WarehouseID: 2, Item: MHFItem{ItemID: 200}, Quantity: 10}, + }, + update: []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 0}, // Removed + }, + wantLen: 1, + checkFn: func(t *testing.T, result []MHFItemStack) { + for _, item := range result { + if item.WarehouseID == 1 { + t.Error("Item with quantity 0 should be removed") + } + } + }, + }, + { + name: "empty old, add new", + old: []MHFItemStack{}, + update: []MHFItemStack{{WarehouseID: 0, Item: MHFItem{ItemID: 100}, Quantity: 5}}, + wantLen: 1, + checkFn: func(t *testing.T, result []MHFItemStack) { + if len(result) != 1 || result[0].Item.ItemID != 100 { + t.Error("Should add new item to empty list") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := DiffItemStacks(tt.old, tt.update) + if len(result) != tt.wantLen { + t.Errorf("DiffItemStacks() length = %d, want %d", len(result), tt.wantLen) + } + if tt.checkFn != nil { + tt.checkFn(t, result) + } + }) + } +} + +func TestReadWarehouseEquipment(t *testing.T) { + // Save original config + originalMode := _config.ErupeConfig.RealClientMode + defer func() { + _config.ErupeConfig.RealClientMode = originalMode + }() + _config.ErupeConfig.RealClientMode = _config.Z1 + + bf := byteframe.NewByteFrame() + bf.WriteUint32(12345) // WarehouseID + bf.WriteUint8(1) // ItemType + bf.WriteUint8(2) // Unk0 + bf.WriteUint16(100) // ItemID + bf.WriteUint16(5) // Level + + // Write 3 decorations + bf.WriteUint16(201) + bf.WriteUint16(202) + bf.WriteUint16(203) + + // Write 3 sigils (G1+) + for i := 0; i < 3; i++ { + // 3 effects per sigil + for j := 0; j < 3; j++ { + bf.WriteUint16(uint16(300 + i*10 + j)) // Effect ID + } + for j := 0; j < 3; j++ { + bf.WriteUint16(uint16(1 + j)) // Effect Level + } + bf.WriteUint8(10) + bf.WriteUint8(11) + bf.WriteUint8(12) + bf.WriteUint8(13) + } + + // Unk1 (Z1+) + bf.WriteUint16(9999) + + bf.Seek(0, 0) + equipment := ReadWarehouseEquipment(bf) + + if equipment.WarehouseID != 12345 { + t.Errorf("WarehouseID = %d, want 12345", equipment.WarehouseID) + } + if equipment.ItemType != 1 { + t.Errorf("ItemType = %d, want 1", equipment.ItemType) + } + if equipment.ItemID != 100 { + t.Errorf("ItemID = %d, want 100", equipment.ItemID) + } + if equipment.Level != 5 { + t.Errorf("Level = %d, want 5", equipment.Level) + } + if equipment.Decorations[0].ItemID != 201 { + t.Errorf("Decoration[0] = %d, want 201", equipment.Decorations[0].ItemID) + } + if equipment.Sigils[0].Effects[0].ID != 300 { + t.Errorf("Sigil[0].Effect[0].ID = %d, want 300", equipment.Sigils[0].Effects[0].ID) + } + if equipment.Unk1 != 9999 { + t.Errorf("Unk1 = %d, want 9999", equipment.Unk1) + } +} + +func TestReadWarehouseEquipment_ZeroWarehouseID(t *testing.T) { + // Save original config + originalMode := _config.ErupeConfig.RealClientMode + defer func() { + _config.ErupeConfig.RealClientMode = originalMode + }() + _config.ErupeConfig.RealClientMode = _config.Z1 + + bf := byteframe.NewByteFrame() + bf.WriteUint32(0) // WarehouseID = 0 + bf.WriteUint8(1) + bf.WriteUint8(2) + bf.WriteUint16(100) + bf.WriteUint16(5) + // Write decorations + for i := 0; i < 3; i++ { + bf.WriteUint16(0) + } + // Write sigils + for i := 0; i < 3; i++ { + for j := 0; j < 6; j++ { + bf.WriteUint16(0) + } + bf.WriteUint8(0) + bf.WriteUint8(0) + bf.WriteUint8(0) + bf.WriteUint8(0) + } + bf.WriteUint16(0) + + bf.Seek(0, 0) + equipment := ReadWarehouseEquipment(bf) + + if equipment.WarehouseID == 0 { + t.Error("WarehouseID should be replaced with random value when input is 0") + } +} + +func TestMHFEquipment_ToBytes(t *testing.T) { + // Save original config + originalMode := _config.ErupeConfig.RealClientMode + defer func() { + _config.ErupeConfig.RealClientMode = originalMode + }() + _config.ErupeConfig.RealClientMode = _config.Z1 + + equipment := MHFEquipment{ + WarehouseID: 12345, + ItemType: 1, + Unk0: 2, + ItemID: 100, + Level: 5, + Decorations: []MHFItem{{ItemID: 201}, {ItemID: 202}, {ItemID: 203}}, + Sigils: make([]MHFSigil, 3), + Unk1: 9999, + } + for i := 0; i < 3; i++ { + equipment.Sigils[i].Effects = make([]MHFSigilEffect, 3) + } + + data := equipment.ToBytes() + bf := byteframe.NewByteFrameFromBytes(data) + readEquipment := ReadWarehouseEquipment(bf) + + if readEquipment.WarehouseID != equipment.WarehouseID { + t.Errorf("WarehouseID = %d, want %d", readEquipment.WarehouseID, equipment.WarehouseID) + } + if readEquipment.ItemID != equipment.ItemID { + t.Errorf("ItemID = %d, want %d", readEquipment.ItemID, equipment.ItemID) + } + if readEquipment.Level != equipment.Level { + t.Errorf("Level = %d, want %d", readEquipment.Level, equipment.Level) + } + if readEquipment.Unk1 != equipment.Unk1 { + t.Errorf("Unk1 = %d, want %d", readEquipment.Unk1, equipment.Unk1) + } +} + +func TestSerializeWarehouseEquipment(t *testing.T) { + // Save original config + originalMode := _config.ErupeConfig.RealClientMode + defer func() { + _config.ErupeConfig.RealClientMode = originalMode + }() + _config.ErupeConfig.RealClientMode = _config.Z1 + + equipment := []MHFEquipment{ + { + WarehouseID: 1, + ItemType: 1, + ItemID: 100, + Level: 5, + Decorations: []MHFItem{{ItemID: 0}, {ItemID: 0}, {ItemID: 0}}, + Sigils: make([]MHFSigil, 3), + }, + { + WarehouseID: 2, + ItemType: 2, + ItemID: 200, + Level: 10, + Decorations: []MHFItem{{ItemID: 0}, {ItemID: 0}, {ItemID: 0}}, + Sigils: make([]MHFSigil, 3), + }, + } + for i := range equipment { + for j := 0; j < 3; j++ { + equipment[i].Sigils[j].Effects = make([]MHFSigilEffect, 3) + } + } + + data := SerializeWarehouseEquipment(equipment) + bf := byteframe.NewByteFrameFromBytes(data) + + count := bf.ReadUint16() + if count != 2 { + t.Errorf("count = %d, want 2", count) + } +} + +func TestMHFEquipment_RoundTrip(t *testing.T) { + // Test that we can write and read back the same equipment + originalMode := _config.ErupeConfig.RealClientMode + defer func() { + _config.ErupeConfig.RealClientMode = originalMode + }() + _config.ErupeConfig.RealClientMode = _config.Z1 + + original := MHFEquipment{ + WarehouseID: 99999, + ItemType: 5, + Unk0: 10, + ItemID: 500, + Level: 25, + Decorations: []MHFItem{{ItemID: 1}, {ItemID: 2}, {ItemID: 3}}, + Sigils: make([]MHFSigil, 3), + Unk1: 12345, + } + for i := 0; i < 3; i++ { + original.Sigils[i].Effects = []MHFSigilEffect{ + {ID: uint16(100 + i), Level: 1}, + {ID: uint16(200 + i), Level: 2}, + {ID: uint16(300 + i), Level: 3}, + } + } + + // Write to bytes + data := original.ToBytes() + + // Read back + bf := byteframe.NewByteFrameFromBytes(data) + recovered := ReadWarehouseEquipment(bf) + + // Compare + if recovered.WarehouseID != original.WarehouseID { + t.Errorf("WarehouseID = %d, want %d", recovered.WarehouseID, original.WarehouseID) + } + if recovered.ItemType != original.ItemType { + t.Errorf("ItemType = %d, want %d", recovered.ItemType, original.ItemType) + } + if recovered.ItemID != original.ItemID { + t.Errorf("ItemID = %d, want %d", recovered.ItemID, original.ItemID) + } + if recovered.Level != original.Level { + t.Errorf("Level = %d, want %d", recovered.Level, original.Level) + } + for i := 0; i < 3; i++ { + if recovered.Decorations[i].ItemID != original.Decorations[i].ItemID { + t.Errorf("Decoration[%d] = %d, want %d", i, recovered.Decorations[i].ItemID, original.Decorations[i].ItemID) + } + } +} + +func BenchmarkReadWarehouseItem(b *testing.B) { + bf := byteframe.NewByteFrame() + bf.WriteUint32(12345) + bf.WriteUint16(100) + bf.WriteUint16(5) + bf.WriteUint32(0) + data := bf.Data() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrameFromBytes(data) + _ = ReadWarehouseItem(bf) + } +} + +func BenchmarkDiffItemStacks(b *testing.B) { + old := []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 5}, + {WarehouseID: 2, Item: MHFItem{ItemID: 200}, Quantity: 10}, + {WarehouseID: 3, Item: MHFItem{ItemID: 300}, Quantity: 15}, + } + update := []MHFItemStack{ + {WarehouseID: 1, Item: MHFItem{ItemID: 100}, Quantity: 8}, + {WarehouseID: 0, Item: MHFItem{ItemID: 400}, Quantity: 3}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = DiffItemStacks(old, update) + } +} + +func BenchmarkSerializeWarehouseItems(b *testing.B) { + items := make([]MHFItemStack, 100) + for i := range items { + items[i] = MHFItemStack{ + WarehouseID: uint32(i), + Item: MHFItem{ItemID: uint16(i)}, + Quantity: uint16(i % 99), + } + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = SerializeWarehouseItems(items) + } +} + +func TestMHFItemStack_ToBytes_RoundTrip(t *testing.T) { + original := MHFItemStack{ + WarehouseID: 12345, + Item: MHFItem{ItemID: 999}, + Quantity: 42, + Unk0: 777, + } + + data := original.ToBytes() + bf := byteframe.NewByteFrameFromBytes(data) + recovered := ReadWarehouseItem(bf) + + if !bytes.Equal(original.ToBytes(), recovered.ToBytes()) { + t.Error("Round-trip serialization failed") + } +} + +func TestDiffItemStacks_PreserveOldWarehouseID(t *testing.T) { + // Verify that when updating existing items, the old WarehouseID is preserved + old := []MHFItemStack{ + {WarehouseID: 555, Item: MHFItem{ItemID: 100}, Quantity: 5}, + } + update := []MHFItemStack{ + {WarehouseID: 555, Item: MHFItem{ItemID: 100}, Quantity: 10}, + } + + result := DiffItemStacks(old, update) + if len(result) != 1 { + t.Fatalf("Expected 1 item, got %d", len(result)) + } + if result[0].WarehouseID != 555 { + t.Errorf("WarehouseID = %d, want 555", result[0].WarehouseID) + } + if result[0].Quantity != 10 { + t.Errorf("Quantity = %d, want 10", result[0].Quantity) + } +} + +func TestDiffItemStacks_GeneratesNewWarehouseID(t *testing.T) { + // Verify that new items get a generated WarehouseID + old := []MHFItemStack{} + update := []MHFItemStack{ + {WarehouseID: 0, Item: MHFItem{ItemID: 100}, Quantity: 5}, + } + + // Reset RNG for consistent test + token.RNG = token.NewRNG() + + result := DiffItemStacks(old, update) + if len(result) != 1 { + t.Fatalf("Expected 1 item, got %d", len(result)) + } + if result[0].WarehouseID == 0 { + t.Error("New item should have generated WarehouseID, got 0") + } +} diff --git a/common/mhfmon/mhfmon_test.go b/common/mhfmon/mhfmon_test.go new file mode 100644 index 000000000..b2560840c --- /dev/null +++ b/common/mhfmon/mhfmon_test.go @@ -0,0 +1,371 @@ +package mhfmon + +import ( + "testing" +) + +func TestMonsters_Length(t *testing.T) { + // Verify that the Monsters slice has entries + actualLen := len(Monsters) + if actualLen == 0 { + t.Fatal("Monsters slice is empty") + } + // The slice has 177 entries (some constants may not have entries) + if actualLen < 170 { + t.Errorf("Monsters length = %d, seems too small", actualLen) + } +} + +func TestMonsters_IndexMatchesConstant(t *testing.T) { + // Test that the index in the slice matches the constant value + tests := []struct { + index int + name string + large bool + }{ + {Mon0, "Mon0", false}, + {Rathian, "Rathian", true}, + {Fatalis, "Fatalis", true}, + {Kelbi, "Kelbi", false}, + {Rathalos, "Rathalos", true}, + {Diablos, "Diablos", true}, + {Rajang, "Rajang", true}, + {Zinogre, "Zinogre", true}, + {Deviljho, "Deviljho", true}, + {KingShakalaka, "King Shakalaka", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.index >= len(Monsters) { + t.Fatalf("Index %d out of bounds", tt.index) + } + monster := Monsters[tt.index] + if monster.Name != tt.name { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, monster.Name, tt.name) + } + if monster.Large != tt.large { + t.Errorf("Monsters[%d].Large = %v, want %v", tt.index, monster.Large, tt.large) + } + }) + } +} + +func TestMonsters_AllLargeMonsters(t *testing.T) { + // Verify some known large monsters + largeMonsters := []int{ + Rathian, + Fatalis, + YianKutKu, + LaoShanLung, + Cephadrome, + Rathalos, + Diablos, + Khezu, + Gravios, + Tigrex, + Zinogre, + Deviljho, + Brachydios, + } + + for _, idx := range largeMonsters { + if !Monsters[idx].Large { + t.Errorf("Monsters[%d] (%s) should be marked as large", idx, Monsters[idx].Name) + } + } +} + +func TestMonsters_AllSmallMonsters(t *testing.T) { + // Verify some known small monsters + smallMonsters := []int{ + Kelbi, + Mosswine, + Bullfango, + Felyne, + Aptonoth, + Genprey, + Velociprey, + Melynx, + Hornetaur, + Apceros, + Ioprey, + Giaprey, + Cephalos, + Blango, + Conga, + Remobra, + GreatThunderbug, + Shakalaka, + } + + for _, idx := range smallMonsters { + if Monsters[idx].Large { + t.Errorf("Monsters[%d] (%s) should be marked as small", idx, Monsters[idx].Name) + } + } +} + +func TestMonsters_Constants(t *testing.T) { + // Test that constants have expected values + tests := []struct { + constant int + expected int + }{ + {Mon0, 0}, + {Rathian, 1}, + {Fatalis, 2}, + {Kelbi, 3}, + {Rathalos, 11}, + {Diablos, 14}, + {Rajang, 53}, + {Zinogre, 146}, + {Deviljho, 147}, + {Brachydios, 148}, + {KingShakalaka, 176}, + } + + for _, tt := range tests { + if tt.constant != tt.expected { + t.Errorf("Constant = %d, want %d", tt.constant, tt.expected) + } + } +} + +func TestMonsters_NameConsistency(t *testing.T) { + // Test that specific monsters have correct names + tests := []struct { + index int + expectedName string + }{ + {Rathian, "Rathian"}, + {Rathalos, "Rathalos"}, + {YianKutKu, "Yian Kut-Ku"}, + {LaoShanLung, "Lao-Shan Lung"}, + {KushalaDaora, "Kushala Daora"}, + {Tigrex, "Tigrex"}, + {Rajang, "Rajang"}, + {Zinogre, "Zinogre"}, + {Deviljho, "Deviljho"}, + {Brachydios, "Brachydios"}, + {Nargacuga, "Nargacuga"}, + {GoreMagala, "Gore Magala"}, + {ShagaruMagala, "Shagaru Magala"}, + {KingShakalaka, "King Shakalaka"}, + } + + for _, tt := range tests { + t.Run(tt.expectedName, func(t *testing.T) { + if Monsters[tt.index].Name != tt.expectedName { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.expectedName) + } + }) + } +} + +func TestMonsters_SubspeciesNames(t *testing.T) { + // Test subspecies have appropriate names + tests := []struct { + index int + expectedName string + }{ + {PinkRathian, "Pink Rathian"}, + {AzureRathalos, "Azure Rathalos"}, + {SilverRathalos, "Silver Rathalos"}, + {GoldRathian, "Gold Rathian"}, + {BlackDiablos, "Black Diablos"}, + {WhiteMonoblos, "White Monoblos"}, + {RedKhezu, "Red Khezu"}, + {CrimsonFatalis, "Crimson Fatalis"}, + {WhiteFatalis, "White Fatalis"}, + {StygianZinogre, "Stygian Zinogre"}, + {SavageDeviljho, "Savage Deviljho"}, + } + + for _, tt := range tests { + t.Run(tt.expectedName, func(t *testing.T) { + if Monsters[tt.index].Name != tt.expectedName { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.expectedName) + } + }) + } +} + +func TestMonsters_PlaceholderMonsters(t *testing.T) { + // Test that placeholder monsters exist + placeholders := []int{Mon0, Mon18, Mon29, Mon32, Mon72, Mon86, Mon87, Mon88, Mon118, Mon133, Mon134, Mon135, Mon136, Mon137, Mon138, Mon156, Mon168, Mon171} + + for _, idx := range placeholders { + if idx >= len(Monsters) { + t.Errorf("Placeholder monster index %d out of bounds", idx) + continue + } + // Placeholder monsters should be marked as small (non-large) + if Monsters[idx].Large { + t.Errorf("Placeholder Monsters[%d] (%s) should not be marked as large", idx, Monsters[idx].Name) + } + } +} + +func TestMonsters_FrontierMonsters(t *testing.T) { + // Test some MH Frontier-specific monsters + frontierMonsters := []struct { + index int + name string + }{ + {Espinas, "Espinas"}, + {Berukyurosu, "Berukyurosu"}, + {Pariapuria, "Pariapuria"}, + {Raviente, "Raviente"}, + {Dyuragaua, "Dyuragaua"}, + {Doragyurosu, "Doragyurosu"}, + {Gurenzeburu, "Gurenzeburu"}, + {Rukodiora, "Rukodiora"}, + {Gogomoa, "Gogomoa"}, + {Disufiroa, "Disufiroa"}, + {Rebidiora, "Rebidiora"}, + {MiRu, "Mi-Ru"}, + {Shantien, "Shantien"}, + {Zerureusu, "Zerureusu"}, + {GarubaDaora, "Garuba Daora"}, + {Harudomerugu, "Harudomerugu"}, + {Toridcless, "Toridcless"}, + {Guanzorumu, "Guanzorumu"}, + {Egyurasu, "Egyurasu"}, + {Bogabadorumu, "Bogabadorumu"}, + } + + for _, tt := range frontierMonsters { + t.Run(tt.name, func(t *testing.T) { + if tt.index >= len(Monsters) { + t.Fatalf("Index %d out of bounds", tt.index) + } + if Monsters[tt.index].Name != tt.name { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.name) + } + // Most Frontier monsters should be large + if !Monsters[tt.index].Large { + t.Logf("Frontier monster %s is marked as small", tt.name) + } + }) + } +} + +func TestMonsters_DuremudiraVariants(t *testing.T) { + // Test Duremudira variants + tests := []struct { + index int + name string + }{ + {Block1Duremudira, "1st Block Duremudira"}, + {Block2Duremudira, "2nd Block Duremudira"}, + {MusouDuremudira, "Musou Duremudira"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if Monsters[tt.index].Name != tt.name { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.name) + } + if !Monsters[tt.index].Large { + t.Errorf("Duremudira variant should be marked as large") + } + }) + } +} + +func TestMonsters_RalienteVariants(t *testing.T) { + // Test Raviente variants + tests := []struct { + index int + name string + }{ + {Raviente, "Raviente"}, + {BerserkRaviente, "Berserk Raviente"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if Monsters[tt.index].Name != tt.name { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.name) + } + if !Monsters[tt.index].Large { + t.Errorf("Raviente variant should be marked as large") + } + }) + } +} + +func TestMonsters_NoHoles(t *testing.T) { + // Verify that there are no nil entries or empty names (except for placeholder "MonXX" entries) + for i, monster := range Monsters { + if monster.Name == "" { + t.Errorf("Monsters[%d] has empty name", i) + } + } +} + +func TestMonster_Struct(t *testing.T) { + // Test that Monster struct is properly defined + m := Monster{ + Name: "Test Monster", + Large: true, + } + + if m.Name != "Test Monster" { + t.Errorf("Name = %q, want %q", m.Name, "Test Monster") + } + if !m.Large { + t.Error("Large should be true") + } +} + +func BenchmarkAccessMonster(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Monsters[Rathalos] + } +} + +func BenchmarkAccessMonsterName(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Monsters[Zinogre].Name + } +} + +func BenchmarkAccessMonsterLarge(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Monsters[Deviljho].Large + } +} + +func TestMonsters_CrossoverMonsters(t *testing.T) { + // Test crossover monsters (from other games) + tests := []struct { + index int + name string + }{ + {Zinogre, "Zinogre"}, // From MH Portable 3rd + {Deviljho, "Deviljho"}, // From MH3 + {Brachydios, "Brachydios"}, // From MH3G + {Barioth, "Barioth"}, // From MH3 + {Uragaan, "Uragaan"}, // From MH3 + {Nargacuga, "Nargacuga"}, // From MH Freedom Unite + {GoreMagala, "Gore Magala"}, // From MH4 + {Amatsu, "Amatsu"}, // From MH Portable 3rd + {Seregios, "Seregios"}, // From MH4G + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if Monsters[tt.index].Name != tt.name { + t.Errorf("Monsters[%d].Name = %q, want %q", tt.index, Monsters[tt.index].Name, tt.name) + } + if !Monsters[tt.index].Large { + t.Errorf("Crossover large monster %s should be marked as large", tt.name) + } + }) + } +} diff --git a/common/pascalstring/pascalstring_test.go b/common/pascalstring/pascalstring_test.go new file mode 100644 index 000000000..8c4e145c0 --- /dev/null +++ b/common/pascalstring/pascalstring_test.go @@ -0,0 +1,369 @@ +package pascalstring + +import ( + "bytes" + "erupe-ce/common/byteframe" + "testing" +) + +func TestUint8_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Hello" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + expectedLength := uint8(len(testString) + 1) // +1 for null terminator + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + // Should be "Hello\x00" + expected := []byte("Hello\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint8_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + // ASCII string (no special characters) + testString := "Test" + + Uint8(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint8() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + // Should end with null terminator + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint8_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + + if length != 1 { // Just null terminator + t.Errorf("length = %d, want 1", length) + } + + data := bf.ReadBytes(uint(length)) + if data[0] != 0 { + t.Error("empty string should produce just null terminator") + } +} + +func TestUint16_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "World" + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + expectedLength := uint16(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + expected := []byte("World\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint16_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + Uint16(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint16_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length != 1 { + t.Errorf("length = %d, want 1", length) + } +} + +func TestUint32_NoTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Testing" + + Uint32(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint32() + expectedLength := uint32(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + expected := []byte("Testing\x00") + if !bytes.Equal(data, expected) { + t.Errorf("data = %v, want %v", data, expected) + } +} + +func TestUint32_WithTransform(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + Uint32(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint32() + + if length == 0 { + t.Error("length should not be 0 for ASCII string") + } + + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("data should end with null terminator") + } +} + +func TestUint32_EmptyString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "" + + Uint32(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint32() + + if length != 1 { + t.Errorf("length = %d, want 1", length) + } +} + +func TestUint8_LongString(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "This is a longer test string with more characters" + + Uint8(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint8() + expectedLength := uint8(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + if !bytes.HasSuffix(data, []byte{0}) { + t.Error("data should end with null terminator") + } + if !bytes.HasPrefix(data, []byte("This is")) { + t.Error("data should start with expected string") + } +} + +func TestUint16_LongString(t *testing.T) { + bf := byteframe.NewByteFrame() + // Create a string longer than 255 to test uint16 + testString := "" + for i := 0; i < 300; i++ { + testString += "A" + } + + Uint16(bf, testString, false) + + bf.Seek(0, 0) + length := bf.ReadUint16() + expectedLength := uint16(len(testString) + 1) + + if length != expectedLength { + t.Errorf("length = %d, want %d", length, expectedLength) + } + + data := bf.ReadBytes(uint(length)) + if !bytes.HasSuffix(data, []byte{0}) { + t.Error("data should end with null terminator") + } +} + +func TestAllFunctions_NullTermination(t *testing.T) { + tests := []struct { + name string + writeFn func(*byteframe.ByteFrame, string, bool) + readSize func(*byteframe.ByteFrame) uint + }{ + { + name: "Uint8", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint8(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint8()) + }, + }, + { + name: "Uint16", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint16(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint16()) + }, + }, + { + name: "Uint32", + writeFn: func(bf *byteframe.ByteFrame, s string, t bool) { + Uint32(bf, s, t) + }, + readSize: func(bf *byteframe.ByteFrame) uint { + return uint(bf.ReadUint32()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + bf := byteframe.NewByteFrame() + testString := "Test" + + tt.writeFn(bf, testString, false) + + bf.Seek(0, 0) + size := tt.readSize(bf) + data := bf.ReadBytes(size) + + // Verify null termination + if data[len(data)-1] != 0 { + t.Errorf("%s: data should end with null terminator", tt.name) + } + + // Verify length includes null terminator + if size != uint(len(testString)+1) { + t.Errorf("%s: size = %d, want %d", tt.name, size, len(testString)+1) + } + }) + } +} + +func TestTransform_JapaneseCharacters(t *testing.T) { + // Test with Japanese characters that should be transformed to Shift-JIS + bf := byteframe.NewByteFrame() + testString := "テスト" // "Test" in Japanese katakana + + Uint16(bf, testString, true) + + bf.Seek(0, 0) + length := bf.ReadUint16() + + if length == 0 { + t.Error("Transformed Japanese string should have non-zero length") + } + + // The transformed Shift-JIS should be different length than UTF-8 + // UTF-8: 9 bytes (3 chars * 3 bytes each), Shift-JIS: 6 bytes (3 chars * 2 bytes each) + 1 null + data := bf.ReadBytes(uint(length)) + if data[len(data)-1] != 0 { + t.Error("Transformed string should end with null terminator") + } +} + +func TestTransform_InvalidUTF8(t *testing.T) { + // This test verifies graceful handling of encoding errors + // When transformation fails, the functions should write length 0 + + bf := byteframe.NewByteFrame() + // Create a string with invalid UTF-8 sequence + // Note: Go strings are generally valid UTF-8, but we can test the error path + testString := "Valid ASCII" + + Uint8(bf, testString, true) + // Should succeed for ASCII characters + + bf.Seek(0, 0) + length := bf.ReadUint8() + if length == 0 { + t.Error("ASCII string should transform successfully") + } +} + +func BenchmarkUint8_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint8(bf, testString, false) + } +} + +func BenchmarkUint8_WithTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint8(bf, testString, true) + } +} + +func BenchmarkUint16_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint16(bf, testString, false) + } +} + +func BenchmarkUint32_NoTransform(b *testing.B) { + testString := "Hello, World!" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint32(bf, testString, false) + } +} + +func BenchmarkUint16_Japanese(b *testing.B) { + testString := "テストメッセージ" + b.ResetTimer() + for i := 0; i < b.N; i++ { + bf := byteframe.NewByteFrame() + Uint16(bf, testString, true) + } +} diff --git a/common/stringstack/stringstack_test.go b/common/stringstack/stringstack_test.go new file mode 100644 index 000000000..3bfcf7656 --- /dev/null +++ b/common/stringstack/stringstack_test.go @@ -0,0 +1,343 @@ +package stringstack + +import ( + "testing" +) + +func TestNew(t *testing.T) { + s := New() + if s == nil { + t.Fatal("New() returned nil") + } + if len(s.stack) != 0 { + t.Errorf("New() stack length = %d, want 0", len(s.stack)) + } +} + +func TestStringStack_Set(t *testing.T) { + s := New() + s.Set("first") + + if len(s.stack) != 1 { + t.Errorf("Set() stack length = %d, want 1", len(s.stack)) + } + if s.stack[0] != "first" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "first") + } +} + +func TestStringStack_Set_Replaces(t *testing.T) { + s := New() + s.Push("item1") + s.Push("item2") + s.Push("item3") + + // Set should replace the entire stack + s.Set("new_item") + + if len(s.stack) != 1 { + t.Errorf("Set() stack length = %d, want 1", len(s.stack)) + } + if s.stack[0] != "new_item" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "new_item") + } +} + +func TestStringStack_Push(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + s.Push("third") + + if len(s.stack) != 3 { + t.Errorf("Push() stack length = %d, want 3", len(s.stack)) + } + if s.stack[0] != "first" { + t.Errorf("stack[0] = %q, want %q", s.stack[0], "first") + } + if s.stack[1] != "second" { + t.Errorf("stack[1] = %q, want %q", s.stack[1], "second") + } + if s.stack[2] != "third" { + t.Errorf("stack[2] = %q, want %q", s.stack[2], "third") + } +} + +func TestStringStack_Pop(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + s.Push("third") + + // Pop should return LIFO (last in, first out) + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "third" { + t.Errorf("Pop() = %q, want %q", val, "third") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "second" { + t.Errorf("Pop() = %q, want %q", val, "second") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v, want nil", err) + } + if val != "first" { + t.Errorf("Pop() = %q, want %q", val, "first") + } + + if len(s.stack) != 0 { + t.Errorf("stack length = %d, want 0 after popping all items", len(s.stack)) + } +} + +func TestStringStack_Pop_Empty(t *testing.T) { + s := New() + + val, err := s.Pop() + if err == nil { + t.Error("Pop() on empty stack should return error") + } + if val != "" { + t.Errorf("Pop() on empty stack returned %q, want empty string", val) + } + + expectedError := "no items on stack" + if err.Error() != expectedError { + t.Errorf("Pop() error = %q, want %q", err.Error(), expectedError) + } +} + +func TestStringStack_LIFO_Behavior(t *testing.T) { + s := New() + items := []string{"A", "B", "C", "D", "E"} + + for _, item := range items { + s.Push(item) + } + + // Pop should return in reverse order (LIFO) + for i := len(items) - 1; i >= 0; i-- { + val, err := s.Pop() + if err != nil { + t.Fatalf("Pop() error = %v", err) + } + if val != items[i] { + t.Errorf("Pop() = %q, want %q", val, items[i]) + } + } +} + +func TestStringStack_PushAfterPop(t *testing.T) { + s := New() + s.Push("first") + s.Push("second") + + val, _ := s.Pop() + if val != "second" { + t.Errorf("Pop() = %q, want %q", val, "second") + } + + s.Push("third") + + val, _ = s.Pop() + if val != "third" { + t.Errorf("Pop() = %q, want %q", val, "third") + } + + val, _ = s.Pop() + if val != "first" { + t.Errorf("Pop() = %q, want %q", val, "first") + } +} + +func TestStringStack_EmptyStrings(t *testing.T) { + s := New() + s.Push("") + s.Push("text") + s.Push("") + + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "" { + t.Errorf("Pop() = %q, want empty string", val) + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "text" { + t.Errorf("Pop() = %q, want %q", val, "text") + } + + val, err = s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "" { + t.Errorf("Pop() = %q, want empty string", val) + } +} + +func TestStringStack_LongStrings(t *testing.T) { + s := New() + longString := "" + for i := 0; i < 1000; i++ { + longString += "A" + } + + s.Push(longString) + val, err := s.Pop() + + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != longString { + t.Error("Pop() returned different string than pushed") + } + if len(val) != 1000 { + t.Errorf("Pop() string length = %d, want 1000", len(val)) + } +} + +func TestStringStack_ManyItems(t *testing.T) { + s := New() + count := 1000 + + // Push many items + for i := 0; i < count; i++ { + s.Push("item") + } + + if len(s.stack) != count { + t.Errorf("stack length = %d, want %d", len(s.stack), count) + } + + // Pop all items + for i := 0; i < count; i++ { + _, err := s.Pop() + if err != nil { + t.Errorf("Pop()[%d] error = %v", i, err) + } + } + + // Should be empty now + if len(s.stack) != 0 { + t.Errorf("stack length = %d, want 0 after popping all", len(s.stack)) + } + + // Next pop should error + _, err := s.Pop() + if err == nil { + t.Error("Pop() on empty stack should return error") + } +} + +func TestStringStack_SetAfterOperations(t *testing.T) { + s := New() + s.Push("a") + s.Push("b") + s.Push("c") + s.Pop() + s.Push("d") + + // Set should clear everything + s.Set("reset") + + if len(s.stack) != 1 { + t.Errorf("stack length = %d, want 1 after Set", len(s.stack)) + } + + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != "reset" { + t.Errorf("Pop() = %q, want %q", val, "reset") + } +} + +func TestStringStack_SpecialCharacters(t *testing.T) { + s := New() + specialStrings := []string{ + "Hello\nWorld", + "Tab\tSeparated", + "Quote\"Test", + "Backslash\\Test", + "Unicode: テスト", + "Emoji: 😀", + "", + " ", + " spaces ", + } + + for _, str := range specialStrings { + s.Push(str) + } + + // Pop in reverse order + for i := len(specialStrings) - 1; i >= 0; i-- { + val, err := s.Pop() + if err != nil { + t.Errorf("Pop() error = %v", err) + } + if val != specialStrings[i] { + t.Errorf("Pop() = %q, want %q", val, specialStrings[i]) + } + } +} + +func BenchmarkStringStack_Push(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Push("test string") + } +} + +func BenchmarkStringStack_Pop(b *testing.B) { + s := New() + // Pre-populate + for i := 0; i < 10000; i++ { + s.Push("test string") + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + if len(s.stack) == 0 { + // Repopulate + for j := 0; j < 10000; j++ { + s.Push("test string") + } + } + _, _ = s.Pop() + } +} + +func BenchmarkStringStack_PushPop(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Push("test") + _, _ = s.Pop() + } +} + +func BenchmarkStringStack_Set(b *testing.B) { + s := New() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Set("test string") + } +} diff --git a/common/stringsupport/string_convert_test.go b/common/stringsupport/string_convert_test.go new file mode 100644 index 000000000..adfc434f4 --- /dev/null +++ b/common/stringsupport/string_convert_test.go @@ -0,0 +1,491 @@ +package stringsupport + +import ( + "bytes" + "testing" +) + +func TestUTF8ToSJIS(t *testing.T) { + tests := []struct { + name string + input string + }{ + {"ascii", "Hello World"}, + {"numbers", "12345"}, + {"symbols", "!@#$%"}, + {"japanese_hiragana", "あいうえお"}, + {"japanese_katakana", "アイウエオ"}, + {"japanese_kanji", "日本語"}, + {"mixed", "Hello世界"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := UTF8ToSJIS(tt.input) + if len(result) == 0 && len(tt.input) > 0 { + t.Error("UTF8ToSJIS returned empty result for non-empty input") + } + }) + } +} + +func TestSJISToUTF8(t *testing.T) { + // Test ASCII characters (which are the same in SJIS and UTF-8) + asciiBytes := []byte("Hello World") + result := SJISToUTF8(asciiBytes) + if result != "Hello World" { + t.Errorf("SJISToUTF8() = %q, want %q", result, "Hello World") + } +} + +func TestUTF8ToSJIS_RoundTrip(t *testing.T) { + // Test round-trip conversion for ASCII + original := "Hello World 123" + sjis := UTF8ToSJIS(original) + back := SJISToUTF8(sjis) + + if back != original { + t.Errorf("Round-trip failed: got %q, want %q", back, original) + } +} + +func TestToNGWord(t *testing.T) { + tests := []struct { + name string + input string + minLen int + checkFn func(t *testing.T, result []uint16) + }{ + { + name: "ascii characters", + input: "ABC", + minLen: 3, + checkFn: func(t *testing.T, result []uint16) { + if result[0] != uint16('A') { + t.Errorf("result[0] = %d, want %d", result[0], 'A') + } + }, + }, + { + name: "numbers", + input: "123", + minLen: 3, + checkFn: func(t *testing.T, result []uint16) { + if result[0] != uint16('1') { + t.Errorf("result[0] = %d, want %d", result[0], '1') + } + }, + }, + { + name: "japanese characters", + input: "あ", + minLen: 1, + checkFn: func(t *testing.T, result []uint16) { + if len(result) == 0 { + t.Error("result should not be empty") + } + }, + }, + { + name: "empty string", + input: "", + minLen: 0, + checkFn: func(t *testing.T, result []uint16) { + if len(result) != 0 { + t.Errorf("result length = %d, want 0", len(result)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ToNGWord(tt.input) + if len(result) < tt.minLen { + t.Errorf("ToNGWord() length = %d, want at least %d", len(result), tt.minLen) + } + if tt.checkFn != nil { + tt.checkFn(t, result) + } + }) + } +} + +func TestPaddedString(t *testing.T) { + tests := []struct { + name string + input string + size uint + transform bool + wantLen uint + }{ + {"short string", "Hello", 10, false, 10}, + {"exact size", "Test", 5, false, 5}, + {"longer than size", "This is a long string", 10, false, 10}, + {"empty string", "", 5, false, 5}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := PaddedString(tt.input, tt.size, tt.transform) + if uint(len(result)) != tt.wantLen { + t.Errorf("PaddedString() length = %d, want %d", len(result), tt.wantLen) + } + // Verify last byte is null + if result[len(result)-1] != 0 { + t.Error("PaddedString() should end with null byte") + } + }) + } +} + +func TestPaddedString_NullTermination(t *testing.T) { + result := PaddedString("Test", 10, false) + if result[9] != 0 { + t.Error("Last byte should be null") + } + // First 4 bytes should be "Test" + if !bytes.Equal(result[0:4], []byte("Test")) { + t.Errorf("First 4 bytes = %v, want %v", result[0:4], []byte("Test")) + } +} + +func TestCSVAdd(t *testing.T) { + tests := []struct { + name string + csv string + value int + expected string + }{ + {"add to empty", "", 1, "1"}, + {"add to existing", "1,2,3", 4, "1,2,3,4"}, + {"add duplicate", "1,2,3", 2, "1,2,3"}, + {"add to single", "5", 10, "5,10"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVAdd(tt.csv, tt.value) + if result != tt.expected { + t.Errorf("CSVAdd(%q, %d) = %q, want %q", tt.csv, tt.value, result, tt.expected) + } + }) + } +} + +func TestCSVRemove(t *testing.T) { + tests := []struct { + name string + csv string + value int + check func(t *testing.T, result string) + }{ + { + name: "remove from middle", + csv: "1,2,3,4,5", + value: 3, + check: func(t *testing.T, result string) { + if CSVContains(result, 3) { + t.Error("Result should not contain 3") + } + if CSVLength(result) != 4 { + t.Errorf("Result length = %d, want 4", CSVLength(result)) + } + }, + }, + { + name: "remove from start", + csv: "1,2,3", + value: 1, + check: func(t *testing.T, result string) { + if CSVContains(result, 1) { + t.Error("Result should not contain 1") + } + }, + }, + { + name: "remove non-existent", + csv: "1,2,3", + value: 99, + check: func(t *testing.T, result string) { + if CSVLength(result) != 3 { + t.Errorf("Length should remain 3, got %d", CSVLength(result)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVRemove(tt.csv, tt.value) + tt.check(t, result) + }) + } +} + +func TestCSVContains(t *testing.T) { + tests := []struct { + name string + csv string + value int + expected bool + }{ + {"contains in middle", "1,2,3,4,5", 3, true}, + {"contains at start", "1,2,3", 1, true}, + {"contains at end", "1,2,3", 3, true}, + {"does not contain", "1,2,3", 5, false}, + {"empty csv", "", 1, false}, + {"single value match", "42", 42, true}, + {"single value no match", "42", 43, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVContains(tt.csv, tt.value) + if result != tt.expected { + t.Errorf("CSVContains(%q, %d) = %v, want %v", tt.csv, tt.value, result, tt.expected) + } + }) + } +} + +func TestCSVLength(t *testing.T) { + tests := []struct { + name string + csv string + expected int + }{ + {"empty", "", 0}, + {"single", "1", 1}, + {"multiple", "1,2,3,4,5", 5}, + {"two", "10,20", 2}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVLength(tt.csv) + if result != tt.expected { + t.Errorf("CSVLength(%q) = %d, want %d", tt.csv, result, tt.expected) + } + }) + } +} + +func TestCSVElems(t *testing.T) { + tests := []struct { + name string + csv string + expected []int + }{ + {"empty", "", []int{}}, + {"single", "42", []int{42}}, + {"multiple", "1,2,3,4,5", []int{1, 2, 3, 4, 5}}, + {"negative numbers", "-1,0,1", []int{-1, 0, 1}}, + {"large numbers", "100,200,300", []int{100, 200, 300}}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVElems(tt.csv) + if len(result) != len(tt.expected) { + t.Errorf("CSVElems(%q) length = %d, want %d", tt.csv, len(result), len(tt.expected)) + } + for i, v := range tt.expected { + if i >= len(result) || result[i] != v { + t.Errorf("CSVElems(%q)[%d] = %d, want %d", tt.csv, i, result[i], v) + } + } + }) + } +} + +func TestCSVGetIndex(t *testing.T) { + csv := "10,20,30,40,50" + + tests := []struct { + name string + index int + expected int + }{ + {"first", 0, 10}, + {"middle", 2, 30}, + {"last", 4, 50}, + {"out of bounds", 10, 0}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVGetIndex(csv, tt.index) + if result != tt.expected { + t.Errorf("CSVGetIndex(%q, %d) = %d, want %d", csv, tt.index, result, tt.expected) + } + }) + } +} + +func TestCSVSetIndex(t *testing.T) { + tests := []struct { + name string + csv string + index int + value int + check func(t *testing.T, result string) + }{ + { + name: "set first", + csv: "10,20,30", + index: 0, + value: 99, + check: func(t *testing.T, result string) { + if CSVGetIndex(result, 0) != 99 { + t.Errorf("Index 0 = %d, want 99", CSVGetIndex(result, 0)) + } + }, + }, + { + name: "set middle", + csv: "10,20,30", + index: 1, + value: 88, + check: func(t *testing.T, result string) { + if CSVGetIndex(result, 1) != 88 { + t.Errorf("Index 1 = %d, want 88", CSVGetIndex(result, 1)) + } + }, + }, + { + name: "set last", + csv: "10,20,30", + index: 2, + value: 77, + check: func(t *testing.T, result string) { + if CSVGetIndex(result, 2) != 77 { + t.Errorf("Index 2 = %d, want 77", CSVGetIndex(result, 2)) + } + }, + }, + { + name: "set out of bounds", + csv: "10,20,30", + index: 10, + value: 99, + check: func(t *testing.T, result string) { + // Should not modify the CSV + if CSVLength(result) != 3 { + t.Errorf("CSV length changed when setting out of bounds") + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CSVSetIndex(tt.csv, tt.index, tt.value) + tt.check(t, result) + }) + } +} + +func TestCSV_CompleteWorkflow(t *testing.T) { + // Test a complete workflow + csv := "" + + // Add elements + csv = CSVAdd(csv, 10) + csv = CSVAdd(csv, 20) + csv = CSVAdd(csv, 30) + + if CSVLength(csv) != 3 { + t.Errorf("Length = %d, want 3", CSVLength(csv)) + } + + // Check contains + if !CSVContains(csv, 20) { + t.Error("Should contain 20") + } + + // Get element + if CSVGetIndex(csv, 1) != 20 { + t.Errorf("Index 1 = %d, want 20", CSVGetIndex(csv, 1)) + } + + // Set element + csv = CSVSetIndex(csv, 1, 99) + if CSVGetIndex(csv, 1) != 99 { + t.Errorf("Index 1 = %d, want 99 after set", CSVGetIndex(csv, 1)) + } + + // Remove element + csv = CSVRemove(csv, 99) + if CSVContains(csv, 99) { + t.Error("Should not contain 99 after removal") + } + + if CSVLength(csv) != 2 { + t.Errorf("Length = %d, want 2 after removal", CSVLength(csv)) + } +} + +func BenchmarkCSVAdd(b *testing.B) { + csv := "1,2,3,4,5" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CSVAdd(csv, 6) + } +} + +func BenchmarkCSVContains(b *testing.B) { + csv := "1,2,3,4,5,6,7,8,9,10" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CSVContains(csv, 5) + } +} + +func BenchmarkCSVRemove(b *testing.B) { + csv := "1,2,3,4,5,6,7,8,9,10" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CSVRemove(csv, 5) + } +} + +func BenchmarkCSVElems(b *testing.B) { + csv := "1,2,3,4,5,6,7,8,9,10" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = CSVElems(csv) + } +} + +func BenchmarkUTF8ToSJIS(b *testing.B) { + text := "Hello World テスト" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = UTF8ToSJIS(text) + } +} + +func BenchmarkSJISToUTF8(b *testing.B) { + text := []byte("Hello World") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = SJISToUTF8(text) + } +} + +func BenchmarkPaddedString(b *testing.B) { + text := "Test String" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = PaddedString(text, 50, false) + } +} + +func BenchmarkToNGWord(b *testing.B) { + text := "TestString" + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = ToNGWord(text) + } +} diff --git a/common/token/token_test.go b/common/token/token_test.go new file mode 100644 index 000000000..4d7487492 --- /dev/null +++ b/common/token/token_test.go @@ -0,0 +1,340 @@ +package token + +import ( + "math/rand" + "testing" + "time" +) + +func TestGenerate_Length(t *testing.T) { + tests := []struct { + name string + length int + }{ + {"zero length", 0}, + {"short", 5}, + {"medium", 32}, + {"long", 100}, + {"very long", 1000}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := Generate(tt.length) + if len(result) != tt.length { + t.Errorf("Generate(%d) length = %d, want %d", tt.length, len(result), tt.length) + } + }) + } +} + +func TestGenerate_CharacterSet(t *testing.T) { + // Verify that generated tokens only contain alphanumeric characters + validChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + validCharMap := make(map[rune]bool) + for _, c := range validChars { + validCharMap[c] = true + } + + token := Generate(1000) // Large sample + for _, c := range token { + if !validCharMap[c] { + t.Errorf("Generate() produced invalid character: %c", c) + } + } +} + +func TestGenerate_Randomness(t *testing.T) { + // Generate multiple tokens and verify they're different + tokens := make(map[string]bool) + count := 100 + length := 32 + + for i := 0; i < count; i++ { + token := Generate(length) + if tokens[token] { + t.Errorf("Generate() produced duplicate token: %s", token) + } + tokens[token] = true + } + + if len(tokens) != count { + t.Errorf("Generated %d unique tokens, want %d", len(tokens), count) + } +} + +func TestGenerate_ContainsUppercase(t *testing.T) { + // With enough characters, should contain at least one uppercase letter + token := Generate(1000) + hasUpper := false + for _, c := range token { + if c >= 'A' && c <= 'Z' { + hasUpper = true + break + } + } + if !hasUpper { + t.Error("Generate(1000) should contain at least one uppercase letter") + } +} + +func TestGenerate_ContainsLowercase(t *testing.T) { + // With enough characters, should contain at least one lowercase letter + token := Generate(1000) + hasLower := false + for _, c := range token { + if c >= 'a' && c <= 'z' { + hasLower = true + break + } + } + if !hasLower { + t.Error("Generate(1000) should contain at least one lowercase letter") + } +} + +func TestGenerate_ContainsDigit(t *testing.T) { + // With enough characters, should contain at least one digit + token := Generate(1000) + hasDigit := false + for _, c := range token { + if c >= '0' && c <= '9' { + hasDigit = true + break + } + } + if !hasDigit { + t.Error("Generate(1000) should contain at least one digit") + } +} + +func TestGenerate_Distribution(t *testing.T) { + // Test that characters are reasonably distributed + token := Generate(6200) // 62 chars * 100 = good sample size + charCount := make(map[rune]int) + + for _, c := range token { + charCount[c]++ + } + + // With 62 valid characters and 6200 samples, average should be 100 per char + // We'll accept a range to account for randomness + minExpected := 50 // Allow some variance + maxExpected := 150 + + for c, count := range charCount { + if count < minExpected || count > maxExpected { + t.Logf("Character %c appeared %d times (outside expected range %d-%d)", c, count, minExpected, maxExpected) + } + } + + // Just verify we have a good spread of characters + if len(charCount) < 50 { + t.Errorf("Only %d different characters used, want at least 50", len(charCount)) + } +} + +func TestNewRNG(t *testing.T) { + rng := NewRNG() + if rng == nil { + t.Fatal("NewRNG() returned nil") + } + + // Test that it produces different values on subsequent calls + val1 := rng.Intn(1000000) + val2 := rng.Intn(1000000) + + if val1 == val2 { + // This is possible but unlikely, let's try a few more times + same := true + for i := 0; i < 10; i++ { + if rng.Intn(1000000) != val1 { + same = false + break + } + } + if same { + t.Error("NewRNG() produced same value 12 times in a row") + } + } +} + +func TestRNG_GlobalVariable(t *testing.T) { + // Test that the global RNG variable is initialized + if RNG == nil { + t.Fatal("Global RNG is nil") + } + + // Test that it works + val := RNG.Intn(100) + if val < 0 || val >= 100 { + t.Errorf("RNG.Intn(100) = %d, out of range [0, 100)", val) + } +} + +func TestRNG_Uint32(t *testing.T) { + // Test that RNG can generate uint32 values + val1 := RNG.Uint32() + val2 := RNG.Uint32() + + // They should be different (with very high probability) + if val1 == val2 { + // Try a few more times + same := true + for i := 0; i < 10; i++ { + if RNG.Uint32() != val1 { + same = false + break + } + } + if same { + t.Error("RNG.Uint32() produced same value 12 times") + } + } +} + +func TestGenerate_Concurrency(t *testing.T) { + // Test that Generate works correctly when called concurrently + done := make(chan string, 100) + + for i := 0; i < 100; i++ { + go func() { + token := Generate(32) + done <- token + }() + } + + tokens := make(map[string]bool) + for i := 0; i < 100; i++ { + token := <-done + if len(token) != 32 { + t.Errorf("Token length = %d, want 32", len(token)) + } + tokens[token] = true + } + + // Should have many unique tokens (allow some small chance of duplicates) + if len(tokens) < 95 { + t.Errorf("Only %d unique tokens from 100 concurrent calls", len(tokens)) + } +} + +func TestGenerate_EmptyString(t *testing.T) { + token := Generate(0) + if token != "" { + t.Errorf("Generate(0) = %q, want empty string", token) + } +} + +func TestGenerate_OnlyAlphanumeric(t *testing.T) { + // Verify no special characters + token := Generate(1000) + for i, c := range token { + isValid := (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') + if !isValid { + t.Errorf("Token[%d] = %c (invalid character)", i, c) + } + } +} + +func TestNewRNG_DifferentSeeds(t *testing.T) { + // Create two RNGs at different times and verify they produce different sequences + rng1 := NewRNG() + time.Sleep(1 * time.Millisecond) // Ensure different seed + rng2 := NewRNG() + + val1 := rng1.Intn(1000000) + val2 := rng2.Intn(1000000) + + // They should be different with high probability + if val1 == val2 { + // Try again + val1 = rng1.Intn(1000000) + val2 = rng2.Intn(1000000) + if val1 == val2 { + t.Log("Two RNGs created at different times produced same first two values (possible but unlikely)") + } + } +} + +func BenchmarkGenerate_Short(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Generate(8) + } +} + +func BenchmarkGenerate_Medium(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Generate(32) + } +} + +func BenchmarkGenerate_Long(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = Generate(128) + } +} + +func BenchmarkNewRNG(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewRNG() + } +} + +func BenchmarkRNG_Intn(b *testing.B) { + rng := NewRNG() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = rng.Intn(62) + } +} + +func BenchmarkRNG_Uint32(b *testing.B) { + rng := NewRNG() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = rng.Uint32() + } +} + +func TestGenerate_ConsistentCharacterSet(t *testing.T) { + // Verify the character set matches what's defined in the code + expectedChars := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + if len(expectedChars) != 62 { + t.Errorf("Expected character set length = %d, want 62", len(expectedChars)) + } + + // Count each type + lowercase := 0 + uppercase := 0 + digits := 0 + for _, c := range expectedChars { + if c >= 'a' && c <= 'z' { + lowercase++ + } else if c >= 'A' && c <= 'Z' { + uppercase++ + } else if c >= '0' && c <= '9' { + digits++ + } + } + + if lowercase != 26 { + t.Errorf("Lowercase count = %d, want 26", lowercase) + } + if uppercase != 26 { + t.Errorf("Uppercase count = %d, want 26", uppercase) + } + if digits != 10 { + t.Errorf("Digits count = %d, want 10", digits) + } +} + +func TestRNG_Type(t *testing.T) { + // Verify RNG is of type *rand.Rand + var _ *rand.Rand = RNG + var _ *rand.Rand = NewRNG() +}