From 5bf9d72518818875650cf2e44049b99fedee49d1 Mon Sep 17 00:00:00 2001 From: mleku Date: Thu, 26 Jun 2025 21:28:52 +0100 Subject: [PATCH] Add comprehensive tests for sha256 marshalling and unmarshalling Introduce tests to verify `MarshalBinary` and `UnmarshalBinary` behavior, covering block boundary cases, data integrity after state restoration, and error handling for invalid inputs. These tests ensure correctness for diverse data sizes, edge cases, and error scenarios. --- sha256/additional_test.go | 197 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) create mode 100644 sha256/additional_test.go diff --git a/sha256/additional_test.go b/sha256/additional_test.go new file mode 100644 index 0000000..8d9bec9 --- /dev/null +++ b/sha256/additional_test.go @@ -0,0 +1,197 @@ +package sha256 + +import ( + "bytes" + "fmt" + "testing" +) + +// Test for the UnmarshalBinary issue where d.nx calculation might be incorrect +func TestUnmarshalBinaryNxCalculation(t *testing.T) { + // Create a digest and write some data that doesn't align to block boundary + d1 := New().(*digest) + testData := []byte("hello world") // 11 bytes, not aligned to 64-byte boundary + d1.Write(testData) + + // Marshal the state + state, err := d1.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + + // Create a new digest and unmarshal + d2 := New().(*digest) + err = d2.UnmarshalBinary(state) + if err != nil { + t.Fatalf("UnmarshalBinary failed: %v", err) + } + + // Check that nx values match + if d1.nx != d2.nx { + t.Errorf("nx mismatch after unmarshal: original=%d, unmarshaled=%d", d1.nx, d2.nx) + } + + // Check that the buffer contents match + if !bytes.Equal(d1.x[:d1.nx], d2.x[:d2.nx]) { + t.Errorf("buffer contents mismatch after unmarshal") + } + + // Continue writing and verify results match + moreData := []byte(" more data") + d1.Write(moreData) + d2.Write(moreData) + + sum1 := d1.Sum(nil) + sum2 := d2.Sum(nil) + + if !bytes.Equal(sum1, sum2) { + t.Errorf("final sums don't match: %x vs %x", sum1, sum2) + } +} + +// Test edge case with exactly block-sized data +func TestUnmarshalBinaryBlockBoundary(t *testing.T) { + // Create data that's exactly one block (64 bytes) + testData := make([]byte, BlockSize) + for i := range testData { + testData[i] = byte(i) + } + + d1 := New().(*digest) + d1.Write(testData) + + // Marshal and unmarshal + state, err := d1.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed: %v", err) + } + + d2 := New().(*digest) + err = d2.UnmarshalBinary(state) + if err != nil { + t.Fatalf("UnmarshalBinary failed: %v", err) + } + + // After writing exactly one block, nx should be 0 + if d1.nx != 0 || d2.nx != 0 { + t.Errorf("nx should be 0 after block boundary: d1.nx=%d, d2.nx=%d", d1.nx, d2.nx) + } + + // Verify final results match + sum1 := d1.Sum(nil) + sum2 := d2.Sum(nil) + + if !bytes.Equal(sum1, sum2) { + t.Errorf("final sums don't match: %x vs %x", sum1, sum2) + } +} + +// Test marshaling/unmarshaling with various data sizes +func TestMarshalUnmarshalVariousSizes(t *testing.T) { + sizes := []int{0, 1, 32, 63, 64, 65, 127, 128, 129} + + for _, size := range sizes { + t.Run(fmt.Sprintf("size_%d", size), func(t *testing.T) { + testData := make([]byte, size) + for i := range testData { + testData[i] = byte(i % 256) + } + + d1 := New().(*digest) + d1.Write(testData) + + // Marshal + state, err := d1.MarshalBinary() + if err != nil { + t.Fatalf("MarshalBinary failed for size %d: %v", size, err) + } + + // Unmarshal + d2 := New().(*digest) + err = d2.UnmarshalBinary(state) + if err != nil { + t.Fatalf("UnmarshalBinary failed for size %d: %v", size, err) + } + + // Verify state matches + if d1.nx != d2.nx { + t.Errorf("nx mismatch for size %d: %d vs %d", size, d1.nx, d2.nx) + } + + if d1.len != d2.len { + t.Errorf("len mismatch for size %d: %d vs %d", size, d1.len, d2.len) + } + + // Verify final results match + sum1 := d1.Sum(nil) + sum2 := d2.Sum(nil) + + if !bytes.Equal(sum1, sum2) { + t.Errorf("final sums don't match for size %d: %x vs %x", size, sum1, sum2) + } + }) + } +} + +// Test the getDigest function to ensure it correctly extracts digest values +func TestGetDigest(t *testing.T) { + // Create a test state array with known values + state := make([]byte, 512) + + // Fill with a pattern that we can verify + for i := 0; i < 512; i++ { + state[i] = byte(i % 256) + } + + // Test extracting digest for different indices + for index := 0; index < 16; index++ { + digest := getDigest(index, state) + + // Verify the digest has the expected size + if len(digest) != Size { + t.Errorf("digest size mismatch for index %d: got %d, want %d", index, len(digest), Size) + } + + // The digest should not be all zeros (unless the state was all zeros) + allZero := true + for _, b := range digest { + if b != 0 { + allZero = false + break + } + } + + // Since we filled state with a pattern, digest shouldn't be all zeros + if allZero { + t.Errorf("digest for index %d is all zeros, which is unexpected", index) + } + } +} + +// Test error handling in UnmarshalBinary +func TestUnmarshalBinaryErrorHandling(t *testing.T) { + d := New().(*digest) + + // Test with invalid magic + invalidMagic := []byte("bad\x03" + string(make([]byte, marshaledSize-4))) + err := d.UnmarshalBinary(invalidMagic) + if err == nil { + t.Error("expected error for invalid magic, got nil") + } + + // Test with invalid size + validMagic := []byte("sha\x03") + tooShort := append(validMagic, make([]byte, 10)...) + err = d.UnmarshalBinary(tooShort) + if err == nil { + t.Error("expected error for invalid size, got nil") + } + + // Test with exactly the minimum size but wrong magic + exactSize := make([]byte, marshaledSize) + copy(exactSize, "bad\x03") + err = d.UnmarshalBinary(exactSize) + if err == nil { + t.Error("expected error for wrong magic with correct size, got nil") + } +}