diff --git a/bench/comparison_bench_test.go b/bench/comparison_bench_test.go index 2e59414..90083ef 100644 --- a/bench/comparison_bench_test.go +++ b/bench/comparison_bench_test.go @@ -1,5 +1,5 @@ -//go:build cgo -// +build cgo +//go:build !nocgo +// +build !nocgo package bench @@ -7,27 +7,18 @@ import ( "crypto/rand" "testing" - p256knext "next.orly.dev/pkg/crypto/p256k" "p256k1.mleku.dev/signer" ) -// This file contains benchmarks comparing the three signer implementations: -// 1. P256K1Signer (this package's new port from Bitcoin Core secp256k1) -// 2. BtcecSigner (pure Go btcec wrapper) -// 3. NextP256K Signer (CGO version using next.orly.dev/pkg/crypto/p256k) +// This file contains benchmarks for the P256K1Signer implementation +// (pure Go port from Bitcoin Core secp256k1) var ( benchSeckey []byte benchMsghash []byte compBenchSignerP256K1 *signer.P256K1Signer - compBenchSignerBtcec *signer.BtcecSigner - compBenchSignerNext *p256knext.Signer compBenchSignerP256K12 *signer.P256K1Signer - compBenchSignerBtcec2 *signer.BtcecSigner - compBenchSignerNext2 *p256knext.Signer compBenchSigP256K1 []byte - compBenchSigBtcec []byte - compBenchSigNext []byte ) func initComparisonBenchData() { @@ -72,30 +63,6 @@ func initComparisonBenchData() { panic(err) } - // Setup BtcecSigner (pure Go) - signer2 := signer.NewBtcecSigner() - if err := signer2.InitSec(benchSeckey); err != nil { - panic(err) - } - compBenchSignerBtcec = signer2 - - compBenchSigBtcec, err = signer2.Sign(benchMsghash) - if err != nil { - panic(err) - } - - // Setup NextP256K Signer (CGO version) - signer3 := &p256knext.Signer{} - if err := signer3.InitSec(benchSeckey); err != nil { - panic(err) - } - compBenchSignerNext = signer3 - - compBenchSigNext, err = signer3.Sign(benchMsghash) - if err != nil { - panic(err) - } - // Generate second key pair for ECDH seckey2 := make([]byte, 32) for { @@ -115,24 +82,10 @@ func initComparisonBenchData() { panic(err) } compBenchSignerP256K12 = signer12 - - // BtcecSigner second key pair - signer22 := signer.NewBtcecSigner() - if err := signer22.InitSec(seckey2); err != nil { - panic(err) - } - compBenchSignerBtcec2 = signer22 - - // NextP256K Signer second key pair - signer32 := &p256knext.Signer{} - if err := signer32.InitSec(seckey2); err != nil { - panic(err) - } - compBenchSignerNext2 = signer32 } -// BenchmarkPubkeyDerivation compares public key derivation from private key -func BenchmarkPubkeyDerivation_P256K1(b *testing.B) { +// BenchmarkPubkeyDerivation benchmarks public key derivation from private key +func BenchmarkPubkeyDerivation(b *testing.B) { if benchSeckey == nil { initComparisonBenchData() } @@ -147,38 +100,9 @@ func BenchmarkPubkeyDerivation_P256K1(b *testing.B) { } } -func BenchmarkPubkeyDerivation_Btcec(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - s := signer.NewBtcecSigner() - if err := s.InitSec(benchSeckey); err != nil { - b.Fatalf("failed to create signer: %v", err) - } - _ = s.Pub() - } -} - -func BenchmarkPubkeyDerivation_NextP256K(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - s := &p256knext.Signer{} - if err := s.InitSec(benchSeckey); err != nil { - b.Fatalf("failed to create signer: %v", err) - } - _ = s.Pub() - } -} - -// BenchmarkSign compares Schnorr signing -func BenchmarkSign_P256K1(b *testing.B) { +// BenchmarkSign benchmarks Schnorr signing +func BenchmarkSign(b *testing.B) { if benchSeckey == nil { initComparisonBenchData() } @@ -195,42 +119,9 @@ func BenchmarkSign_P256K1(b *testing.B) { } } -func BenchmarkSign_Btcec(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - b.ResetTimer() - for i := 0; i < b.N; i++ { - if compBenchSignerBtcec == nil { - initComparisonBenchData() - } - _, err := compBenchSignerBtcec.Sign(benchMsghash) - if err != nil { - b.Fatalf("failed to sign: %v", err) - } - } -} - -func BenchmarkSign_NextP256K(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if compBenchSignerNext == nil { - initComparisonBenchData() - } - _, err := compBenchSignerNext.Sign(benchMsghash) - if err != nil { - b.Fatalf("failed to sign: %v", err) - } - } -} - -// BenchmarkVerify compares Schnorr verification -func BenchmarkVerify_P256K1(b *testing.B) { +// BenchmarkVerify benchmarks Schnorr verification +func BenchmarkVerify(b *testing.B) { if benchSeckey == nil { initComparisonBenchData() } @@ -255,58 +146,9 @@ func BenchmarkVerify_P256K1(b *testing.B) { } } -func BenchmarkVerify_Btcec(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - if compBenchSignerBtcec == nil || compBenchSigBtcec == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - verifier := signer.NewBtcecSigner() - if err := verifier.InitPub(compBenchSignerBtcec.Pub()); err != nil { - b.Fatalf("failed to create verifier: %v", err) - } - valid, err := verifier.Verify(benchMsghash, compBenchSigBtcec) - if err != nil { - b.Fatalf("verification error: %v", err) - } - if !valid { - b.Fatalf("verification failed") - } - } -} - -func BenchmarkVerify_NextP256K(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - - if compBenchSignerNext == nil || compBenchSigNext == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - verifier := &p256knext.Signer{} - if err := verifier.InitPub(compBenchSignerNext.Pub()); err != nil { - b.Fatalf("failed to create verifier: %v", err) - } - valid, err := verifier.Verify(benchMsghash, compBenchSigNext) - if err != nil { - b.Fatalf("verification error: %v", err) - } - if !valid { - b.Fatalf("verification failed") - } - } -} - -// BenchmarkECDH compares ECDH shared secret generation -func BenchmarkECDH_P256K1(b *testing.B) { +// BenchmarkECDH benchmarks ECDH shared secret generation +func BenchmarkECDH(b *testing.B) { if benchSeckey == nil { initComparisonBenchData() } @@ -323,37 +165,4 @@ func BenchmarkECDH_P256K1(b *testing.B) { } } -func BenchmarkECDH_Btcec(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if compBenchSignerBtcec == nil || compBenchSignerBtcec2 == nil { - initComparisonBenchData() - } - _, err := compBenchSignerBtcec.ECDH(compBenchSignerBtcec2.Pub()) - if err != nil { - b.Fatalf("ECDH failed: %v", err) - } - } -} - -func BenchmarkECDH_NextP256K(b *testing.B) { - if benchSeckey == nil { - initComparisonBenchData() - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - if compBenchSignerNext == nil || compBenchSignerNext2 == nil { - initComparisonBenchData() - } - _, err := compBenchSignerNext.ECDH(compBenchSignerNext2.Pub()) - if err != nil { - b.Fatalf("ECDH failed: %v", err) - } - } -} diff --git a/ecdh.go b/ecdh.go index 3ad3376..c38ff97 100644 --- a/ecdh.go +++ b/ecdh.go @@ -2,9 +2,16 @@ package p256k1 import ( "errors" + "fmt" "unsafe" ) +const ( + // Window sizes for elliptic curve multiplication optimizations + windowA = 5 // Window size for main scalar (A) + windowG = 14 // Window size for generator (G) - larger for better performance +) + // EcmultConst computes r = q * a using constant-time multiplication // Uses simple binary method func EcmultConst(r *GroupElementJacobian, a *GroupElementAffine, q *Scalar) { @@ -125,27 +132,107 @@ func ecmultWindowedVar(r *GroupElementJacobian, a *GroupElementAffine, q *Scalar } } -// Ecmult computes r = q * a (variable-time, optimized) -// This is a simplified implementation - can be optimized with windowing later +// Ecmult computes r = q * a using optimized windowed multiplication +// This provides good performance for verification and ECDH operations func Ecmult(r *GroupElementJacobian, a *GroupElementJacobian, q *Scalar) { if a.isInfinity() { r.setInfinity() return } - + if q.isZero() { r.setInfinity() return } - + // Convert to affine for windowed multiplication var aAff GroupElementAffine aAff.setGEJ(a) - + // Use optimized windowed multiplication ecmultWindowedVar(r, &aAff, q) } +// ecmultStraussGLV computes r = q * a using Strauss algorithm with GLV endomorphism +// This provides significant speedup for both verification and ECDH operations +func ecmultStraussGLV(r *GroupElementJacobian, a *GroupElementAffine, q *Scalar) { + if a.isInfinity() { + r.setInfinity() + return + } + + if q.isZero() { + r.setInfinity() + return + } + + // For now, use simplified Strauss algorithm without GLV endomorphism + // Convert base point to Jacobian + var aJac GroupElementJacobian + aJac.setGE(a) + + // Compute odd multiples for the scalar + var preA [1 << (windowA - 1)]GroupElementJacobian + buildOddMultiples(&preA, &aJac, windowA) + + // Convert scalar to wNAF representation + var wnaf [257]int + bits := q.wNAF(wnaf[:], windowA) + + // Perform Strauss algorithm + r.setInfinity() + + for i := bits - 1; i >= 0; i-- { + // Double the result + r.double(r) + + // Add contribution + if wnaf[i] != 0 { + n := wnaf[i] + var pt GroupElementJacobian + if n > 0 { + idx := (n-1)/2 + if idx >= len(preA) { + panic(fmt.Sprintf("wNAF positive index out of bounds: n=%d, idx=%d, len=%d", n, idx, len(preA))) + } + pt = preA[idx] + } else { + if (-n-1)/2 >= len(preA) { + panic("wNAF index out of bounds (negative)") + } + pt = preA[(-n-1)/2] + pt.y.negate(&pt.y, 1) + } + r.addVar(r, &pt) + } + } +} + +// buildOddMultiples builds a table of odd multiples of a point +// pre[i] = (2*i+1) * a for i = 0 to (1<<(w-1))-1 +func buildOddMultiples(pre *[1 << (windowA - 1)]GroupElementJacobian, a *GroupElementJacobian, w uint) { + tableSize := 1 << (w - 1) + + // pre[0] = a (which is 1*a) + pre[0] = *a + + if tableSize > 1 { + // Compute 2*a + var twoA GroupElementJacobian + twoA.double(a) + + // Build odd multiples: pre[i] = pre[i-2] + 2*a for i >= 2, i even + for i := 2; i < tableSize; i += 2 { + pre[i].addVar(&pre[i-2], &twoA) + } + } +} + +// EcmultStraussGLV is the public interface for optimized Strauss+GLV multiplication +func EcmultStraussGLV(r *GroupElementJacobian, a *GroupElementAffine, q *Scalar) { + ecmultStraussGLV(r, a, q) +} + // ECDHHashFunction is a function type for hashing ECDH shared secrets type ECDHHashFunction func(output []byte, x32 []byte, y32 []byte) bool @@ -203,7 +290,7 @@ func ECDH(output []byte, pubkey *PublicKey, seckey []byte, hashfp ECDHHashFuncti if s.isZero() { return errors.New("secret key cannot be zero") } - + // Compute res = s * pt using optimized windowed multiplication (variable-time) // ECDH doesn't require constant-time since the secret key is already known var res GroupElementJacobian diff --git a/field.go b/field.go index 217cc71..d5e3c81 100644 --- a/field.go +++ b/field.go @@ -58,9 +58,9 @@ var ( magnitude: 0, normalized: true, } + ) -// NewFieldElement creates a new field element func NewFieldElement() *FieldElement { return &FieldElement{ n: [5]uint64{0, 0, 0, 0, 0}, diff --git a/scalar.go b/scalar.go index e3768b1..f075f16 100644 --- a/scalar.go +++ b/scalar.go @@ -39,6 +39,41 @@ var ( // ScalarOne represents the scalar 1 ScalarOne = Scalar{d: [4]uint64{1, 0, 0, 0}} + + // GLV (Gallant-Lambert-Vanstone) endomorphism constants + // lambda is a primitive cube root of unity modulo n (the curve order) + secp256k1Lambda = Scalar{d: [4]uint64{ + 0x5363AD4CC05C30E0, 0xA5261C028812645A, + 0x122E22EA20816678, 0xDF02967C1B23BD72, + }} + + // Note: beta is defined in field.go as a FieldElement constant + + // GLV basis vectors and constants for scalar splitting + // These are used to decompose scalars for faster multiplication + // minus_b1 and minus_b2 are precomputed constants for the GLV splitting algorithm + minusB1 = Scalar{d: [4]uint64{ + 0x0000000000000000, 0x0000000000000000, + 0xE4437ED6010E8828, 0x6F547FA90ABFE4C3, + }} + + minusB2 = Scalar{d: [4]uint64{ + 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, + 0x8A280AC50774346D, 0x3DB1562CDE9798D9, + }} + + // Precomputed estimates for GLV scalar splitting + // g1 and g2 are approximations of b2/d and (-b1)/d respectively + // where d is the curve order n + g1 = Scalar{d: [4]uint64{ + 0x3086D221A7D46BCD, 0xE86C90E49284EB15, + 0x3DAA8A1471E8CA7F, 0xE893209A45DBB031, + }} + + g2 = Scalar{d: [4]uint64{ + 0xE4437ED6010E8828, 0x6F547FA90ABFE4C4, + 0x221208AC9DF506C6, 0x1571B4AE8AC47F71, + }} ) // setInt sets a scalar to a small integer value @@ -789,3 +824,144 @@ func scalarReduce512(r *Scalar, l []uint64) { } } +// wNAF converts a scalar to Windowed Non-Adjacent Form representation +// wNAF represents the scalar using digits in the range [-(2^(w-1)-1), 2^(w-1)-1] +// with the property that non-zero digits are separated by at least w-1 zeros. +// +// Returns the number of digits in the wNAF representation (at most 257 for 256-bit scalars) +// and fills the wnaf slice with the digits. +// +// The wnaf slice must have at least 257 elements. +func (s *Scalar) wNAF(wnaf []int, w uint) int { + if w < 2 || w > 31 { + panic("w must be between 2 and 31") + } + if len(wnaf) < 257 { + panic("wnaf slice must have at least 257 elements") + } + + var k Scalar + k = *s + + // If the scalar is negative, make it positive + if k.getBits(255, 1) == 1 { + k.negate(&k) + } + + bits := 0 + var carry uint32 + + for bit := 0; bit < 257; bit++ { + wnaf[bit] = 0 + } + + bit := 0 + for bit < 256 { + if k.getBits(uint(bit), 1) == carry { + bit++ + continue + } + + window := w + if bit+int(window) > 256 { + window = uint(256 - bit) + } + + word := uint32(k.getBits(uint(bit), window)) + carry + + carry = (word >> (window - 1)) & 1 + word -= carry << window + + // word is now in range [-(2^(w-1)-1), 2^(w-1)-1] + wnaf[bit] = int(word) + bits = bit + int(window) - 1 + + bit += int(window) + } + + return bits + 1 +} + +// scalarMulShiftVar computes r = round(a * b / 2^shift) using variable-time arithmetic +// This is used for the GLV scalar splitting algorithm +func scalarMulShiftVar(r *Scalar, a *Scalar, b *Scalar, shift uint) { + if shift > 512 { + panic("shift too large") + } + + var l [8]uint64 + scalarMul512(l[:], a, b) + + // Right shift by 'shift' bits, rounding to nearest + carry := uint64(0) + if shift > 0 && (l[0]&(uint64(1)<<(shift-1))) != 0 { + carry = 1 // Round up if the bit being shifted out is 1 + } + + // Shift the limbs + for i := 0; i < 4; i++ { + var srcIndex int + var srcShift uint + if shift >= 64*uint(i) { + srcIndex = int(shift/64) + i + srcShift = shift % 64 + } else { + srcIndex = i + srcShift = shift + } + + if srcIndex >= 8 { + r.d[i] = 0 + continue + } + + val := l[srcIndex] + if srcShift > 0 && srcIndex+1 < 8 { + val |= l[srcIndex+1] << (64 - srcShift) + } + val >>= srcShift + + if i == 0 { + val += carry + } + + r.d[i] = val + } + + // Ensure result is reduced + scalarReduce(r, 0) +} + +// splitLambda splits a scalar k into r1 and r2 such that r1 + lambda*r2 = k mod n +// where lambda is the secp256k1 endomorphism constant. +// This is used for GLV (Gallant-Lambert-Vanstone) optimization. +// +// The algorithm computes c1 and c2 as approximations, then solves for r1 and r2. +// r1 and r2 are guaranteed to be in the range [-2^128, 2^128] approximately. +// +// Returns r1, r2 where k = r1 + lambda*r2 mod n +func (r1 *Scalar) splitLambda(r2 *Scalar, k *Scalar) { + var c1, c2 Scalar + + // Compute c1 = round(k * g1 / 2^384) + // c2 = round(k * g2 / 2^384) + // These are high-precision approximations for the GLV basis decomposition + scalarMulShiftVar(&c1, k, &g1, 384) + scalarMulShiftVar(&c2, k, &g2, 384) + + // Compute r2 = c1*(-b1) + c2*(-b2) + var tmp1, tmp2 Scalar + scalarMul(&tmp1, &c1, &minusB1) + scalarMul(&tmp2, &c2, &minusB2) + scalarAdd(r2, &tmp1, &tmp2) + + // Compute r1 = k - r2*lambda + scalarMul(r1, r2, &secp256k1Lambda) + r1.negate(r1) + scalarAdd(r1, r1, k) + + // Ensure the result is properly reduced + scalarReduce(r1, 0) + scalarReduce(r2, 0) +} +