diff --git a/common/decryption/jpk.go b/common/decryption/jpk.go index 24940f178..b4a60a4c9 100644 --- a/common/decryption/jpk.go +++ b/common/decryption/jpk.go @@ -10,15 +10,16 @@ import ( "io" ) -var mShiftIndex = 0 -var mFlag = byte(0) +// jpkState holds the mutable bit-reader state for a single JPK decompression. +// This is local to each call, making concurrent UnpackSimple calls safe. +type jpkState struct { + shiftIndex int + flag byte +} // UnpackSimple decompresses a JPK type-3 compressed byte slice. If the data // does not start with the JKR magic header it is returned unchanged. func UnpackSimple(data []byte) []byte { - mShiftIndex = 0 - mFlag = byte(0) - bf := byteframe.NewByteFrameFromBytes(data) bf.SetLE() header := bf.ReadUint32() @@ -33,7 +34,8 @@ func UnpackSimple(data []byte) []byte { outSize := bf.ReadInt32() outBuffer := make([]byte, outSize) _, _ = bf.Seek(int64(startOffset), io.SeekStart) - ProcessDecode(bf, outBuffer) + s := &jpkState{} + s.processDecode(bf, outBuffer) return outBuffer } @@ -45,16 +47,21 @@ func UnpackSimple(data []byte) []byte { // ProcessDecode runs the JPK LZ-style decompression loop, reading compressed // tokens from data and writing decompressed bytes into outBuffer. func ProcessDecode(data *byteframe.ByteFrame, outBuffer []byte) { + s := &jpkState{} + s.processDecode(data, outBuffer) +} + +func (s *jpkState) processDecode(data *byteframe.ByteFrame, outBuffer []byte) { outIndex := 0 for int(data.Index()) < len(data.Data()) && outIndex < len(outBuffer)-1 { - if JPKBitShift(data) == 0 { + if s.bitShift(data) == 0 { outBuffer[outIndex] = ReadByte(data) outIndex++ continue } else { - if JPKBitShift(data) == 0 { - length := (JPKBitShift(data) << 1) | JPKBitShift(data) + if s.bitShift(data) == 0 { + length := (s.bitShift(data) << 1) | s.bitShift(data) off := ReadByte(data) JPKCopy(outBuffer, int(off), int(length)+3, &outIndex) continue @@ -67,8 +74,8 @@ func ProcessDecode(data *byteframe.ByteFrame, outBuffer []byte) { JPKCopy(outBuffer, off, length+2, &outIndex) continue } else { - if JPKBitShift(data) == 0 { - length := (JPKBitShift(data) << 3) | (JPKBitShift(data) << 2) | (JPKBitShift(data) << 1) | JPKBitShift(data) + if s.bitShift(data) == 0 { + length := (s.bitShift(data) << 3) | (s.bitShift(data) << 2) | (s.bitShift(data) << 1) | s.bitShift(data) JPKCopy(outBuffer, off, int(length)+2+8, &outIndex) continue } else { @@ -89,17 +96,17 @@ func ProcessDecode(data *byteframe.ByteFrame, outBuffer []byte) { } } -// JPKBitShift reads one bit from the compressed stream's flag byte, refilling +// bitShift reads one bit from the compressed stream's flag byte, refilling // the flag from the next byte in data when all 8 bits have been consumed. -func JPKBitShift(data *byteframe.ByteFrame) byte { - mShiftIndex-- +func (s *jpkState) bitShift(data *byteframe.ByteFrame) byte { + s.shiftIndex-- - if mShiftIndex < 0 { - mShiftIndex = 7 - mFlag = ReadByte(data) + if s.shiftIndex < 0 { + s.shiftIndex = 7 + s.flag = ReadByte(data) } - return (byte)((mFlag >> mShiftIndex) & 1) + return (s.flag >> s.shiftIndex) & 1 } // JPKCopy copies length bytes from a previous position in outBuffer (determined diff --git a/common/decryption/jpk_test.go b/common/decryption/jpk_test.go index a9b0542b4..c824b6b30 100644 --- a/common/decryption/jpk_test.go +++ b/common/decryption/jpk_test.go @@ -81,21 +81,40 @@ func TestUnpackSimple_JPKHeader(t *testing.T) { } func TestJPKBitShift_Initialization(t *testing.T) { - // Test that the function doesn't crash with bad initial global state - mShiftIndex = 10 - mFlag = 0xFF - - // Create data without JPK header (will return as-is) - // Need at least 4 bytes since UnpackSimple reads a uint32 header + // Test that bitShift correctly initializes from zero state bf := byteframe.NewByteFrame() - bf.WriteUint32(0xAABBCCDD) // Not a JPK header + bf.WriteUint8(0xFF) // All bits set + bf.WriteUint8(0x00) // No bits set - data := bf.Data() - result := UnpackSimple(data) + _, _ = bf.Seek(0, io.SeekStart) + s := &jpkState{} - // Without JPK header, should return data as-is - if !bytes.Equal(result, data) { - t.Error("UnpackSimple with non-JPK data should return input as-is") + // First call should read 0xFF as flag and return bit 7 = 1 + bit := s.bitShift(bf) + if bit != 1 { + t.Errorf("bitShift() first bit of 0xFF = %d, want 1", bit) + } +} + +func TestUnpackSimple_ConcurrentSafety(t *testing.T) { + // Verify that concurrent UnpackSimple calls don't race. + // Non-JPK data is returned as-is; the important thing is no data race. + input := []byte{0x00, 0x01, 0x02, 0x03} + + done := make(chan struct{}) + for i := 0; i < 8; i++ { + go func() { + defer func() { done <- struct{}{} }() + for j := 0; j < 100; j++ { + result := UnpackSimple(input) + if !bytes.Equal(result, input) { + t.Errorf("concurrent UnpackSimple returned wrong data") + } + } + }() + } + for i := 0; i < 8; i++ { + <-done } }