Files
p256k1/scalar_test.go

458 lines
9.2 KiB
Go

package p256k1
import (
"crypto/rand"
"testing"
)
func TestScalarBasics(t *testing.T) {
// Test zero scalar
var zero Scalar
zero.setInt(0)
if !zero.isZero() {
t.Error("Zero scalar should be zero")
}
// Test one scalar
var one Scalar
one.setInt(1)
if one.isZero() {
t.Error("One scalar should not be zero")
}
if !one.isOne() {
t.Error("One scalar should be one")
}
// Test equality
var one2 Scalar
one2.setInt(1)
if !one.equal(&one2) {
t.Error("Two ones should be equal")
}
}
func TestScalarSetB32(t *testing.T) {
testCases := []struct {
name string
bytes [32]byte
overflow bool
}{
{
name: "zero",
bytes: [32]byte{},
overflow: false,
},
{
name: "one",
bytes: [32]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1},
overflow: false,
},
{
name: "group_order_minus_one",
bytes: [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
},
overflow: false,
},
{
name: "group_order",
bytes: [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x41,
},
overflow: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var s Scalar
overflow := s.setB32(tc.bytes[:])
if overflow != tc.overflow {
t.Errorf("Expected overflow %v, got %v", tc.overflow, overflow)
}
// Test round-trip for non-overflowing values
if !tc.overflow {
var result [32]byte
s.getB32(result[:])
// Values should match after round-trip
for i := 0; i < 32; i++ {
if result[i] != tc.bytes[i] {
t.Errorf("Round-trip failed at byte %d: expected %02x, got %02x", i, tc.bytes[i], result[i])
break
}
}
}
})
}
}
func TestScalarSetB32Seckey(t *testing.T) {
// Test valid secret key
validKey := [32]byte{
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01,
}
var s Scalar
if !s.setB32Seckey(validKey[:]) {
t.Error("Valid secret key should be accepted")
}
// Test zero key (invalid)
zeroKey := [32]byte{}
if s.setB32Seckey(zeroKey[:]) {
t.Error("Zero secret key should be rejected")
}
// Test overflowing key (invalid)
overflowKey := [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x41,
}
if s.setB32Seckey(overflowKey[:]) {
t.Error("Overflowing secret key should be rejected")
}
}
func TestScalarArithmetic(t *testing.T) {
// Test addition
var a, b, c Scalar
a.setInt(5)
b.setInt(7)
c.add(&a, &b)
var expected Scalar
expected.setInt(12)
if !c.equal(&expected) {
t.Error("5 + 7 should equal 12")
}
// Test multiplication
var mult Scalar
mult.mul(&a, &b)
expected.setInt(35)
if !mult.equal(&expected) {
t.Error("5 * 7 should equal 35")
}
// Test negation
var neg Scalar
neg.negate(&a)
var sum Scalar
sum.add(&a, &neg)
if !sum.isZero() {
t.Error("a + (-a) should equal zero")
}
}
func TestScalarInverse(t *testing.T) {
// Test inverse of small numbers
for i := uint(1); i <= 10; i++ {
var a, inv, product Scalar
a.setInt(i)
inv.inverse(&a)
product.mul(&a, &inv)
if !product.isOne() {
t.Errorf("a * a^(-1) should equal 1 for a = %d", i)
}
}
// Test inverse of zero should not crash (though result is undefined)
var zero, inv Scalar
zero.setInt(0)
inv.inverse(&zero) // Should not crash
}
func TestScalarHalf(t *testing.T) {
// Test halving even numbers
var even, half Scalar
even.setInt(10)
half.half(&even)
var expected Scalar
expected.setInt(5)
if !half.equal(&expected) {
t.Error("10 / 2 should equal 5")
}
// Test halving odd numbers
var odd Scalar
odd.setInt(7)
half.half(&odd)
// 7/2 mod n should be (7 + n)/2 mod n
// This is more complex to verify, so we just check that 2*half = 7
var doubled Scalar
doubled.setInt(2)
doubled.mul(&doubled, &half)
if !doubled.equal(&odd) {
t.Error("2 * (7/2) should equal 7")
}
}
func TestScalarProperties(t *testing.T) {
// Test even/odd detection
var even, odd Scalar
even.setInt(42)
odd.setInt(43)
if !even.isEven() {
t.Error("42 should be even")
}
if odd.isEven() {
t.Error("43 should be odd")
}
// Test high/low detection (compared to n/2)
var low, high Scalar
low.setInt(1)
// Set high to a large value (close to group order)
highBytes := [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
}
high.setB32(highBytes[:])
if low.isHigh() {
t.Error("Small value should not be high")
}
if !high.isHigh() {
t.Error("Large value should be high")
}
}
func TestScalarConditionalNegate(t *testing.T) {
var s Scalar
s.setInt(42)
// Test conditional negate with false
negated := s.condNegate(false)
if negated {
t.Error("Should not negate when flag is false")
}
var expected Scalar
expected.setInt(42)
if !s.equal(&expected) {
t.Error("Value should not change when flag is false")
}
// Test conditional negate with true
negated = s.condNegate(true)
if !negated {
t.Error("Should negate when flag is true")
}
var neg Scalar
expected.setInt(42)
neg.negate(&expected)
if !s.equal(&neg) {
t.Error("Value should be negated when flag is true")
}
}
func TestScalarGetBits(t *testing.T) {
// Test bit extraction
var s Scalar
s.setInt(0b11010110) // 214 in binary
// Extract different bit ranges
bits := s.getBits(0, 4) // Lower 4 bits: 0110 = 6
if bits != 6 {
t.Errorf("Expected 6, got %d", bits)
}
bits = s.getBits(4, 4) // Next 4 bits: 1101 = 13
if bits != 13 {
t.Errorf("Expected 13, got %d", bits)
}
bits = s.getBits(1, 3) // 3 bits starting at position 1: 011 = 3
if bits != 3 {
t.Errorf("Expected 3, got %d", bits)
}
}
func TestScalarConditionalMove(t *testing.T) {
var a, b, result Scalar
a.setInt(10)
b.setInt(20)
result = a
// Test conditional move with flag = 0 (no move)
result.cmov(&b, 0)
if !result.equal(&a) {
t.Error("cmov with flag=0 should not change value")
}
// Test conditional move with flag = 1 (move)
result = a
result.cmov(&b, 1)
if !result.equal(&b) {
t.Error("cmov with flag=1 should change value")
}
}
func TestScalarClear(t *testing.T) {
var s Scalar
s.setInt(12345)
s.clear()
// After clearing, should be zero
if !s.isZero() {
t.Error("Cleared scalar should be zero")
}
}
func TestScalarRandomOperations(t *testing.T) {
// Test with random values
for i := 0; i < 50; i++ {
var bytes1, bytes2 [32]byte
rand.Read(bytes1[:])
rand.Read(bytes2[:])
var a, b Scalar
// Ensure we don't overflow
bytes1[0] &= 0x7F
bytes2[0] &= 0x7F
a.setB32(bytes1[:])
b.setB32(bytes2[:])
// Skip if either is zero (to avoid division by zero in inverse tests)
if a.isZero() || b.isZero() {
continue
}
// Test a + b - a = b
var sum, diff Scalar
sum.add(&a, &b)
var negA Scalar
negA.negate(&a)
diff.add(&sum, &negA)
if !diff.equal(&b) {
t.Errorf("Random test %d: (a + b) - a should equal b", i)
}
// Test a * b / a = b (if a != 0)
var product, quotient Scalar
product.mul(&a, &b)
var invA Scalar
invA.inverse(&a)
quotient.mul(&product, &invA)
if !quotient.equal(&b) {
t.Errorf("Random test %d: (a * b) / a should equal b", i)
}
}
}
func TestScalarEdgeCases(t *testing.T) {
// Test group order boundary
var n_minus_1 Scalar
n_minus_1_bytes := [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x40,
}
n_minus_1.setB32(n_minus_1_bytes[:])
// Add 1 should give 0
var one, result Scalar
one.setInt(1)
result.add(&n_minus_1, &one)
if !result.isZero() {
t.Error("(n-1) + 1 should equal 0 in scalar arithmetic")
}
// Test -1 = n-1
var neg_one Scalar
neg_one.negate(&one)
if !neg_one.equal(&n_minus_1) {
t.Error("-1 should equal n-1")
}
}
// Benchmark tests
func BenchmarkScalarSetB32(b *testing.B) {
var bytes [32]byte
rand.Read(bytes[:])
bytes[0] &= 0x7F // Ensure no overflow
var s Scalar
b.ResetTimer()
for i := 0; i < b.N; i++ {
s.setB32(bytes[:])
}
}
func BenchmarkScalarAdd(b *testing.B) {
var a, c, result Scalar
a.setInt(12345)
c.setInt(67890)
b.ResetTimer()
for i := 0; i < b.N; i++ {
result.add(&a, &c)
}
}
func BenchmarkScalarMul(b *testing.B) {
var a, c, result Scalar
a.setInt(12345)
c.setInt(67890)
b.ResetTimer()
for i := 0; i < b.N; i++ {
result.mul(&a, &c)
}
}
func BenchmarkScalarInverse(b *testing.B) {
var a, result Scalar
a.setInt(12345)
b.ResetTimer()
for i := 0; i < b.N; i++ {
result.inverse(&a)
}
}
func BenchmarkScalarNegate(b *testing.B) {
var a, result Scalar
a.setInt(12345)
b.ResetTimer()
for i := 0; i < b.N; i++ {
result.negate(&a)
}
}