working AVX2 scalar/affines

This commit is contained in:
2025-11-28 16:35:08 +00:00
parent 93af5ef27b
commit b250fc5cf7
20 changed files with 4412 additions and 8 deletions

View File

@@ -100,6 +100,102 @@ Benchmark results on AMD Ryzen 5 PRO 4650G:
- Field Addition: ~2.4 ns/op
- Scalar Multiplication: ~9.9 ns/op
## AVX2 Acceleration Opportunities
The Scalar and FieldElement types and their operations are designed with data layouts that are amenable to AVX2 SIMD acceleration:
### Scalar Type (`scalar.go`)
- **Representation**: 4×64-bit limbs (`[4]uint64`) representing 256-bit scalars
- **AVX2-Acceleratable Operations**:
- `scalarAdd` / `scalarMul`: 256-bit integer arithmetic using `VPADDD/Q`, `VPMULUDQ`
- `mul512`: Full 512-bit product computation - can use AVX2's 256-bit registers to process limb pairs in parallel
- `reduce512`: Modular reduction with Montgomery-style operations
- `wNAF`: Window Non-Adjacent Form conversion for scalar multiplication
- `splitLambda`: GLV endomorphism scalar splitting
### FieldElement Type (`field.go`, `field_mul.go`)
- **Representation**: 5×52-bit limbs (`[5]uint64`) in base 2^52 for efficient multiplication
- **AVX2-Acceleratable Operations**:
- `mul` / `sqr`: Field multiplication/squaring using 128-bit intermediate products
- `normalize` / `normalizeWeak`: Carry propagation across limbs
- `add` / `negate`: Parallel limb operations ideal for `VPADDQ`, `VPSUBQ`
- `inv`: Modular inversion via Fermat's little theorem (chain of sqr/mul)
- `sqrt`: Square root computation using addition chains
### Affine/Jacobian Group Operations (`group.go`)
- **Types**: `GroupElementAffine` (x, y coordinates), `GroupElementJacobian` (x, y, z coordinates)
- **AVX2-Acceleratable Operations**:
- `double`: Point doubling - multiple independent field operations
- `addVar` / `addGE`: Point addition - parallelizable field multiplications
- `setGEJ`: Coordinate conversion with batch field inversions
### Key AVX2 Instructions for Implementation
| Operation | Relevant AVX2 Instructions |
|-----------|---------------------------|
| 128-bit limb add | `VPADDQ` (packed 64-bit add) with carry chain |
| Limb multiplication | `VPMULUDQ` (unsigned 32×32→64), `VPCLMULQDQ` (carryless multiply) |
| 128-bit arithmetic | `VPMULLD`, `VPMULUDQ` for multi-precision products |
| Carry propagation | `VPSRLQ`/`VPSLLQ` (shift), `VPAND` (mask), `VPALIGNR` |
| Conditional moves | `VPBLENDVB` (blend based on mask) |
| Data movement | `VMOVDQU` (unaligned load/store), `VBROADCASTI128` |
### 128-bit Limb Representation with AVX2
AVX2's 256-bit YMM registers can natively hold two 128-bit limbs, enabling more efficient representations:
**Scalar (256-bit) with 2×128-bit limbs:**
```
YMM0 = [scalar.d[1]:scalar.d[0]] | [scalar.d[3]:scalar.d[2]]
├── 128-bit limb 0 ───────┤ ├── 128-bit limb 1 ───────┤
```
- A single 256-bit scalar fits in one YMM register as two 128-bit limbs
- Addition/subtraction can use `VPADDQ` with manual carry handling between 64-bit halves
- The 4×64-bit representation naturally maps to 2×128-bit by treating pairs
**FieldElement (260-bit effective) with 128-bit limbs:**
```
YMM0 = [fe.n[0]:fe.n[1]] (lower 104 bits used per pair)
YMM1 = [fe.n[2]:fe.n[3]]
XMM2 = [fe.n[4]:0] (upper 48 bits)
```
- 5×52-bit limbs can be reorganized into 3×128-bit containers
- Multiplication benefits from `VPMULUDQ` processing two 64×64→128 products simultaneously
**512-bit Intermediate Products:**
- Scalar multiplication produces 512-bit intermediates
- Two YMM registers hold the full product: `YMM0 = [l[1]:l[0]], YMM1 = [l[3]:l[2]], YMM2 = [l[5]:l[4]], YMM3 = [l[7]:l[6]]`
- Reduction can proceed in parallel across register pairs
### Implementation Approach
AVX2 acceleration can be added via Go assembly (`.s` files) using the patterns described in `AVX.md`:
```go
//go:build amd64
package p256k1
// FieldMulAVX2 multiplies two field elements using AVX2
// Uses 128-bit limb operations for ~2x throughput
//go:noescape
func FieldMulAVX2(r, a, b *FieldElement)
// ScalarMulAVX2 multiplies two scalars using AVX2
// Processes scalar as 2×128-bit limbs in a single YMM register
//go:noescape
func ScalarMulAVX2(r, a, b *Scalar)
// ScalarAdd256AVX2 adds two 256-bit scalars using 128-bit limb arithmetic
//go:noescape
func ScalarAdd256AVX2(r, a, b *Scalar) bool
```
The key insight is that AVX2's 256-bit registers holding 128-bit limb pairs enable:
- **2x parallelism** for addition/subtraction across limb pairs
- **Efficient carry chains** using `VPSRLQ` to extract carries and `VPADDQ` to propagate
- **Reduced loop iterations** for multi-precision arithmetic (2 iterations for 256-bit instead of 4)
## Implementation Status
### ✅ Completed

295
avx/IMPLEMENTATION_PLAN.md Normal file
View File

@@ -0,0 +1,295 @@
# AVX2 secp256k1 Implementation Plan
## Overview
This implementation uses 128-bit limbs with AVX2 256-bit registers for secp256k1 cryptographic operations. The key insight is that AVX2's YMM registers can hold two 128-bit values, enabling efficient parallel processing.
## Data Layout
### Register Mapping
| Type | Size | AVX2 Representation | Registers |
|------|------|---------------------|-----------|
| Uint128 | 128-bit | 1×128-bit in XMM or half YMM | 0.5 YMM |
| Scalar | 256-bit | 2×128-bit limbs | 1 YMM |
| FieldElement | 256-bit | 2×128-bit limbs | 1 YMM |
| AffinePoint | 512-bit | 2×FieldElement (x, y) | 2 YMM |
| JacobianPoint | 768-bit | 3×FieldElement (x, y, z) | 3 YMM |
### Memory Layout
```
Uint128:
[Lo:64][Hi:64] = 128 bits
Scalar/FieldElement (in YMM register):
YMM = [D[0].Lo:64][D[0].Hi:64][D[1].Lo:64][D[1].Hi:64]
├─── 128-bit limb 0 ────┤├─── 128-bit limb 1 ────┤
AffinePoint (2 YMM registers):
YMM0 = X coordinate (256 bits)
YMM1 = Y coordinate (256 bits)
JacobianPoint (3 YMM registers):
YMM0 = X coordinate (256 bits)
YMM1 = Y coordinate (256 bits)
YMM2 = Z coordinate (256 bits)
```
## Implementation Phases
### Phase 1: Core 128-bit Operations
File: `uint128_amd64.s`
1. **uint128Add** - Add two 128-bit values with carry out
- Instructions: `ADDQ`, `ADCQ`
- Input: XMM0 (a), XMM1 (b)
- Output: XMM0 (result), carry flag
2. **uint128Sub** - Subtract with borrow
- Instructions: `SUBQ`, `SBBQ`
3. **uint128Mul** - Multiply two 64-bit values to get 128-bit result
- Instructions: `MULQ` (scalar) or `VPMULUDQ` (SIMD)
4. **uint128Mul128** - Full 128×128→256 multiplication
- This is the critical operation for field/scalar multiplication
- Uses Karatsuba or schoolbook with `VPMULUDQ`
### Phase 2: Scalar Operations (mod n)
File: `scalar_amd64.go` (stubs), `scalar_amd64.s` (assembly)
1. **ScalarAdd** - Add two scalars mod n
```
Load a into YMM0
Load b into YMM1
VPADDQ YMM0, YMM0, YMM1 ; parallel add of 64-bit lanes
Handle carries between 64-bit lanes
Conditional subtract n if >= n
```
2. **ScalarSub** - Subtract scalars mod n
- Similar to add but with `VPSUBQ` and conditional add of n
3. **ScalarMul** - Multiply scalars mod n
- Compute 512-bit product using 128×128 multiplications
- Reduce mod n using Barrett or Montgomery reduction
- 512-bit intermediate fits in 2 YMM registers
4. **ScalarNegate** - Compute -a mod n
- `n - a` using subtraction
5. **ScalarInverse** - Compute a^(-1) mod n
- Use Fermat's little theorem: a^(n-2) mod n
- Requires efficient square-and-multiply
6. **ScalarIsZero**, **ScalarIsHigh**, **ScalarEqual** - Comparisons
### Phase 3: Field Operations (mod p)
File: `field_amd64.go` (stubs), `field_amd64.s` (assembly)
1. **FieldAdd** - Add two field elements mod p
```
Load a into YMM0
Load b into YMM1
VPADDQ YMM0, YMM0, YMM1
Handle carries
Conditional subtract p if >= p
```
2. **FieldSub** - Subtract field elements mod p
3. **FieldMul** - Multiply field elements mod p
- Most critical operation for performance
- 256×256→512 bit product, then reduce mod p
- secp256k1 has special structure: p = 2^256 - 2^32 - 977
- Reduction: if result >= 2^256, add (2^32 + 977) to lower bits
4. **FieldSqr** - Square a field element (optimized mul(a,a))
- Can save ~25% multiplications vs general multiply
5. **FieldInv** - Compute a^(-1) mod p
- Fermat: a^(p-2) mod p
- Use addition chain for efficiency
6. **FieldSqrt** - Compute square root mod p
- p ≡ 3 (mod 4), so sqrt(a) = a^((p+1)/4) mod p
7. **FieldNegate**, **FieldIsZero**, **FieldEqual** - Basic operations
### Phase 4: Point Operations
File: `point_amd64.go` (stubs), `point_amd64.s` (assembly)
1. **AffineToJacobian** - Convert (x, y) to (x, y, 1)
2. **JacobianToAffine** - Convert (X, Y, Z) to (X/Z², Y/Z³)
- Requires field inversion
3. **JacobianDouble** - Point doubling
- ~4 field multiplications, ~4 field squarings, ~6 field additions
- All field ops can use AVX2 versions
4. **JacobianAdd** - Add two Jacobian points
- ~12 field multiplications, ~4 field squarings
5. **JacobianAddAffine** - Add Jacobian + Affine (optimized)
- ~8 field multiplications, ~3 field squarings
- Common case in scalar multiplication
6. **ScalarMult** - Compute k*P for scalar k and point P
- Use windowed NAF or GLV decomposition
- Core loop: double + conditional add
7. **ScalarBaseMult** - Compute k*G using precomputed table
- Precompute multiples of generator G
- Faster than general scalar mult
### Phase 5: High-Level Operations
File: `ecdsa.go`, `schnorr.go`
1. **ECDSA Sign/Verify**
2. **Schnorr Sign/Verify** (BIP-340)
3. **ECDH** - Shared secret computation
## Assembly Conventions
### Register Usage
```
YMM0-YMM7: Scratch registers (caller-saved)
YMM8-YMM15: Can be used but should be preserved
For our operations:
YMM0: Primary operand/result
YMM1: Secondary operand
YMM2-YMM5: Intermediate calculations
YMM6-YMM7: Constants (field prime, masks, etc.)
```
### Key AVX2 Instructions
```asm
; Data movement
VMOVDQU YMM0, [mem] ; Load 256 bits unaligned
VMOVDQA YMM0, [mem] ; Load 256 bits aligned
VBROADCASTI128 YMM0, [mem] ; Broadcast 128-bit to both lanes
; Arithmetic
VPADDQ YMM0, YMM1, YMM2 ; Add packed 64-bit integers
VPSUBQ YMM0, YMM1, YMM2 ; Subtract packed 64-bit integers
VPMULUDQ YMM0, YMM1, YMM2 ; Multiply low 32-bits of each 64-bit lane
; Logical
VPAND YMM0, YMM1, YMM2 ; Bitwise AND
VPOR YMM0, YMM1, YMM2 ; Bitwise OR
VPXOR YMM0, YMM1, YMM2 ; Bitwise XOR
; Shifts
VPSLLQ YMM0, YMM1, imm ; Shift left logical 64-bit
VPSRLQ YMM0, YMM1, imm ; Shift right logical 64-bit
; Shuffles and permutes
VPERMQ YMM0, YMM1, imm ; Permute 64-bit elements
VPERM2I128 YMM0, YMM1, YMM2, imm ; Permute 128-bit lanes
VPALIGNR YMM0, YMM1, YMM2, imm ; Byte align
; Comparisons
VPCMPEQQ YMM0, YMM1, YMM2 ; Compare equal 64-bit
VPCMPGTQ YMM0, YMM1, YMM2 ; Compare greater than 64-bit
; Blending
VPBLENDVB YMM0, YMM1, YMM2, YMM3 ; Conditional blend
```
## Carry Propagation Strategy
The tricky part of 128-bit limb arithmetic is carry propagation between the 64-bit halves and between the two 128-bit limbs.
### Addition Carry Chain
```
Given: A = [A0.Lo, A0.Hi, A1.Lo, A1.Hi] (256 bits as 4×64)
B = [B0.Lo, B0.Hi, B1.Lo, B1.Hi]
Step 1: Add with VPADDQ (no carries)
R = A + B (per-lane, ignoring overflow)
Step 2: Detect carries
carry_0_to_1 = (R0.Lo < A0.Lo) ? 1 : 0 ; carry from Lo to Hi in limb 0
carry_1_to_2 = (R0.Hi < A0.Hi) ? 1 : 0 ; carry from limb 0 to limb 1
carry_2_to_3 = (R1.Lo < A1.Lo) ? 1 : 0 ; carry within limb 1
carry_out = (R1.Hi < A1.Hi) ? 1 : 0 ; overflow
Step 3: Propagate carries
R0.Hi += carry_0_to_1
R1.Lo += carry_1_to_2 + (R0.Hi < carry_0_to_1 ? 1 : 0)
R1.Hi += carry_2_to_3 + ...
```
This is complex in SIMD. Alternative: use `ADCX`/`ADOX` instructions (ADX extension) for scalar carry chains, which may be faster for sequential operations.
### Multiplication Strategy
For 128×128→256 multiplication:
```
A = A.Hi * 2^64 + A.Lo
B = B.Hi * 2^64 + B.Lo
A * B = A.Hi*B.Hi * 2^128
+ (A.Hi*B.Lo + A.Lo*B.Hi) * 2^64
+ A.Lo*B.Lo
Using MULX (BMI2) for efficient 64×64→128:
MULX r1, r0, A.Lo ; r1:r0 = A.Lo * B.Lo
MULX r3, r2, A.Hi ; r3:r2 = A.Hi * B.Lo
... (4 multiplications total, then accumulate)
```
## Testing Strategy
1. **Unit tests for each operation** comparing against reference (main package)
2. **Edge cases**: zero, one, max values, values near modulus
3. **Random tests**: generate random inputs, compare results
4. **Benchmark comparisons**: AVX2 vs pure Go implementation
## File Structure
```
avx/
├── IMPLEMENTATION_PLAN.md (this file)
├── types.go (type definitions)
├── uint128.go (pure Go fallback)
├── uint128_amd64.go (Go stubs for assembly)
├── uint128_amd64.s (AVX2 assembly)
├── scalar.go (pure Go fallback)
├── scalar_amd64.go (Go stubs)
├── scalar_amd64.s (AVX2 assembly)
├── field.go (pure Go fallback)
├── field_amd64.go (Go stubs)
├── field_amd64.s (AVX2 assembly)
├── point.go (pure Go fallback)
├── point_amd64.go (Go stubs)
├── point_amd64.s (AVX2 assembly)
├── avx_test.go (tests)
└── bench_test.go (benchmarks)
```
## Performance Targets
Compared to the current pure Go implementation:
- Scalar multiplication: 2-3x faster
- Field multiplication: 2-4x faster
- Point operations: 2-3x faster (dominated by field ops)
- ECDSA sign/verify: 2-3x faster overall
## Dependencies
- Go 1.21+ (for assembly support)
- CPU with AVX2 support (Intel Haswell+, AMD Excavator+)
- Optional: BMI2 for MULX instruction (faster 64×64→128 multiply)

452
avx/avx_test.go Normal file
View File

@@ -0,0 +1,452 @@
package avx
import (
"bytes"
"crypto/rand"
"encoding/hex"
"testing"
)
// Test vectors from Bitcoin/secp256k1
func TestUint128Add(t *testing.T) {
tests := []struct {
a, b Uint128
expect Uint128
carry uint64
}{
{Uint128{0, 0}, Uint128{0, 0}, Uint128{0, 0}, 0},
{Uint128{1, 0}, Uint128{1, 0}, Uint128{2, 0}, 0},
{Uint128{^uint64(0), 0}, Uint128{1, 0}, Uint128{0, 1}, 0},
{Uint128{^uint64(0), ^uint64(0)}, Uint128{1, 0}, Uint128{0, 0}, 1},
}
for i, tt := range tests {
result, carry := tt.a.Add(tt.b)
if result != tt.expect || carry != tt.carry {
t.Errorf("test %d: got (%v, %d), want (%v, %d)", i, result, carry, tt.expect, tt.carry)
}
}
}
func TestUint128Mul(t *testing.T) {
// Test: 2^64 * 2^64 = 2^128
a := Uint128{0, 1} // 2^64
b := Uint128{0, 1} // 2^64
result := a.Mul(b)
// Expected: 2^128 = [0, 0, 1, 0]
expected := [4]uint64{0, 0, 1, 0}
if result != expected {
t.Errorf("2^64 * 2^64: got %v, want %v", result, expected)
}
// Test: (2^64 - 1) * (2^64 - 1)
a = Uint128{^uint64(0), 0}
b = Uint128{^uint64(0), 0}
result = a.Mul(b)
// (2^64 - 1)^2 = 2^128 - 2^65 + 1
// = [1, 0xFFFFFFFFFFFFFFFE, 0, 0]
expected = [4]uint64{1, 0xFFFFFFFFFFFFFFFE, 0, 0}
if result != expected {
t.Errorf("(2^64-1)^2: got %v, want %v", result, expected)
}
}
func TestScalarSetBytes(t *testing.T) {
// Test with a known scalar
bytes32 := make([]byte, 32)
bytes32[31] = 1 // scalar = 1
var s Scalar
s.SetBytes(bytes32)
if !s.IsOne() {
t.Errorf("expected scalar to be 1, got %+v", s)
}
// Test zero
bytes32 = make([]byte, 32)
s.SetBytes(bytes32)
if !s.IsZero() {
t.Errorf("expected scalar to be 0, got %+v", s)
}
}
func TestScalarAddSub(t *testing.T) {
var a, b, sum, diff, recovered Scalar
// a = 1, b = 2
a = ScalarOne
b.D[0].Lo = 2
sum.Add(&a, &b)
if sum.D[0].Lo != 3 {
t.Errorf("1 + 2: expected 3, got %d", sum.D[0].Lo)
}
diff.Sub(&sum, &b)
if !diff.Equal(&a) {
t.Errorf("(1+2) - 2: expected 1, got %+v", diff)
}
// Test with overflow
a = ScalarN
a.D[0].Lo-- // n - 1
b = ScalarOne
sum.Add(&a, &b)
// n - 1 + 1 = n ≡ 0 (mod n)
if !sum.IsZero() {
t.Errorf("(n-1) + 1 should be 0 mod n, got %+v", sum)
}
// Test subtraction with borrow
a = ScalarZero
b = ScalarOne
diff.Sub(&a, &b)
// 0 - 1 = -1 ≡ n - 1 (mod n)
recovered.Add(&diff, &b)
if !recovered.IsZero() {
t.Errorf("(0-1) + 1 should be 0, got %+v", recovered)
}
}
func TestScalarMul(t *testing.T) {
var a, b, product Scalar
// 2 * 3 = 6
a.D[0].Lo = 2
b.D[0].Lo = 3
product.Mul(&a, &b)
if product.D[0].Lo != 6 || product.D[0].Hi != 0 || !product.D[1].IsZero() {
t.Errorf("2 * 3: expected 6, got %+v", product)
}
// Test with larger values
a.D[0].Lo = 0xFFFFFFFFFFFFFFFF
a.D[0].Hi = 0
b.D[0].Lo = 2
product.Mul(&a, &b)
// (2^64 - 1) * 2 = 2^65 - 2
if product.D[0].Lo != 0xFFFFFFFFFFFFFFFE || product.D[0].Hi != 1 {
t.Errorf("(2^64-1) * 2: got %+v", product)
}
}
func TestScalarNegate(t *testing.T) {
var a, neg, sum Scalar
a.D[0].Lo = 12345
neg.Negate(&a)
sum.Add(&a, &neg)
if !sum.IsZero() {
t.Errorf("a + (-a) should be 0, got %+v", sum)
}
}
func TestFieldSetBytes(t *testing.T) {
bytes32 := make([]byte, 32)
bytes32[31] = 1
var f FieldElement
f.SetBytes(bytes32)
if !f.IsOne() {
t.Errorf("expected field element to be 1, got %+v", f)
}
}
func TestFieldAddSub(t *testing.T) {
var a, b, sum, diff FieldElement
a.N[0].Lo = 100
b.N[0].Lo = 200
sum.Add(&a, &b)
if sum.N[0].Lo != 300 {
t.Errorf("100 + 200: expected 300, got %d", sum.N[0].Lo)
}
diff.Sub(&sum, &b)
if !diff.Equal(&a) {
t.Errorf("(100+200) - 200: expected 100, got %+v", diff)
}
}
func TestFieldMul(t *testing.T) {
var a, b, product FieldElement
a.N[0].Lo = 7
b.N[0].Lo = 8
product.Mul(&a, &b)
if product.N[0].Lo != 56 {
t.Errorf("7 * 8: expected 56, got %d", product.N[0].Lo)
}
}
func TestFieldInverse(t *testing.T) {
var a, inv, product FieldElement
a.N[0].Lo = 7
inv.Inverse(&a)
product.Mul(&a, &inv)
if !product.IsOne() {
t.Errorf("7 * 7^(-1) should be 1, got %+v", product)
}
}
func TestFieldSqrt(t *testing.T) {
// Test sqrt(4) = 2
var four, root, check FieldElement
four.N[0].Lo = 4
if !root.Sqrt(&four) {
t.Fatal("sqrt(4) should exist")
}
check.Sqr(&root)
if !check.Equal(&four) {
t.Errorf("sqrt(4)^2 should be 4, got %+v", check)
}
}
func TestGeneratorOnCurve(t *testing.T) {
if !Generator.IsOnCurve() {
t.Error("generator point should be on the curve")
}
}
func TestPointDouble(t *testing.T) {
var g, doubled JacobianPoint
var affineResult AffinePoint
g.FromAffine(&Generator)
doubled.Double(&g)
doubled.ToAffine(&affineResult)
if affineResult.Infinity {
t.Error("2G should not be infinity")
}
if !affineResult.IsOnCurve() {
t.Error("2G should be on the curve")
}
}
func TestPointAdd(t *testing.T) {
var g, twoG, threeG JacobianPoint
var affineResult AffinePoint
g.FromAffine(&Generator)
twoG.Double(&g)
threeG.Add(&twoG, &g)
threeG.ToAffine(&affineResult)
if !affineResult.IsOnCurve() {
t.Error("3G should be on the curve")
}
// Also test via scalar multiplication
var three Scalar
three.D[0].Lo = 3
var expected JacobianPoint
expected.ScalarMult(&g, &three)
var expectedAffine AffinePoint
expected.ToAffine(&expectedAffine)
if !affineResult.Equal(&expectedAffine) {
t.Error("G + 2G should equal 3G")
}
}
func TestPointAddInfinity(t *testing.T) {
var g, inf, result JacobianPoint
var affineResult AffinePoint
g.FromAffine(&Generator)
inf.SetInfinity()
result.Add(&g, &inf)
result.ToAffine(&affineResult)
if !affineResult.Equal(&Generator) {
t.Error("G + O should equal G")
}
result.Add(&inf, &g)
result.ToAffine(&affineResult)
if !affineResult.Equal(&Generator) {
t.Error("O + G should equal G")
}
}
func TestScalarBaseMult(t *testing.T) {
// Test 1*G = G
result := BasePointMult(&ScalarOne)
if !result.Equal(&Generator) {
t.Error("1*G should equal G")
}
// Test 2*G
var two Scalar
two.D[0].Lo = 2
result = BasePointMult(&two)
var g, twoG JacobianPoint
var expected AffinePoint
g.FromAffine(&Generator)
twoG.Double(&g)
twoG.ToAffine(&expected)
if !result.Equal(&expected) {
t.Error("2*G via scalar mult should equal 2*G via doubling")
}
}
func TestKnownScalarMult(t *testing.T) {
// Test vector: private key and public key from Bitcoin
// This is a well-known test vector
privKeyHex := "0000000000000000000000000000000000000000000000000000000000000001"
expectedXHex := "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
expectedYHex := "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8"
privKeyBytes, _ := hex.DecodeString(privKeyHex)
var k Scalar
k.SetBytes(privKeyBytes)
result := BasePointMult(&k)
xBytes := result.X.Bytes()
yBytes := result.Y.Bytes()
expectedX, _ := hex.DecodeString(expectedXHex)
expectedY, _ := hex.DecodeString(expectedYHex)
if !bytes.Equal(xBytes[:], expectedX) {
t.Errorf("X coordinate mismatch:\ngot: %x\nwant: %x", xBytes, expectedX)
}
if !bytes.Equal(yBytes[:], expectedY) {
t.Errorf("Y coordinate mismatch:\ngot: %x\nwant: %x", yBytes, expectedY)
}
}
// Benchmark tests
func BenchmarkUint128Mul(b *testing.B) {
a := Uint128{0x123456789ABCDEF0, 0xFEDCBA9876543210}
c := Uint128{0xABCDEF0123456789, 0x9876543210FEDCBA}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = a.Mul(c)
}
}
func BenchmarkScalarAdd(b *testing.B) {
var a, c, r Scalar
aBytes := make([]byte, 32)
cBytes := make([]byte, 32)
rand.Read(aBytes)
rand.Read(cBytes)
a.SetBytes(aBytes)
c.SetBytes(cBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Add(&a, &c)
}
}
func BenchmarkScalarMul(b *testing.B) {
var a, c, r Scalar
aBytes := make([]byte, 32)
cBytes := make([]byte, 32)
rand.Read(aBytes)
rand.Read(cBytes)
a.SetBytes(aBytes)
c.SetBytes(cBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Mul(&a, &c)
}
}
func BenchmarkFieldAdd(b *testing.B) {
var a, c, r FieldElement
aBytes := make([]byte, 32)
cBytes := make([]byte, 32)
rand.Read(aBytes)
rand.Read(cBytes)
a.SetBytes(aBytes)
c.SetBytes(cBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Add(&a, &c)
}
}
func BenchmarkFieldMul(b *testing.B) {
var a, c, r FieldElement
aBytes := make([]byte, 32)
cBytes := make([]byte, 32)
rand.Read(aBytes)
rand.Read(cBytes)
a.SetBytes(aBytes)
c.SetBytes(cBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Mul(&a, &c)
}
}
func BenchmarkFieldInverse(b *testing.B) {
var a, r FieldElement
aBytes := make([]byte, 32)
rand.Read(aBytes)
a.SetBytes(aBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Inverse(&a)
}
}
func BenchmarkPointDouble(b *testing.B) {
var g, r JacobianPoint
g.FromAffine(&Generator)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Double(&g)
}
}
func BenchmarkPointAdd(b *testing.B) {
var g, twoG, r JacobianPoint
g.FromAffine(&Generator)
twoG.Double(&g)
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Add(&g, &twoG)
}
}
func BenchmarkScalarBaseMult(b *testing.B) {
var k Scalar
kBytes := make([]byte, 32)
rand.Read(kBytes)
k.SetBytes(kBytes)
b.ResetTimer()
for i := 0; i < b.N; i++ {
_ = BasePointMult(&k)
}
}

59
avx/debug_double_test.go Normal file
View File

@@ -0,0 +1,59 @@
package avx
import (
"bytes"
"encoding/hex"
"testing"
)
func TestDebugDouble(t *testing.T) {
// Known value: 2G for secp256k1 (verified using Python)
expectedX := "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5"
expectedY := "1ae168fea63dc339a3c58419466ceaeef7f632653266d0e1236431a950cfe52a"
var g, doubled JacobianPoint
var affineResult AffinePoint
g.FromAffine(&Generator)
doubled.Double(&g)
doubled.ToAffine(&affineResult)
xBytes := affineResult.X.Bytes()
yBytes := affineResult.Y.Bytes()
t.Logf("Generator X: %x", Generator.X.Bytes())
t.Logf("Generator Y: %x", Generator.Y.Bytes())
t.Logf("2G X: %x", xBytes)
t.Logf("2G Y: %x", yBytes)
expectedXBytes, _ := hex.DecodeString(expectedX)
expectedYBytes, _ := hex.DecodeString(expectedY)
t.Logf("Expected X: %s", expectedX)
t.Logf("Expected Y: %s", expectedY)
if !bytes.Equal(xBytes[:], expectedXBytes) {
t.Errorf("2G X coordinate mismatch")
}
if !bytes.Equal(yBytes[:], expectedYBytes) {
t.Errorf("2G Y coordinate mismatch")
}
// Check if 2G is on curve
if !affineResult.IsOnCurve() {
// Let's verify manually
var y2, x2, x3, rhs FieldElement
y2.Sqr(&affineResult.Y)
x2.Sqr(&affineResult.X)
x3.Mul(&x2, &affineResult.X)
var seven FieldElement
seven.N[0].Lo = 7
rhs.Add(&x3, &seven)
y2Bytes := y2.Bytes()
rhsBytes := rhs.Bytes()
t.Logf("y^2 = %x", y2Bytes)
t.Logf("x^3 + 7 = %x", rhsBytes)
t.Logf("y^2 == x^3+7: %v", y2.Equal(&rhs))
}
}

446
avx/field.go Normal file
View File

@@ -0,0 +1,446 @@
package avx
import "math/bits"
// Field operations modulo the secp256k1 field prime p.
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
// = 2^256 - 2^32 - 977
// SetBytes sets a field element from a 32-byte big-endian slice.
// Returns true if the value was >= p and was reduced.
func (f *FieldElement) SetBytes(b []byte) bool {
if len(b) != 32 {
panic("field element must be 32 bytes")
}
// Convert big-endian bytes to little-endian limbs
f.N[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
f.N[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
f.N[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
f.N[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
// Check overflow and reduce if necessary
overflow := f.checkOverflow()
if overflow {
f.reduce()
}
return overflow
}
// Bytes returns the field element as a 32-byte big-endian slice.
func (f *FieldElement) Bytes() [32]byte {
var b [32]byte
b[31] = byte(f.N[0].Lo)
b[30] = byte(f.N[0].Lo >> 8)
b[29] = byte(f.N[0].Lo >> 16)
b[28] = byte(f.N[0].Lo >> 24)
b[27] = byte(f.N[0].Lo >> 32)
b[26] = byte(f.N[0].Lo >> 40)
b[25] = byte(f.N[0].Lo >> 48)
b[24] = byte(f.N[0].Lo >> 56)
b[23] = byte(f.N[0].Hi)
b[22] = byte(f.N[0].Hi >> 8)
b[21] = byte(f.N[0].Hi >> 16)
b[20] = byte(f.N[0].Hi >> 24)
b[19] = byte(f.N[0].Hi >> 32)
b[18] = byte(f.N[0].Hi >> 40)
b[17] = byte(f.N[0].Hi >> 48)
b[16] = byte(f.N[0].Hi >> 56)
b[15] = byte(f.N[1].Lo)
b[14] = byte(f.N[1].Lo >> 8)
b[13] = byte(f.N[1].Lo >> 16)
b[12] = byte(f.N[1].Lo >> 24)
b[11] = byte(f.N[1].Lo >> 32)
b[10] = byte(f.N[1].Lo >> 40)
b[9] = byte(f.N[1].Lo >> 48)
b[8] = byte(f.N[1].Lo >> 56)
b[7] = byte(f.N[1].Hi)
b[6] = byte(f.N[1].Hi >> 8)
b[5] = byte(f.N[1].Hi >> 16)
b[4] = byte(f.N[1].Hi >> 24)
b[3] = byte(f.N[1].Hi >> 32)
b[2] = byte(f.N[1].Hi >> 40)
b[1] = byte(f.N[1].Hi >> 48)
b[0] = byte(f.N[1].Hi >> 56)
return b
}
// IsZero returns true if the field element is zero.
func (f *FieldElement) IsZero() bool {
return f.N[0].IsZero() && f.N[1].IsZero()
}
// IsOne returns true if the field element is one.
func (f *FieldElement) IsOne() bool {
return f.N[0].Lo == 1 && f.N[0].Hi == 0 && f.N[1].IsZero()
}
// Equal returns true if two field elements are equal.
func (f *FieldElement) Equal(other *FieldElement) bool {
return f.N[0].Lo == other.N[0].Lo && f.N[0].Hi == other.N[0].Hi &&
f.N[1].Lo == other.N[1].Lo && f.N[1].Hi == other.N[1].Hi
}
// IsOdd returns true if the field element is odd.
func (f *FieldElement) IsOdd() bool {
return f.N[0].Lo&1 == 1
}
// checkOverflow returns true if f >= p.
func (f *FieldElement) checkOverflow() bool {
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
// Compare high to low
if f.N[1].Hi > FieldP.N[1].Hi {
return true
}
if f.N[1].Hi < FieldP.N[1].Hi {
return false
}
if f.N[1].Lo > FieldP.N[1].Lo {
return true
}
if f.N[1].Lo < FieldP.N[1].Lo {
return false
}
if f.N[0].Hi > FieldP.N[0].Hi {
return true
}
if f.N[0].Hi < FieldP.N[0].Hi {
return false
}
return f.N[0].Lo >= FieldP.N[0].Lo
}
// reduce reduces f modulo p by adding the complement (2^256 - p = 2^32 + 977).
func (f *FieldElement) reduce() {
// f = f - p = f + (2^256 - p) mod 2^256
// 2^256 - p = 0x1000003D1
var carry uint64
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, 0x1000003D1, 0)
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, 0, carry)
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
}
// Add sets f = a + b mod p.
func (f *FieldElement) Add(a, b *FieldElement) *FieldElement {
var carry uint64
f.N[0].Lo, carry = bits.Add64(a.N[0].Lo, b.N[0].Lo, 0)
f.N[0].Hi, carry = bits.Add64(a.N[0].Hi, b.N[0].Hi, carry)
f.N[1].Lo, carry = bits.Add64(a.N[1].Lo, b.N[1].Lo, carry)
f.N[1].Hi, carry = bits.Add64(a.N[1].Hi, b.N[1].Hi, carry)
// If there was a carry or if result >= p, reduce
if carry != 0 || f.checkOverflow() {
f.reduce()
}
return f
}
// Sub sets f = a - b mod p.
func (f *FieldElement) Sub(a, b *FieldElement) *FieldElement {
var borrow uint64
f.N[0].Lo, borrow = bits.Sub64(a.N[0].Lo, b.N[0].Lo, 0)
f.N[0].Hi, borrow = bits.Sub64(a.N[0].Hi, b.N[0].Hi, borrow)
f.N[1].Lo, borrow = bits.Sub64(a.N[1].Lo, b.N[1].Lo, borrow)
f.N[1].Hi, borrow = bits.Sub64(a.N[1].Hi, b.N[1].Hi, borrow)
// If there was a borrow, add p back
if borrow != 0 {
var carry uint64
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, FieldP.N[0].Lo, 0)
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, FieldP.N[0].Hi, carry)
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, FieldP.N[1].Lo, carry)
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, FieldP.N[1].Hi, carry)
}
return f
}
// Negate sets f = -a mod p.
func (f *FieldElement) Negate(a *FieldElement) *FieldElement {
if a.IsZero() {
*f = FieldZero
return f
}
// f = p - a
var borrow uint64
f.N[0].Lo, borrow = bits.Sub64(FieldP.N[0].Lo, a.N[0].Lo, 0)
f.N[0].Hi, borrow = bits.Sub64(FieldP.N[0].Hi, a.N[0].Hi, borrow)
f.N[1].Lo, borrow = bits.Sub64(FieldP.N[1].Lo, a.N[1].Lo, borrow)
f.N[1].Hi, _ = bits.Sub64(FieldP.N[1].Hi, a.N[1].Hi, borrow)
return f
}
// Mul sets f = a * b mod p.
func (f *FieldElement) Mul(a, b *FieldElement) *FieldElement {
// Compute 512-bit product
var prod [8]uint64
fieldMul512(&prod, a, b)
// Reduce mod p using secp256k1's special structure
fieldReduce512(f, &prod)
return f
}
// fieldMul512 computes the 512-bit product of two 256-bit field elements.
func fieldMul512(prod *[8]uint64, a, b *FieldElement) {
aLimbs := [4]uint64{a.N[0].Lo, a.N[0].Hi, a.N[1].Lo, a.N[1].Hi}
bLimbs := [4]uint64{b.N[0].Lo, b.N[0].Hi, b.N[1].Lo, b.N[1].Hi}
// Clear product
for i := range prod {
prod[i] = 0
}
// Schoolbook multiplication
for i := 0; i < 4; i++ {
var carry uint64
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
lo, c := bits.Add64(lo, prod[i+j], 0)
hi, _ = bits.Add64(hi, 0, c)
lo, c = bits.Add64(lo, carry, 0)
hi, _ = bits.Add64(hi, 0, c)
prod[i+j] = lo
carry = hi
}
prod[i+4] = carry
}
}
// fieldReduce512 reduces a 512-bit value mod p using secp256k1's special structure.
// p = 2^256 - 2^32 - 977, so 2^256 ≡ 2^32 + 977 (mod p)
func fieldReduce512(f *FieldElement, prod *[8]uint64) {
// The key insight: if we have a 512-bit number split as H*2^256 + L
// then H*2^256 + L ≡ H*(2^32 + 977) + L (mod p)
// Extract low and high 256-bit parts
low := [4]uint64{prod[0], prod[1], prod[2], prod[3]}
high := [4]uint64{prod[4], prod[5], prod[6], prod[7]}
// Compute high * (2^32 + 977) = high * 0x1000003D1
// This gives us at most a 289-bit result (256 + 33 bits)
const c = uint64(0x1000003D1)
var reduction [5]uint64
var carry uint64
for i := 0; i < 4; i++ {
hi, lo := bits.Mul64(high[i], c)
lo, cc := bits.Add64(lo, carry, 0)
hi, _ = bits.Add64(hi, 0, cc)
reduction[i] = lo
carry = hi
}
reduction[4] = carry
// Add low + reduction
var result [5]uint64
carry = 0
for i := 0; i < 4; i++ {
result[i], carry = bits.Add64(low[i], reduction[i], carry)
}
result[4] = carry + reduction[4]
// If result[4] is non-zero, we need to reduce again
// result[4] * 2^256 ≡ result[4] * (2^32 + 977) (mod p)
if result[4] != 0 {
hi, lo := bits.Mul64(result[4], c)
result[0], carry = bits.Add64(result[0], lo, 0)
result[1], carry = bits.Add64(result[1], hi, carry)
result[2], carry = bits.Add64(result[2], 0, carry)
result[3], _ = bits.Add64(result[3], 0, carry)
result[4] = 0
}
// Store result
f.N[0].Lo = result[0]
f.N[0].Hi = result[1]
f.N[1].Lo = result[2]
f.N[1].Hi = result[3]
// Final reduction if >= p
if f.checkOverflow() {
f.reduce()
}
}
// Sqr sets f = a^2 mod p.
func (f *FieldElement) Sqr(a *FieldElement) *FieldElement {
// Optimized squaring could save some multiplications, but for now use Mul
return f.Mul(a, a)
}
// Inverse sets f = a^(-1) mod p using Fermat's little theorem.
// a^(-1) = a^(p-2) mod p
func (f *FieldElement) Inverse(a *FieldElement) *FieldElement {
// p-2 in bytes (big-endian)
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
// p-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D
pMinus2 := [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2D,
}
var result, base FieldElement
result = FieldOne
base = *a
for i := 0; i < 32; i++ {
b := pMinus2[31-i]
for j := 0; j < 8; j++ {
if (b>>j)&1 == 1 {
result.Mul(&result, &base)
}
base.Sqr(&base)
}
}
*f = result
return f
}
// Sqrt sets f = sqrt(a) mod p if it exists, returns true if successful.
// For secp256k1, p ≡ 3 (mod 4), so sqrt(a) = a^((p+1)/4) mod p
func (f *FieldElement) Sqrt(a *FieldElement) bool {
// (p+1)/4 in bytes
// p+1 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30
// (p+1)/4 = 3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFF0C
pPlus1Div4 := [32]byte{
0x3F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xBF, 0xFF, 0xFF, 0x0C,
}
var result, base FieldElement
result = FieldOne
base = *a
for i := 0; i < 32; i++ {
b := pPlus1Div4[31-i]
for j := 0; j < 8; j++ {
if (b>>j)&1 == 1 {
result.Mul(&result, &base)
}
base.Sqr(&base)
}
}
// Verify: result^2 should equal a
var check FieldElement
check.Sqr(&result)
if check.Equal(a) {
*f = result
return true
}
return false
}
// MulInt sets f = a * n mod p where n is a small integer.
func (f *FieldElement) MulInt(a *FieldElement, n uint64) *FieldElement {
if n == 0 {
*f = FieldZero
return f
}
if n == 1 {
*f = *a
return f
}
// Multiply by small integer using proper carry chain
// We need to compute a 320-bit result (256 + 64 bits max)
var result [5]uint64
var carry uint64
// Multiply each 64-bit limb by n
var hi uint64
hi, result[0] = bits.Mul64(a.N[0].Lo, n)
carry = hi
hi, result[1] = bits.Mul64(a.N[0].Hi, n)
result[1], carry = bits.Add64(result[1], carry, 0)
carry = hi + carry // carry can be at most 1 here, so no overflow
hi, result[2] = bits.Mul64(a.N[1].Lo, n)
result[2], carry = bits.Add64(result[2], carry, 0)
carry = hi + carry
hi, result[3] = bits.Mul64(a.N[1].Hi, n)
result[3], carry = bits.Add64(result[3], carry, 0)
result[4] = hi + carry
// Store preliminary result
f.N[0].Lo = result[0]
f.N[0].Hi = result[1]
f.N[1].Lo = result[2]
f.N[1].Hi = result[3]
// Reduce overflow
if result[4] != 0 {
// overflow * 2^256 ≡ overflow * (2^32 + 977) (mod p)
hi, lo := bits.Mul64(result[4], 0x1000003D1)
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, lo, 0)
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, hi, carry)
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
}
if f.checkOverflow() {
f.reduce()
}
return f
}
// Double sets f = 2*a mod p (optimized addition).
func (f *FieldElement) Double(a *FieldElement) *FieldElement {
return f.Add(a, a)
}
// Half sets f = a/2 mod p.
func (f *FieldElement) Half(a *FieldElement) *FieldElement {
// If a is even, just shift right
// If a is odd, add p first (which makes it even), then shift right
var result FieldElement = *a
if result.N[0].Lo&1 == 1 {
// Add p
var carry uint64
result.N[0].Lo, carry = bits.Add64(result.N[0].Lo, FieldP.N[0].Lo, 0)
result.N[0].Hi, carry = bits.Add64(result.N[0].Hi, FieldP.N[0].Hi, carry)
result.N[1].Lo, carry = bits.Add64(result.N[1].Lo, FieldP.N[1].Lo, carry)
result.N[1].Hi, _ = bits.Add64(result.N[1].Hi, FieldP.N[1].Hi, carry)
}
// Shift right by 1
f.N[0].Lo = (result.N[0].Lo >> 1) | (result.N[0].Hi << 63)
f.N[0].Hi = (result.N[0].Hi >> 1) | (result.N[1].Lo << 63)
f.N[1].Lo = (result.N[1].Lo >> 1) | (result.N[1].Hi << 63)
f.N[1].Hi = result.N[1].Hi >> 1
return f
}
// CMov conditionally moves b into f if cond is true (constant-time).
func (f *FieldElement) CMov(b *FieldElement, cond bool) *FieldElement {
mask := uint64(0)
if cond {
mask = ^uint64(0)
}
f.N[0].Lo = (f.N[0].Lo &^ mask) | (b.N[0].Lo & mask)
f.N[0].Hi = (f.N[0].Hi &^ mask) | (b.N[0].Hi & mask)
f.N[1].Lo = (f.N[1].Lo &^ mask) | (b.N[1].Lo & mask)
f.N[1].Hi = (f.N[1].Hi &^ mask) | (b.N[1].Hi & mask)
return f
}

25
avx/field_amd64.go Normal file
View File

@@ -0,0 +1,25 @@
//go:build amd64
package avx
// AMD64-specific field operations with AVX2 assembly.
// FieldAddAVX2 adds two field elements using AVX2.
//
//go:noescape
func FieldAddAVX2(r, a, b *FieldElement)
// FieldSubAVX2 subtracts two field elements using AVX2.
//
//go:noescape
func FieldSubAVX2(r, a, b *FieldElement)
// FieldMulAVX2 multiplies two field elements using AVX2.
//
//go:noescape
func FieldMulAVX2(r, a, b *FieldElement)
// FieldSqrAVX2 squares a field element using AVX2.
//
//go:noescape
func FieldSqrAVX2(r, a *FieldElement)

369
avx/field_amd64.s Normal file
View File

@@ -0,0 +1,369 @@
//go:build amd64
#include "textflag.h"
// Field prime p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
DATA fieldP<>+0x00(SB)/8, $0xFFFFFFFEFFFFFC2F
DATA fieldP<>+0x08(SB)/8, $0xFFFFFFFFFFFFFFFF
DATA fieldP<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFF
DATA fieldP<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF
GLOBL fieldP<>(SB), RODATA|NOPTR, $32
// 2^256 - p = 2^32 + 977 = 0x1000003D1
DATA fieldPC<>+0x00(SB)/8, $0x1000003D1
DATA fieldPC<>+0x08(SB)/8, $0x0000000000000000
DATA fieldPC<>+0x10(SB)/8, $0x0000000000000000
DATA fieldPC<>+0x18(SB)/8, $0x0000000000000000
GLOBL fieldPC<>(SB), RODATA|NOPTR, $32
// func FieldAddAVX2(r, a, b *FieldElement)
// Adds two 256-bit field elements mod p.
TEXT ·FieldAddAVX2(SB), NOSPLIT, $0-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// Load a
MOVQ 0(SI), AX
MOVQ 8(SI), BX
MOVQ 16(SI), CX
MOVQ 24(SI), R8
// Add b with carry chain
ADDQ 0(DX), AX
ADCQ 8(DX), BX
ADCQ 16(DX), CX
ADCQ 24(DX), R8
// Save carry
SETCS R9B
// Store preliminary result
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
// Check if we need to reduce
TESTB R9B, R9B
JNZ field_reduce
// Compare with p (from high to low)
// p.Hi = 0xFFFFFFFFFFFFFFFF (all limbs except first)
// p.Lo = 0xFFFFFFFEFFFFFC2F
MOVQ $0xFFFFFFFFFFFFFFFF, R10
CMPQ R8, R10
JB field_done
JA field_reduce
CMPQ CX, R10
JB field_done
JA field_reduce
CMPQ BX, R10
JB field_done
JA field_reduce
MOVQ fieldP<>+0x00(SB), R10
CMPQ AX, R10
JB field_done
field_reduce:
// Subtract p by adding 2^256 - p = 0x1000003D1
MOVQ 0(DI), AX
MOVQ 8(DI), BX
MOVQ 16(DI), CX
MOVQ 24(DI), R8
MOVQ fieldPC<>+0x00(SB), R10
ADDQ R10, AX
ADCQ $0, BX
ADCQ $0, CX
ADCQ $0, R8
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
field_done:
VZEROUPPER
RET
// func FieldSubAVX2(r, a, b *FieldElement)
// Subtracts two 256-bit field elements mod p.
TEXT ·FieldSubAVX2(SB), NOSPLIT, $0-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// Load a
MOVQ 0(SI), AX
MOVQ 8(SI), BX
MOVQ 16(SI), CX
MOVQ 24(SI), R8
// Subtract b with borrow chain
SUBQ 0(DX), AX
SBBQ 8(DX), BX
SBBQ 16(DX), CX
SBBQ 24(DX), R8
// Save borrow
SETCS R9B
// Store preliminary result
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
// If borrow, add p back
TESTB R9B, R9B
JZ field_sub_done
// Add p from memory
MOVQ fieldP<>+0x00(SB), R10
ADDQ R10, AX
MOVQ fieldP<>+0x08(SB), R10
ADCQ R10, BX
MOVQ fieldP<>+0x10(SB), R10
ADCQ R10, CX
MOVQ fieldP<>+0x18(SB), R10
ADCQ R10, R8
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
field_sub_done:
VZEROUPPER
RET
// func FieldMulAVX2(r, a, b *FieldElement)
// Multiplies two 256-bit field elements mod p.
TEXT ·FieldMulAVX2(SB), NOSPLIT, $64-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// Load a limbs
MOVQ 0(SI), R8 // a0
MOVQ 8(SI), R9 // a1
MOVQ 16(SI), R10 // a2
MOVQ 24(SI), R11 // a3
// Store b pointer
MOVQ DX, R12
// Initialize 512-bit product on stack
XORQ AX, AX
MOVQ AX, 0(SP)
MOVQ AX, 8(SP)
MOVQ AX, 16(SP)
MOVQ AX, 24(SP)
MOVQ AX, 32(SP)
MOVQ AX, 40(SP)
MOVQ AX, 48(SP)
MOVQ AX, 56(SP)
// Schoolbook multiplication (same as scalar, but with field reduction)
// a0 * b[0..3]
MOVQ R8, AX
MULQ 0(R12)
MOVQ AX, 0(SP)
MOVQ DX, R13
MOVQ R8, AX
MULQ 8(R12)
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 8(SP)
MOVQ DX, R13
MOVQ R8, AX
MULQ 16(R12)
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 16(SP)
MOVQ DX, R13
MOVQ R8, AX
MULQ 24(R12)
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 24(SP)
MOVQ DX, 32(SP)
// a1 * b[0..3]
MOVQ R9, AX
MULQ 0(R12)
ADDQ AX, 8(SP)
ADCQ DX, 16(SP)
ADCQ $0, 24(SP)
ADCQ $0, 32(SP)
MOVQ R9, AX
MULQ 8(R12)
ADDQ AX, 16(SP)
ADCQ DX, 24(SP)
ADCQ $0, 32(SP)
MOVQ R9, AX
MULQ 16(R12)
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
MOVQ R9, AX
MULQ 24(R12)
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
// a2 * b[0..3]
MOVQ R10, AX
MULQ 0(R12)
ADDQ AX, 16(SP)
ADCQ DX, 24(SP)
ADCQ $0, 32(SP)
ADCQ $0, 40(SP)
MOVQ R10, AX
MULQ 8(R12)
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
MOVQ R10, AX
MULQ 16(R12)
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
ADCQ $0, 48(SP)
MOVQ R10, AX
MULQ 24(R12)
ADDQ AX, 40(SP)
ADCQ DX, 48(SP)
// a3 * b[0..3]
MOVQ R11, AX
MULQ 0(R12)
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
ADCQ $0, 48(SP)
MOVQ R11, AX
MULQ 8(R12)
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
ADCQ $0, 48(SP)
MOVQ R11, AX
MULQ 16(R12)
ADDQ AX, 40(SP)
ADCQ DX, 48(SP)
ADCQ $0, 56(SP)
MOVQ R11, AX
MULQ 24(R12)
ADDQ AX, 48(SP)
ADCQ DX, 56(SP)
// Now reduce 512-bit product mod p
// Using 2^256 2^32 + 977 (mod p)
// high = [32(SP), 40(SP), 48(SP), 56(SP)]
// low = [0(SP), 8(SP), 16(SP), 24(SP)]
// result = low + high * (2^32 + 977)
// Multiply high * 0x1000003D1
MOVQ fieldPC<>+0x00(SB), R13
MOVQ 32(SP), AX
MULQ R13
MOVQ AX, R8 // reduction[0]
MOVQ DX, R14 // carry
MOVQ 40(SP), AX
MULQ R13
ADDQ R14, AX
ADCQ $0, DX
MOVQ AX, R9 // reduction[1]
MOVQ DX, R14
MOVQ 48(SP), AX
MULQ R13
ADDQ R14, AX
ADCQ $0, DX
MOVQ AX, R10 // reduction[2]
MOVQ DX, R14
MOVQ 56(SP), AX
MULQ R13
ADDQ R14, AX
ADCQ $0, DX
MOVQ AX, R11 // reduction[3]
MOVQ DX, R14 // reduction[4] (overflow)
// Add low + reduction
ADDQ 0(SP), R8
ADCQ 8(SP), R9
ADCQ 16(SP), R10
ADCQ 24(SP), R11
ADCQ $0, R14 // Capture any carry into R14
// If R14 is non-zero, reduce again
TESTQ R14, R14
JZ field_mul_check
// R14 * 0x1000003D1
MOVQ R14, AX
MULQ R13
ADDQ AX, R8
ADCQ DX, R9
ADCQ $0, R10
ADCQ $0, R11
field_mul_check:
// Check if result >= p and reduce if needed
MOVQ $0xFFFFFFFFFFFFFFFF, R15
CMPQ R11, R15
JB field_mul_store
JA field_mul_reduce2
CMPQ R10, R15
JB field_mul_store
JA field_mul_reduce2
CMPQ R9, R15
JB field_mul_store
JA field_mul_reduce2
MOVQ fieldP<>+0x00(SB), R15
CMPQ R8, R15
JB field_mul_store
field_mul_reduce2:
MOVQ fieldPC<>+0x00(SB), R15
ADDQ R15, R8
ADCQ $0, R9
ADCQ $0, R10
ADCQ $0, R11
field_mul_store:
MOVQ r+0(FP), DI
MOVQ R8, 0(DI)
MOVQ R9, 8(DI)
MOVQ R10, 16(DI)
MOVQ R11, 24(DI)
VZEROUPPER
RET
// func FieldSqrAVX2(r, a *FieldElement)
// Squares a 256-bit field element mod p.
// For now, just calls FieldMulAVX2(r, a, a)
TEXT ·FieldSqrAVX2(SB), NOSPLIT, $24-16
MOVQ r+0(FP), AX
MOVQ a+8(FP), BX
MOVQ AX, 0(SP)
MOVQ BX, 8(SP)
MOVQ BX, 16(SP)
CALL ·FieldMulAVX2(SB)
RET

29
avx/mulint_test.go Normal file
View File

@@ -0,0 +1,29 @@
package avx
import "testing"
func TestMulInt(t *testing.T) {
// Test 3 * X = X + X + X
var x, tripleX, addX FieldElement
x.N[0].Lo = 12345
tripleX.MulInt(&x, 3)
addX.Add(&x, &x)
addX.Add(&addX, &x)
if !tripleX.Equal(&addX) {
t.Errorf("3*X != X+X+X: MulInt=%+v, Add=%+v", tripleX, addX)
}
// Test 2 * Y = Y + Y
var y, doubleY, addY FieldElement
y.N[0].Lo = 0xFFFFFFFFFFFFFFFF
y.N[0].Hi = 0xFFFFFFFFFFFFFFFF
doubleY.MulInt(&y, 2)
addY.Add(&y, &y)
if !doubleY.Equal(&addY) {
t.Errorf("2*Y != Y+Y: MulInt=%+v, Add=%+v", doubleY, addY)
}
}

425
avx/point.go Normal file
View File

@@ -0,0 +1,425 @@
package avx
// Point operations on the secp256k1 curve.
// Affine: (x, y) where y² = x³ + 7
// Jacobian: (X, Y, Z) where affine = (X/Z², Y/Z³)
// SetInfinity sets the point to the point at infinity.
func (p *AffinePoint) SetInfinity() *AffinePoint {
p.X = FieldZero
p.Y = FieldZero
p.Infinity = true
return p
}
// IsInfinity returns true if the point is the point at infinity.
func (p *AffinePoint) IsInfinity() bool {
return p.Infinity
}
// Set sets p to the value of q.
func (p *AffinePoint) Set(q *AffinePoint) *AffinePoint {
p.X = q.X
p.Y = q.Y
p.Infinity = q.Infinity
return p
}
// Equal returns true if two points are equal.
func (p *AffinePoint) Equal(q *AffinePoint) bool {
if p.Infinity && q.Infinity {
return true
}
if p.Infinity || q.Infinity {
return false
}
return p.X.Equal(&q.X) && p.Y.Equal(&q.Y)
}
// Negate sets p = -q (reflection over x-axis).
func (p *AffinePoint) Negate(q *AffinePoint) *AffinePoint {
if q.Infinity {
p.SetInfinity()
return p
}
p.X = q.X
p.Y.Negate(&q.Y)
p.Infinity = false
return p
}
// IsOnCurve returns true if the point is on the secp256k1 curve.
func (p *AffinePoint) IsOnCurve() bool {
if p.Infinity {
return true
}
// Check y² = x³ + 7
var y2, x2, x3, rhs FieldElement
y2.Sqr(&p.Y)
x2.Sqr(&p.X)
x3.Mul(&x2, &p.X)
// rhs = x³ + 7
var seven FieldElement
seven.N[0].Lo = 7
rhs.Add(&x3, &seven)
return y2.Equal(&rhs)
}
// SetXY sets the point to (x, y).
func (p *AffinePoint) SetXY(x, y *FieldElement) *AffinePoint {
p.X = *x
p.Y = *y
p.Infinity = false
return p
}
// SetCompressed sets the point from compressed form (x coordinate + sign bit).
// Returns true if successful.
func (p *AffinePoint) SetCompressed(x *FieldElement, odd bool) bool {
// Compute y² = x³ + 7
var y2, x2, x3 FieldElement
x2.Sqr(x)
x3.Mul(&x2, x)
// y² = x³ + 7
var seven FieldElement
seven.N[0].Lo = 7
y2.Add(&x3, &seven)
// Compute y = sqrt(y²)
var y FieldElement
if !y.Sqrt(&y2) {
return false // No square root exists
}
// Choose the correct sign
if y.IsOdd() != odd {
y.Negate(&y)
}
p.X = *x
p.Y = y
p.Infinity = false
return true
}
// Jacobian point operations
// SetInfinity sets the Jacobian point to the point at infinity.
func (p *JacobianPoint) SetInfinity() *JacobianPoint {
p.X = FieldOne
p.Y = FieldOne
p.Z = FieldZero
p.Infinity = true
return p
}
// IsInfinity returns true if the point is the point at infinity.
func (p *JacobianPoint) IsInfinity() bool {
return p.Infinity || p.Z.IsZero()
}
// Set sets p to the value of q.
func (p *JacobianPoint) Set(q *JacobianPoint) *JacobianPoint {
p.X = q.X
p.Y = q.Y
p.Z = q.Z
p.Infinity = q.Infinity
return p
}
// FromAffine converts an affine point to Jacobian coordinates.
func (p *JacobianPoint) FromAffine(q *AffinePoint) *JacobianPoint {
if q.Infinity {
p.SetInfinity()
return p
}
p.X = q.X
p.Y = q.Y
p.Z = FieldOne
p.Infinity = false
return p
}
// ToAffine converts a Jacobian point to affine coordinates.
func (p *JacobianPoint) ToAffine(q *AffinePoint) *AffinePoint {
if p.IsInfinity() {
q.SetInfinity()
return q
}
// affine = (X/Z², Y/Z³)
var zInv, zInv2, zInv3 FieldElement
zInv.Inverse(&p.Z)
zInv2.Sqr(&zInv)
zInv3.Mul(&zInv2, &zInv)
q.X.Mul(&p.X, &zInv2)
q.Y.Mul(&p.Y, &zInv3)
q.Infinity = false
return q
}
// Double sets p = 2*q using Jacobian coordinates.
// Standard Jacobian doubling for y²=x³+b (secp256k1 has a=0):
// M = 3*X₁²
// S = 4*X₁*Y₁²
// T = 8*Y₁⁴
// X₃ = M² - 2*S
// Y₃ = M*(S - X₃) - T
// Z₃ = 2*Y₁*Z₁
func (p *JacobianPoint) Double(q *JacobianPoint) *JacobianPoint {
if q.IsInfinity() {
p.SetInfinity()
return p
}
var y2, m, x2, s, t, tmp FieldElement
var x3, y3, z3 FieldElement // Use temporaries to avoid aliasing issues
// Y² = Y₁²
y2.Sqr(&q.Y)
// M = 3*X₁² (for a=0 curves like secp256k1)
x2.Sqr(&q.X)
m.MulInt(&x2, 3)
// S = 4*X₁*Y₁²
s.Mul(&q.X, &y2)
s.MulInt(&s, 4)
// T = 8*Y₁⁴
t.Sqr(&y2)
t.MulInt(&t, 8)
// X₃ = M² - 2*S
x3.Sqr(&m)
tmp.Double(&s)
x3.Sub(&x3, &tmp)
// Y₃ = M*(S - X₃) - T
tmp.Sub(&s, &x3)
y3.Mul(&m, &tmp)
y3.Sub(&y3, &t)
// Z₃ = 2*Y₁*Z₁
z3.Mul(&q.Y, &q.Z)
z3.Double(&z3)
// Now copy to output (safe even if p == q)
p.X = x3
p.Y = y3
p.Z = z3
p.Infinity = false
return p
}
// Add sets p = q + r using Jacobian coordinates.
// This is the complete addition formula.
func (p *JacobianPoint) Add(q, r *JacobianPoint) *JacobianPoint {
if q.IsInfinity() {
p.Set(r)
return p
}
if r.IsInfinity() {
p.Set(q)
return p
}
// Algorithm:
// U₁ = X₁*Z₂²
// U₂ = X₂*Z₁²
// S₁ = Y₁*Z₂³
// S₂ = Y₂*Z₁³
// H = U₂ - U₁
// R = S₂ - S₁
// If H = 0 and R = 0: return Double(q)
// If H = 0 and R ≠ 0: return Infinity
// X₃ = R² - H³ - 2*U₁*H²
// Y₃ = R*(U₁*H² - X₃) - S₁*H³
// Z₃ = H*Z₁*Z₂
var u1, u2, s1, s2, h, rr, h2, h3, u1h2 FieldElement
var z1sq, z2sq, z1cu, z2cu FieldElement
var x3, y3, z3 FieldElement // Use temporaries to avoid aliasing issues
z1sq.Sqr(&q.Z)
z2sq.Sqr(&r.Z)
z1cu.Mul(&z1sq, &q.Z)
z2cu.Mul(&z2sq, &r.Z)
u1.Mul(&q.X, &z2sq)
u2.Mul(&r.X, &z1sq)
s1.Mul(&q.Y, &z2cu)
s2.Mul(&r.Y, &z1cu)
h.Sub(&u2, &u1)
rr.Sub(&s2, &s1)
// Check for special cases
if h.IsZero() {
if rr.IsZero() {
// Points are equal, use doubling
return p.Double(q)
}
// Points are inverses, return infinity
p.SetInfinity()
return p
}
h2.Sqr(&h)
h3.Mul(&h2, &h)
u1h2.Mul(&u1, &h2)
// X₃ = R² - H³ - 2*U₁*H²
var r2, u1h2_2 FieldElement
r2.Sqr(&rr)
u1h2_2.Double(&u1h2)
x3.Sub(&r2, &h3)
x3.Sub(&x3, &u1h2_2)
// Y₃ = R*(U₁*H² - X₃) - S₁*H³
var tmp, s1h3 FieldElement
tmp.Sub(&u1h2, &x3)
y3.Mul(&rr, &tmp)
s1h3.Mul(&s1, &h3)
y3.Sub(&y3, &s1h3)
// Z₃ = H*Z₁*Z₂
z3.Mul(&q.Z, &r.Z)
z3.Mul(&z3, &h)
// Now copy to output (safe even if p == q or p == r)
p.X = x3
p.Y = y3
p.Z = z3
p.Infinity = false
return p
}
// AddAffine sets p = q + r where q is Jacobian and r is affine.
// More efficient than converting r to Jacobian first.
func (p *JacobianPoint) AddAffine(q *JacobianPoint, r *AffinePoint) *JacobianPoint {
if q.IsInfinity() {
p.FromAffine(r)
return p
}
if r.Infinity {
p.Set(q)
return p
}
// When Z₂ = 1 (affine point), formulas simplify:
// U₁ = X₁
// U₂ = X₂*Z₁²
// S₁ = Y₁
// S₂ = Y₂*Z₁³
var u2, s2, h, rr, h2, h3, u1h2 FieldElement
var z1sq, z1cu FieldElement
var x3, y3, z3 FieldElement // Use temporaries to avoid aliasing issues
z1sq.Sqr(&q.Z)
z1cu.Mul(&z1sq, &q.Z)
u2.Mul(&r.X, &z1sq)
s2.Mul(&r.Y, &z1cu)
h.Sub(&u2, &q.X)
rr.Sub(&s2, &q.Y)
if h.IsZero() {
if rr.IsZero() {
return p.Double(q)
}
p.SetInfinity()
return p
}
h2.Sqr(&h)
h3.Mul(&h2, &h)
u1h2.Mul(&q.X, &h2)
// X₃ = R² - H³ - 2*U₁*H²
var r2, u1h2_2 FieldElement
r2.Sqr(&rr)
u1h2_2.Double(&u1h2)
x3.Sub(&r2, &h3)
x3.Sub(&x3, &u1h2_2)
// Y₃ = R*(U₁*H² - X₃) - S₁*H³
var tmp, s1h3 FieldElement
tmp.Sub(&u1h2, &x3)
y3.Mul(&rr, &tmp)
s1h3.Mul(&q.Y, &h3)
y3.Sub(&y3, &s1h3)
// Z₃ = H*Z₁
z3.Mul(&q.Z, &h)
// Now copy to output (safe even if p == q)
p.X = x3
p.Y = y3
p.Z = z3
p.Infinity = false
return p
}
// Negate sets p = -q (reflection over x-axis).
func (p *JacobianPoint) Negate(q *JacobianPoint) *JacobianPoint {
if q.IsInfinity() {
p.SetInfinity()
return p
}
p.X = q.X
p.Y.Negate(&q.Y)
p.Z = q.Z
p.Infinity = false
return p
}
// ScalarMult computes p = k*q using double-and-add.
func (p *JacobianPoint) ScalarMult(q *JacobianPoint, k *Scalar) *JacobianPoint {
// Simple double-and-add (not constant-time)
// A proper implementation would use windowed NAF or similar
p.SetInfinity()
// Process bits from high to low
bytes := k.Bytes()
for i := 0; i < 32; i++ {
b := bytes[i]
for j := 7; j >= 0; j-- {
p.Double(p)
if (b>>j)&1 == 1 {
p.Add(p, q)
}
}
}
return p
}
// ScalarBaseMult computes p = k*G where G is the generator.
func (p *JacobianPoint) ScalarBaseMult(k *Scalar) *JacobianPoint {
var g JacobianPoint
g.FromAffine(&Generator)
return p.ScalarMult(&g, k)
}
// BasePointMult computes k*G and returns the result in affine coordinates.
func BasePointMult(k *Scalar) *AffinePoint {
var jac JacobianPoint
var aff AffinePoint
jac.ScalarBaseMult(k)
jac.ToAffine(&aff)
return &aff
}

425
avx/scalar.go Normal file
View File

@@ -0,0 +1,425 @@
package avx
import "math/bits"
// Scalar operations modulo the secp256k1 group order n.
// n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
// SetBytes sets a scalar from a 32-byte big-endian slice.
// Returns true if the value was >= n and was reduced.
func (s *Scalar) SetBytes(b []byte) bool {
if len(b) != 32 {
panic("scalar must be 32 bytes")
}
// Convert big-endian bytes to little-endian limbs
s.D[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
s.D[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
s.D[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
s.D[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
// Check overflow and reduce if necessary
overflow := s.checkOverflow()
if overflow {
s.reduce()
}
return overflow
}
// Bytes returns the scalar as a 32-byte big-endian slice.
func (s *Scalar) Bytes() [32]byte {
var b [32]byte
b[31] = byte(s.D[0].Lo)
b[30] = byte(s.D[0].Lo >> 8)
b[29] = byte(s.D[0].Lo >> 16)
b[28] = byte(s.D[0].Lo >> 24)
b[27] = byte(s.D[0].Lo >> 32)
b[26] = byte(s.D[0].Lo >> 40)
b[25] = byte(s.D[0].Lo >> 48)
b[24] = byte(s.D[0].Lo >> 56)
b[23] = byte(s.D[0].Hi)
b[22] = byte(s.D[0].Hi >> 8)
b[21] = byte(s.D[0].Hi >> 16)
b[20] = byte(s.D[0].Hi >> 24)
b[19] = byte(s.D[0].Hi >> 32)
b[18] = byte(s.D[0].Hi >> 40)
b[17] = byte(s.D[0].Hi >> 48)
b[16] = byte(s.D[0].Hi >> 56)
b[15] = byte(s.D[1].Lo)
b[14] = byte(s.D[1].Lo >> 8)
b[13] = byte(s.D[1].Lo >> 16)
b[12] = byte(s.D[1].Lo >> 24)
b[11] = byte(s.D[1].Lo >> 32)
b[10] = byte(s.D[1].Lo >> 40)
b[9] = byte(s.D[1].Lo >> 48)
b[8] = byte(s.D[1].Lo >> 56)
b[7] = byte(s.D[1].Hi)
b[6] = byte(s.D[1].Hi >> 8)
b[5] = byte(s.D[1].Hi >> 16)
b[4] = byte(s.D[1].Hi >> 24)
b[3] = byte(s.D[1].Hi >> 32)
b[2] = byte(s.D[1].Hi >> 40)
b[1] = byte(s.D[1].Hi >> 48)
b[0] = byte(s.D[1].Hi >> 56)
return b
}
// IsZero returns true if the scalar is zero.
func (s *Scalar) IsZero() bool {
return s.D[0].IsZero() && s.D[1].IsZero()
}
// IsOne returns true if the scalar is one.
func (s *Scalar) IsOne() bool {
return s.D[0].Lo == 1 && s.D[0].Hi == 0 && s.D[1].IsZero()
}
// Equal returns true if two scalars are equal.
func (s *Scalar) Equal(other *Scalar) bool {
return s.D[0].Lo == other.D[0].Lo && s.D[0].Hi == other.D[0].Hi &&
s.D[1].Lo == other.D[1].Lo && s.D[1].Hi == other.D[1].Hi
}
// checkOverflow returns true if s >= n.
func (s *Scalar) checkOverflow() bool {
// Compare high to low
if s.D[1].Hi > ScalarN.D[1].Hi {
return true
}
if s.D[1].Hi < ScalarN.D[1].Hi {
return false
}
if s.D[1].Lo > ScalarN.D[1].Lo {
return true
}
if s.D[1].Lo < ScalarN.D[1].Lo {
return false
}
if s.D[0].Hi > ScalarN.D[0].Hi {
return true
}
if s.D[0].Hi < ScalarN.D[0].Hi {
return false
}
return s.D[0].Lo >= ScalarN.D[0].Lo
}
// reduce reduces s modulo n by adding the complement (2^256 - n).
func (s *Scalar) reduce() {
// s = s - n = s + (2^256 - n) mod 2^256
var carry uint64
s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, carry)
s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, carry)
s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, carry)
}
// Add sets s = a + b mod n.
func (s *Scalar) Add(a, b *Scalar) *Scalar {
var carry uint64
s.D[0].Lo, carry = bits.Add64(a.D[0].Lo, b.D[0].Lo, 0)
s.D[0].Hi, carry = bits.Add64(a.D[0].Hi, b.D[0].Hi, carry)
s.D[1].Lo, carry = bits.Add64(a.D[1].Lo, b.D[1].Lo, carry)
s.D[1].Hi, carry = bits.Add64(a.D[1].Hi, b.D[1].Hi, carry)
// If there was a carry or if result >= n, reduce
if carry != 0 || s.checkOverflow() {
s.reduce()
}
return s
}
// Sub sets s = a - b mod n.
func (s *Scalar) Sub(a, b *Scalar) *Scalar {
var borrow uint64
s.D[0].Lo, borrow = bits.Sub64(a.D[0].Lo, b.D[0].Lo, 0)
s.D[0].Hi, borrow = bits.Sub64(a.D[0].Hi, b.D[0].Hi, borrow)
s.D[1].Lo, borrow = bits.Sub64(a.D[1].Lo, b.D[1].Lo, borrow)
s.D[1].Hi, borrow = bits.Sub64(a.D[1].Hi, b.D[1].Hi, borrow)
// If there was a borrow, add n back
if borrow != 0 {
var carry uint64
s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarN.D[0].Lo, 0)
s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarN.D[0].Hi, carry)
s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarN.D[1].Lo, carry)
s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarN.D[1].Hi, carry)
}
return s
}
// Negate sets s = -a mod n.
func (s *Scalar) Negate(a *Scalar) *Scalar {
if a.IsZero() {
*s = ScalarZero
return s
}
// s = n - a
var borrow uint64
s.D[0].Lo, borrow = bits.Sub64(ScalarN.D[0].Lo, a.D[0].Lo, 0)
s.D[0].Hi, borrow = bits.Sub64(ScalarN.D[0].Hi, a.D[0].Hi, borrow)
s.D[1].Lo, borrow = bits.Sub64(ScalarN.D[1].Lo, a.D[1].Lo, borrow)
s.D[1].Hi, _ = bits.Sub64(ScalarN.D[1].Hi, a.D[1].Hi, borrow)
return s
}
// Mul sets s = a * b mod n.
func (s *Scalar) Mul(a, b *Scalar) *Scalar {
// Compute 512-bit product
var prod [8]uint64
scalarMul512(&prod, a, b)
// Reduce mod n
scalarReduce512(s, &prod)
return s
}
// scalarMul512 computes the 512-bit product of two 256-bit scalars.
// Result is stored in prod[0..7] where prod[0] is the least significant.
func scalarMul512(prod *[8]uint64, a, b *Scalar) {
// Using schoolbook multiplication with 64-bit limbs
// a = a[0] + a[1]*2^64 + a[2]*2^128 + a[3]*2^192
// b = b[0] + b[1]*2^64 + b[2]*2^128 + b[3]*2^192
aLimbs := [4]uint64{a.D[0].Lo, a.D[0].Hi, a.D[1].Lo, a.D[1].Hi}
bLimbs := [4]uint64{b.D[0].Lo, b.D[0].Hi, b.D[1].Lo, b.D[1].Hi}
// Clear product
for i := range prod {
prod[i] = 0
}
// Schoolbook multiplication
for i := 0; i < 4; i++ {
var carry uint64
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
lo, c := bits.Add64(lo, prod[i+j], 0)
hi, _ = bits.Add64(hi, 0, c)
lo, c = bits.Add64(lo, carry, 0)
hi, _ = bits.Add64(hi, 0, c)
prod[i+j] = lo
carry = hi
}
prod[i+4] = carry
}
}
// scalarReduce512 reduces a 512-bit value mod n.
func scalarReduce512(s *Scalar, prod *[8]uint64) {
// Barrett reduction or simple repeated subtraction
// For now, use a simpler approach: extract high 256 bits, multiply by (2^256 mod n), add to low
// 2^256 mod n = 2^256 - n = ScalarNC (approximately 0x14551231950B75FC4...etc)
// This is a simplified reduction - a full implementation would use Barrett reduction
// Copy low 256 bits to result
s.D[0].Lo = prod[0]
s.D[0].Hi = prod[1]
s.D[1].Lo = prod[2]
s.D[1].Hi = prod[3]
// If high 256 bits are non-zero, we need to reduce
if prod[4] != 0 || prod[5] != 0 || prod[6] != 0 || prod[7] != 0 {
// high * (2^256 mod n) + low
// This is a simplified version - multiply high by NC and add
highScalar := Scalar{
D: [2]Uint128{
{Lo: prod[4], Hi: prod[5]},
{Lo: prod[6], Hi: prod[7]},
},
}
// Multiply high by NC (which is small: ~2^129)
// For correctness, we'd need full multiplication, but NC is small enough
// that we can use a simplified approach
// NC = 0x14551231950B75FC4402DA1732FC9BEBF
// NC.D[0] = {Lo: 0x402DA1732FC9BEBF, Hi: 0x4551231950B75FC4}
// NC.D[1] = {Lo: 0x1, Hi: 0}
// Approximate: high * NC ≈ high * 2^129 (since NC ≈ 2^129)
// This means we shift high left by 129 bits and add
// For a correct implementation, compute high * NC properly:
var reduction [8]uint64
ncLimbs := [4]uint64{ScalarNC.D[0].Lo, ScalarNC.D[0].Hi, ScalarNC.D[1].Lo, ScalarNC.D[1].Hi}
highLimbs := [4]uint64{highScalar.D[0].Lo, highScalar.D[0].Hi, highScalar.D[1].Lo, highScalar.D[1].Hi}
for i := 0; i < 4; i++ {
var carry uint64
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(highLimbs[i], ncLimbs[j])
lo, c := bits.Add64(lo, reduction[i+j], 0)
hi, _ = bits.Add64(hi, 0, c)
lo, c = bits.Add64(lo, carry, 0)
hi, _ = bits.Add64(hi, 0, c)
reduction[i+j] = lo
carry = hi
}
if i+4 < 8 {
reduction[i+4], _ = bits.Add64(reduction[i+4], carry, 0)
}
}
// Add reduction to s
var carry uint64
s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, reduction[0], 0)
s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, reduction[1], carry)
s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, reduction[2], carry)
s.D[1].Hi, carry = bits.Add64(s.D[1].Hi, reduction[3], carry)
// Handle any remaining high bits by repeated reduction
// If there's a carry, it represents 2^256 which equals NC mod n
// If reduction[4..7] are non-zero, we need to reduce those too
if carry != 0 || reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
// The carry and reduction[4..7] together represent additional multiples of 2^256
// Each 2^256 ≡ NC (mod n), so we add (carry + reduction[4..7]) * NC
// First, handle the carry
if carry != 0 {
// carry * NC
var c uint64
s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
// If there's still a carry, add NC again
for c != 0 {
s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
}
}
// Handle reduction[4..7] if non-zero
if reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
// Compute reduction[4..7] * NC and add
highScalar2 := Scalar{
D: [2]Uint128{
{Lo: reduction[4], Hi: reduction[5]},
{Lo: reduction[6], Hi: reduction[7]},
},
}
var reduction2 [8]uint64
high2Limbs := [4]uint64{highScalar2.D[0].Lo, highScalar2.D[0].Hi, highScalar2.D[1].Lo, highScalar2.D[1].Hi}
for i := 0; i < 4; i++ {
var c uint64
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(high2Limbs[i], ncLimbs[j])
lo, cc := bits.Add64(lo, reduction2[i+j], 0)
hi, _ = bits.Add64(hi, 0, cc)
lo, cc = bits.Add64(lo, c, 0)
hi, _ = bits.Add64(hi, 0, cc)
reduction2[i+j] = lo
c = hi
}
if i+4 < 8 {
reduction2[i+4], _ = bits.Add64(reduction2[i+4], c, 0)
}
}
var c uint64
s.D[0].Lo, c = bits.Add64(s.D[0].Lo, reduction2[0], 0)
s.D[0].Hi, c = bits.Add64(s.D[0].Hi, reduction2[1], c)
s.D[1].Lo, c = bits.Add64(s.D[1].Lo, reduction2[2], c)
s.D[1].Hi, c = bits.Add64(s.D[1].Hi, reduction2[3], c)
// Handle cascading carries
for c != 0 || reduction2[4] != 0 || reduction2[5] != 0 || reduction2[6] != 0 || reduction2[7] != 0 {
// This case is extremely rare but handle it
for s.checkOverflow() {
s.reduce()
}
break
}
}
}
}
// Final reduction if needed
if s.checkOverflow() {
s.reduce()
}
}
// Sqr sets s = a^2 mod n.
func (s *Scalar) Sqr(a *Scalar) *Scalar {
return s.Mul(a, a)
}
// Inverse sets s = a^(-1) mod n using Fermat's little theorem.
// a^(-1) = a^(n-2) mod n
func (s *Scalar) Inverse(a *Scalar) *Scalar {
// n-2 in binary is used for square-and-multiply
// This is a simplified implementation using binary exponentiation
var result, base Scalar
result = ScalarOne
base = *a
// n-2 bytes (big-endian)
nMinus2 := [32]byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x3F,
}
for i := 0; i < 32; i++ {
b := nMinus2[31-i]
for j := 0; j < 8; j++ {
if (b>>j)&1 == 1 {
result.Mul(&result, &base)
}
base.Sqr(&base)
}
}
*s = result
return s
}
// IsHigh returns true if s > n/2.
func (s *Scalar) IsHigh() bool {
// Compare with n/2
if s.D[1].Hi > ScalarNHalf.D[1].Hi {
return true
}
if s.D[1].Hi < ScalarNHalf.D[1].Hi {
return false
}
if s.D[1].Lo > ScalarNHalf.D[1].Lo {
return true
}
if s.D[1].Lo < ScalarNHalf.D[1].Lo {
return false
}
if s.D[0].Hi > ScalarNHalf.D[0].Hi {
return true
}
if s.D[0].Hi < ScalarNHalf.D[0].Hi {
return false
}
return s.D[0].Lo > ScalarNHalf.D[0].Lo
}
// CondNegate negates s if cond is true.
func (s *Scalar) CondNegate(cond bool) *Scalar {
if cond {
s.Negate(s)
}
return s
}

27
avx/scalar_amd64.go Normal file
View File

@@ -0,0 +1,27 @@
//go:build amd64
package avx
// AMD64-specific scalar operations with AVX2 assembly.
// ScalarAddAVX2 adds two scalars using AVX2.
// This loads both scalars into YMM registers and performs parallel addition.
//
//go:noescape
func ScalarAddAVX2(r, a, b *Scalar)
// ScalarSubAVX2 subtracts two scalars using AVX2.
//
//go:noescape
func ScalarSubAVX2(r, a, b *Scalar)
// ScalarMulAVX2 multiplies two scalars using AVX2.
// Computes 512-bit product and reduces mod n.
//
//go:noescape
func ScalarMulAVX2(r, a, b *Scalar)
// hasAVX2 returns true if the CPU supports AVX2.
//
//go:noescape
func hasAVX2() bool

515
avx/scalar_amd64.s Normal file
View File

@@ -0,0 +1,515 @@
//go:build amd64
#include "textflag.h"
// Constants for scalar reduction
// n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
DATA scalarN<>+0x00(SB)/8, $0xBFD25E8CD0364141
DATA scalarN<>+0x08(SB)/8, $0xBAAEDCE6AF48A03B
DATA scalarN<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFE
DATA scalarN<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF
GLOBL scalarN<>(SB), RODATA|NOPTR, $32
// 2^256 - n (for reduction)
DATA scalarNC<>+0x00(SB)/8, $0x402DA1732FC9BEBF
DATA scalarNC<>+0x08(SB)/8, $0x4551231950B75FC4
DATA scalarNC<>+0x10(SB)/8, $0x0000000000000001
DATA scalarNC<>+0x18(SB)/8, $0x0000000000000000
GLOBL scalarNC<>(SB), RODATA|NOPTR, $32
// func hasAVX2() bool
TEXT ·hasAVX2(SB), NOSPLIT, $0-1
MOVL $7, AX
MOVL $0, CX
CPUID
ANDL $0x20, BX // Check bit 5 of EBX for AVX2
SETNE AL
MOVB AL, ret+0(FP)
RET
// func ScalarAddAVX2(r, a, b *Scalar)
// Adds two 256-bit scalars using AVX2 for loading/storing and scalar ADD with carry.
//
// YMM layout: [D[0].Lo, D[0].Hi, D[1].Lo, D[1].Hi] = 4 x 64-bit
TEXT ·ScalarAddAVX2(SB), NOSPLIT, $0-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// Load a and b into registers (scalar loads for carry chain)
MOVQ 0(SI), AX // a.D[0].Lo
MOVQ 8(SI), BX // a.D[0].Hi
MOVQ 16(SI), CX // a.D[1].Lo
MOVQ 24(SI), R8 // a.D[1].Hi
// Add b with carry chain
ADDQ 0(DX), AX // a.D[0].Lo + b.D[0].Lo
ADCQ 8(DX), BX // a.D[0].Hi + b.D[0].Hi + carry
ADCQ 16(DX), CX // a.D[1].Lo + b.D[1].Lo + carry
ADCQ 24(DX), R8 // a.D[1].Hi + b.D[1].Hi + carry
// Save carry flag
SETCS R9B
// Store preliminary result
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
// Check if we need to reduce (carry set or result >= n)
TESTB R9B, R9B
JNZ reduce
// Compare with n (from high to low)
MOVQ $0xFFFFFFFFFFFFFFFF, R10
CMPQ R8, R10
JB done
JA reduce
MOVQ scalarN<>+0x10(SB), R10
CMPQ CX, R10
JB done
JA reduce
MOVQ scalarN<>+0x08(SB), R10
CMPQ BX, R10
JB done
JA reduce
MOVQ scalarN<>+0x00(SB), R10
CMPQ AX, R10
JB done
reduce:
// Add 2^256 - n (which is equivalent to subtracting n)
MOVQ 0(DI), AX
MOVQ 8(DI), BX
MOVQ 16(DI), CX
MOVQ 24(DI), R8
MOVQ scalarNC<>+0x00(SB), R10
ADDQ R10, AX
MOVQ scalarNC<>+0x08(SB), R10
ADCQ R10, BX
MOVQ scalarNC<>+0x10(SB), R10
ADCQ R10, CX
MOVQ scalarNC<>+0x18(SB), R10
ADCQ R10, R8
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
done:
VZEROUPPER
RET
// func ScalarSubAVX2(r, a, b *Scalar)
// Subtracts two 256-bit scalars.
TEXT ·ScalarSubAVX2(SB), NOSPLIT, $0-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// Load a
MOVQ 0(SI), AX
MOVQ 8(SI), BX
MOVQ 16(SI), CX
MOVQ 24(SI), R8
// Subtract b with borrow chain
SUBQ 0(DX), AX
SBBQ 8(DX), BX
SBBQ 16(DX), CX
SBBQ 24(DX), R8
// Save borrow flag
SETCS R9B
// Store preliminary result
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
// If borrow, add n back
TESTB R9B, R9B
JZ done_sub
// Add n
MOVQ scalarN<>+0x00(SB), R10
ADDQ R10, AX
MOVQ scalarN<>+0x08(SB), R10
ADCQ R10, BX
MOVQ scalarN<>+0x10(SB), R10
ADCQ R10, CX
MOVQ scalarN<>+0x18(SB), R10
ADCQ R10, R8
MOVQ AX, 0(DI)
MOVQ BX, 8(DI)
MOVQ CX, 16(DI)
MOVQ R8, 24(DI)
done_sub:
VZEROUPPER
RET
// func ScalarMulAVX2(r, a, b *Scalar)
// Multiplies two 256-bit scalars and reduces mod n.
// This is a complex operation requiring 512-bit intermediate.
TEXT ·ScalarMulAVX2(SB), NOSPLIT, $64-24
MOVQ r+0(FP), DI
MOVQ a+8(FP), SI
MOVQ b+16(FP), DX
// We need to compute a 512-bit product and reduce mod n.
// For now, use scalar multiplication with MULX (if BMI2 available) or MUL.
// Load a limbs
MOVQ 0(SI), R8 // a0
MOVQ 8(SI), R9 // a1
MOVQ 16(SI), R10 // a2
MOVQ 24(SI), R11 // a3
// Store b pointer for later use
MOVQ DX, R12
// Compute 512-bit product using schoolbook multiplication
// Product stored on stack at SP+0 to SP+56 (8 limbs)
// Initialize product to zero
XORQ AX, AX
MOVQ AX, 0(SP)
MOVQ AX, 8(SP)
MOVQ AX, 16(SP)
MOVQ AX, 24(SP)
MOVQ AX, 32(SP)
MOVQ AX, 40(SP)
MOVQ AX, 48(SP)
MOVQ AX, 56(SP)
// Multiply a0 * b[0..3]
MOVQ R8, AX
MULQ 0(R12) // a0 * b0
MOVQ AX, 0(SP)
MOVQ DX, R13 // carry
MOVQ R8, AX
MULQ 8(R12) // a0 * b1
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 8(SP)
MOVQ DX, R13
MOVQ R8, AX
MULQ 16(R12) // a0 * b2
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 16(SP)
MOVQ DX, R13
MOVQ R8, AX
MULQ 24(R12) // a0 * b3
ADDQ R13, AX
ADCQ $0, DX
MOVQ AX, 24(SP)
MOVQ DX, 32(SP)
// Multiply a1 * b[0..3] and add
MOVQ R9, AX
MULQ 0(R12) // a1 * b0
ADDQ AX, 8(SP)
ADCQ DX, 16(SP)
ADCQ $0, 24(SP)
ADCQ $0, 32(SP)
MOVQ R9, AX
MULQ 8(R12) // a1 * b1
ADDQ AX, 16(SP)
ADCQ DX, 24(SP)
ADCQ $0, 32(SP)
MOVQ R9, AX
MULQ 16(R12) // a1 * b2
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
MOVQ R9, AX
MULQ 24(R12) // a1 * b3
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
// Multiply a2 * b[0..3] and add
MOVQ R10, AX
MULQ 0(R12) // a2 * b0
ADDQ AX, 16(SP)
ADCQ DX, 24(SP)
ADCQ $0, 32(SP)
ADCQ $0, 40(SP)
MOVQ R10, AX
MULQ 8(R12) // a2 * b1
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
MOVQ R10, AX
MULQ 16(R12) // a2 * b2
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
ADCQ $0, 48(SP)
MOVQ R10, AX
MULQ 24(R12) // a2 * b3
ADDQ AX, 40(SP)
ADCQ DX, 48(SP)
// Multiply a3 * b[0..3] and add
MOVQ R11, AX
MULQ 0(R12) // a3 * b0
ADDQ AX, 24(SP)
ADCQ DX, 32(SP)
ADCQ $0, 40(SP)
ADCQ $0, 48(SP)
MOVQ R11, AX
MULQ 8(R12) // a3 * b1
ADDQ AX, 32(SP)
ADCQ DX, 40(SP)
ADCQ $0, 48(SP)
MOVQ R11, AX
MULQ 16(R12) // a3 * b2
ADDQ AX, 40(SP)
ADCQ DX, 48(SP)
ADCQ $0, 56(SP)
MOVQ R11, AX
MULQ 24(R12) // a3 * b3
ADDQ AX, 48(SP)
ADCQ DX, 56(SP)
// Now we have the 512-bit product in SP+0 to SP+56 (l[0..7])
// Need to reduce mod n using the bitcoin-core algorithm:
//
// Phase 1: 512->385 bits
// c0..c4 = l[0..3] + l[4..7] * NC (where NC = 2^256 - n)
// Phase 2: 385->258 bits
// d0..d4 = c[0..3] + c[4] * NC
// Phase 3: 258->256 bits
// r[0..3] = d[0..3] + d[4] * NC, then final reduce if >= n
//
// NC = [0x402DA1732FC9BEBF, 0x4551231950B75FC4, 1, 0]
// ========== Phase 1: 512->385 bits ==========
// Compute c[0..4] = l[0..3] + l[4..7] * NC
// NC has only 3 significant limbs: NC[0], NC[1], NC[2]=1
// Start with c = l[0..3], then add contributions from l[4..7] * NC
MOVQ 0(SP), R8 // c0 = l0
MOVQ 8(SP), R9 // c1 = l1
MOVQ 16(SP), R10 // c2 = l2
MOVQ 24(SP), R11 // c3 = l3
XORQ R14, R14 // c4 = 0
XORQ R15, R15 // c5 for overflow
// l4 * NC[0]
MOVQ 32(SP), AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12 // DX:AX = l4 * NC[0]
ADDQ AX, R8
ADCQ DX, R9
ADCQ $0, R10
ADCQ $0, R11
ADCQ $0, R14
// l4 * NC[1]
MOVQ 32(SP), AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12 // DX:AX = l4 * NC[1]
ADDQ AX, R9
ADCQ DX, R10
ADCQ $0, R11
ADCQ $0, R14
// l4 * NC[2] (NC[2] = 1)
MOVQ 32(SP), AX
ADDQ AX, R10
ADCQ $0, R11
ADCQ $0, R14
// l5 * NC[0]
MOVQ 40(SP), AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R9
ADCQ DX, R10
ADCQ $0, R11
ADCQ $0, R14
// l5 * NC[1]
MOVQ 40(SP), AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R10
ADCQ DX, R11
ADCQ $0, R14
// l5 * NC[2] (NC[2] = 1)
MOVQ 40(SP), AX
ADDQ AX, R11
ADCQ $0, R14
// l6 * NC[0]
MOVQ 48(SP), AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R10
ADCQ DX, R11
ADCQ $0, R14
// l6 * NC[1]
MOVQ 48(SP), AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R11
ADCQ DX, R14
// l6 * NC[2] (NC[2] = 1)
MOVQ 48(SP), AX
ADDQ AX, R14
ADCQ $0, R15
// l7 * NC[0]
MOVQ 56(SP), AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R11
ADCQ DX, R14
ADCQ $0, R15
// l7 * NC[1]
MOVQ 56(SP), AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R14
ADCQ DX, R15
// l7 * NC[2] (NC[2] = 1)
MOVQ 56(SP), AX
ADDQ AX, R15
// Now c[0..5] = R8, R9, R10, R11, R14, R15 (~385 bits max)
// ========== Phase 2: 385->258 bits ==========
// Reduce c[4..5] by multiplying by NC and adding to c[0..3]
// c4 * NC[0]
MOVQ R14, AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R8
ADCQ DX, R9
ADCQ $0, R10
ADCQ $0, R11
// c4 * NC[1]
MOVQ R14, AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R9
ADCQ DX, R10
ADCQ $0, R11
// c4 * NC[2] (NC[2] = 1)
ADDQ R14, R10
ADCQ $0, R11
// c5 * NC[0]
MOVQ R15, AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R9
ADCQ DX, R10
ADCQ $0, R11
// c5 * NC[1]
MOVQ R15, AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R10
ADCQ DX, R11
// c5 * NC[2] (NC[2] = 1)
ADDQ R15, R11
// Capture any final carry into R14
MOVQ $0, R14
ADCQ $0, R14
// Now we have ~258 bits in R8, R9, R10, R11, R14
// ========== Phase 3: 258->256 bits ==========
// If R14 (the overflow) is non-zero, reduce again
TESTQ R14, R14
JZ check_overflow
// R14 * NC
MOVQ R14, AX
MOVQ scalarNC<>+0x00(SB), R12
MULQ R12
ADDQ AX, R8
ADCQ DX, R9
ADCQ $0, R10
ADCQ $0, R11
MOVQ R14, AX
MOVQ scalarNC<>+0x08(SB), R12
MULQ R12
ADDQ AX, R9
ADCQ DX, R10
ADCQ $0, R11
// R14 * NC[2] (NC[2] = 1)
ADDQ R14, R10
ADCQ $0, R11
check_overflow:
// Check if result >= n and reduce if needed
MOVQ $0xFFFFFFFFFFFFFFFF, R13
CMPQ R11, R13
JB store_result
JA do_reduce
MOVQ scalarN<>+0x10(SB), R13
CMPQ R10, R13
JB store_result
JA do_reduce
MOVQ scalarN<>+0x08(SB), R13
CMPQ R9, R13
JB store_result
JA do_reduce
MOVQ scalarN<>+0x00(SB), R13
CMPQ R8, R13
JB store_result
do_reduce:
// Subtract n (add 2^256 - n)
MOVQ scalarNC<>+0x00(SB), R13
ADDQ R13, R8
MOVQ scalarNC<>+0x08(SB), R13
ADCQ R13, R9
MOVQ scalarNC<>+0x10(SB), R13
ADCQ R13, R10
MOVQ scalarNC<>+0x18(SB), R13
ADCQ R13, R11
store_result:
// Store result
MOVQ r+0(FP), DI
MOVQ R8, 0(DI)
MOVQ R9, 8(DI)
MOVQ R10, 16(DI)
MOVQ R11, 24(DI)
VZEROUPPER
RET

410
avx/trace_double_test.go Normal file
View File

@@ -0,0 +1,410 @@
package avx
import (
"encoding/hex"
"fmt"
"testing"
)
func TestGeneratorConstants(t *testing.T) {
// Verify the generator X and Y constants
expectedGx := "79be667ef9dcbbac55a06295ce870b07029bfcdb2dce28d959f2815b16f81798"
expectedGy := "483ada7726a3c4655da4fbfc0e1108a8fd17b448a68554199c47d08ffb10d4b8"
gx := Generator.X.Bytes()
gy := Generator.Y.Bytes()
t.Logf("Generator X: %x", gx)
t.Logf("Expected X: %s", expectedGx)
t.Logf("Generator Y: %x", gy)
t.Logf("Expected Y: %s", expectedGy)
// They should match
if expectedGx != fmt.Sprintf("%x", gx) {
t.Error("Generator X mismatch")
}
if expectedGy != fmt.Sprintf("%x", gy) {
t.Error("Generator Y mismatch")
}
// Verify G is on the curve
if !Generator.IsOnCurve() {
t.Error("Generator should be on curve")
}
// Let me test squaring and multiplication more carefully
// Y² should equal X³ + 7
var y2, x2, x3, seven, rhs FieldElement
y2.Sqr(&Generator.Y)
x2.Sqr(&Generator.X)
x3.Mul(&x2, &Generator.X)
seven.N[0].Lo = 7
rhs.Add(&x3, &seven)
t.Logf("Y² = %x", y2.Bytes())
t.Logf("X³ + 7 = %x", rhs.Bytes())
if !y2.Equal(&rhs) {
t.Error("Y² != X³ + 7 for generator")
}
}
func TestTraceDouble(t *testing.T) {
// Test the point doubling step by step
var g JacobianPoint
g.FromAffine(&Generator)
t.Logf("Input G:")
t.Logf(" X = %x", g.X.Bytes())
t.Logf(" Y = %x", g.Y.Bytes())
t.Logf(" Z = %x", g.Z.Bytes())
// Standard Jacobian doubling for y²=x³+b (secp256k1 has a=0):
// M = 3*X₁²
// S = 4*X₁*Y₁²
// T = 8*Y₁⁴
// X₃ = M² - 2*S
// Y₃ = M*(S - X₃) - T
// Z₃ = 2*Y₁*Z₁
var y2, m, x2, s, t_val, x3, y3, z3, tmp FieldElement
// Y² = Y₁²
y2.Sqr(&g.Y)
t.Logf("Y² = %x", y2.Bytes())
// M = 3*X²
x2.Sqr(&g.X)
t.Logf("X² = %x", x2.Bytes())
m.MulInt(&x2, 3)
t.Logf("M = 3*X² = %x", m.Bytes())
// S = 4*X₁*Y₁²
s.Mul(&g.X, &y2)
t.Logf("X*Y² = %x", s.Bytes())
s.MulInt(&s, 4)
t.Logf("S = 4*X*Y² = %x", s.Bytes())
// T = 8*Y₁⁴
t_val.Sqr(&y2)
t.Logf("Y⁴ = %x", t_val.Bytes())
t_val.MulInt(&t_val, 8)
t.Logf("T = 8*Y⁴ = %x", t_val.Bytes())
// X₃ = M² - 2*S
x3.Sqr(&m)
t.Logf("M² = %x", x3.Bytes())
tmp.Double(&s)
t.Logf("2*S = %x", tmp.Bytes())
x3.Sub(&x3, &tmp)
t.Logf("X₃ = M² - 2*S = %x", x3.Bytes())
// Y₃ = M*(S - X₃) - T
tmp.Sub(&s, &x3)
t.Logf("S - X₃ = %x", tmp.Bytes())
y3.Mul(&m, &tmp)
t.Logf("M*(S-X₃) = %x", y3.Bytes())
y3.Sub(&y3, &t_val)
t.Logf("Y₃ = M*(S-X₃) - T = %x", y3.Bytes())
// Z₃ = 2*Y₁*Z₁
z3.Mul(&g.Y, &g.Z)
z3.Double(&z3)
t.Logf("Z₃ = 2*Y*Z = %x", z3.Bytes())
// Now convert to affine
var doubled JacobianPoint
doubled.X = x3
doubled.Y = y3
doubled.Z = z3
doubled.Infinity = false
var affineResult AffinePoint
doubled.ToAffine(&affineResult)
t.Logf("Affine result (correct formula):")
t.Logf(" X = %x", affineResult.X.Bytes())
t.Logf(" Y = %x", affineResult.Y.Bytes())
// Expected 2G
expectedX := "c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5"
expectedY := "1ae168fea63dc339a3c58419466ceae1061b7c24a6b3e36e3b4d04f7a8f63301"
t.Logf("Expected:")
t.Logf(" X = %s", expectedX)
t.Logf(" Y = %s", expectedY)
// Verify by computing 2G using the existing Double method
var doubled2 JacobianPoint
doubled2.Double(&g)
var affine2 AffinePoint
doubled2.ToAffine(&affine2)
t.Logf("Current Double method result:")
t.Logf(" X = %x", affine2.X.Bytes())
t.Logf(" Y = %x", affine2.Y.Bytes())
// Compare results
expectedXBytes, _ := hex.DecodeString(expectedX)
expectedYBytes, _ := hex.DecodeString(expectedY)
if fmt.Sprintf("%x", affineResult.X.Bytes()) == expectedX &&
fmt.Sprintf("%x", affineResult.Y.Bytes()) == expectedY {
t.Logf("Correct formula produces expected result!")
} else {
t.Logf("Even correct formula doesn't match - problem elsewhere")
}
_ = expectedXBytes
_ = expectedYBytes
// Verify the result is on the curve
t.Logf("Result is on curve: %v", affineResult.IsOnCurve())
// Compute y² for the computed result
var verifyY2, verifyX2, verifyX3, verifySeven, verifyRhs FieldElement
verifyY2.Sqr(&affineResult.Y)
verifyX2.Sqr(&affineResult.X)
verifyX3.Mul(&verifyX2, &affineResult.X)
verifySeven.N[0].Lo = 7
verifyRhs.Add(&verifyX3, &verifySeven)
t.Logf("Computed y² = %x", verifyY2.Bytes())
t.Logf("Computed x³+7 = %x", verifyRhs.Bytes())
t.Logf("y² == x³+7: %v", verifyY2.Equal(&verifyRhs))
// Now test with the expected Y value
var expectedYField, expectedY2Field FieldElement
expectedYField.SetBytes(expectedYBytes)
expectedY2Field.Sqr(&expectedYField)
t.Logf("Expected Y² = %x", expectedY2Field.Bytes())
t.Logf("Expected Y² == x³+7: %v", expectedY2Field.Equal(&verifyRhs))
// Maybe I have the negative Y - let's check the negation
var negY FieldElement
negY.Negate(&affineResult.Y)
t.Logf("Negated computed Y = %x", negY.Bytes())
// Also check if the expected value is valid at all
// The expected 2G should be:
// X = c6047f9441ed7d6d3045406e95c07cd85c778e4b8cef3ca7abac09b95c709ee5
// Y = 1ae168fea63dc339a3c58419466ceae1061b7c24a6b3e36e3b4d04f7a8f63301
// Let me verify this is correct by computing y² directly
t.Log("--- Verifying expected 2G values ---")
var expXField FieldElement
expXField.SetBytes(expectedXBytes)
// Compute x³ + 7 for the expected X
var expX2, expX3, expRhs FieldElement
expX2.Sqr(&expXField)
expX3.Mul(&expX2, &expXField)
var seven2 FieldElement
seven2.N[0].Lo = 7
expRhs.Add(&expX3, &seven2)
t.Logf("For expected X, x³+7 = %x", expRhs.Bytes())
// Compute sqrt
var sqrtY FieldElement
if sqrtY.Sqrt(&expRhs) {
t.Logf("sqrt(x³+7) = %x", sqrtY.Bytes())
var negSqrtY FieldElement
negSqrtY.Negate(&sqrtY)
t.Logf("-sqrt(x³+7) = %x", negSqrtY.Bytes())
}
}
func TestDebugPointAdd(t *testing.T) {
// Compute 3G two ways: (1) G + 2G and (2) 3*G via scalar mult
var g, twoG, threeGAdd JacobianPoint
var affine3GAdd, affine3GSM AffinePoint
g.FromAffine(&Generator)
twoG.Double(&g)
threeGAdd.Add(&twoG, &g)
threeGAdd.ToAffine(&affine3GAdd)
t.Logf("2G (Jacobian):")
t.Logf(" X = %x", twoG.X.Bytes())
t.Logf(" Y = %x", twoG.Y.Bytes())
t.Logf(" Z = %x", twoG.Z.Bytes())
t.Logf("3G via Add (affine):")
t.Logf(" X = %x", affine3GAdd.X.Bytes())
t.Logf(" Y = %x", affine3GAdd.Y.Bytes())
t.Logf(" On curve: %v", affine3GAdd.IsOnCurve())
// Now via scalar mult
var three Scalar
three.D[0].Lo = 3
var threeGSM JacobianPoint
threeGSM.ScalarMult(&g, &three)
threeGSM.ToAffine(&affine3GSM)
t.Logf("3G via ScalarMult (affine):")
t.Logf(" X = %x", affine3GSM.X.Bytes())
t.Logf(" Y = %x", affine3GSM.Y.Bytes())
t.Logf(" On curve: %v", affine3GSM.IsOnCurve())
// Compute expected 3G using Python
// This should be:
// X = f9308a019258c31049344f85f89d5229b531c845836f99b08601f113bce036f9
// Y = 388f7b0f632de8140fe337e62a37f3566500a99934c2231b6cb9fd7584b8e672
t.Logf("Equal: %v", affine3GAdd.Equal(&affine3GSM))
}
func TestAVX2Operations(t *testing.T) {
// Test that AVX2 assembly produces same results as Go code
if !hasAVX2() {
t.Skip("AVX2 not available")
}
// Test field addition
var a, b, resultGo, resultAVX FieldElement
a.N[0].Lo = 0x123456789ABCDEF0
a.N[0].Hi = 0xFEDCBA9876543210
a.N[1].Lo = 0x1111111111111111
a.N[1].Hi = 0x2222222222222222
b.N[0].Lo = 0x0FEDCBA987654321
b.N[0].Hi = 0x123456789ABCDEF0
b.N[1].Lo = 0x3333333333333333
b.N[1].Hi = 0x4444444444444444
resultGo.Add(&a, &b)
FieldAddAVX2(&resultAVX, &a, &b)
if !resultGo.Equal(&resultAVX) {
t.Errorf("FieldAddAVX2 mismatch:\n Go: %x\n AVX2: %x", resultGo.Bytes(), resultAVX.Bytes())
}
// Test field subtraction
resultGo.Sub(&a, &b)
FieldSubAVX2(&resultAVX, &a, &b)
if !resultGo.Equal(&resultAVX) {
t.Errorf("FieldSubAVX2 mismatch:\n Go: %x\n AVX2: %x", resultGo.Bytes(), resultAVX.Bytes())
}
// Test field multiplication
resultGo.Mul(&a, &b)
FieldMulAVX2(&resultAVX, &a, &b)
if !resultGo.Equal(&resultAVX) {
t.Errorf("FieldMulAVX2 mismatch:\n Go: %x\n AVX2: %x", resultGo.Bytes(), resultAVX.Bytes())
}
// Test scalar addition
var sa, sb, sResultGo, sResultAVX Scalar
sa.D[0].Lo = 0x123456789ABCDEF0
sa.D[0].Hi = 0xFEDCBA9876543210
sa.D[1].Lo = 0x1111111111111111
sa.D[1].Hi = 0x2222222222222222
sb.D[0].Lo = 0x0FEDCBA987654321
sb.D[0].Hi = 0x123456789ABCDEF0
sb.D[1].Lo = 0x3333333333333333
sb.D[1].Hi = 0x4444444444444444
sResultGo.Add(&sa, &sb)
ScalarAddAVX2(&sResultAVX, &sa, &sb)
if !sResultGo.Equal(&sResultAVX) {
t.Errorf("ScalarAddAVX2 mismatch:\n Go: %x\n AVX2: %x", sResultGo.Bytes(), sResultAVX.Bytes())
}
// Test scalar multiplication
sResultGo.Mul(&sa, &sb)
ScalarMulAVX2(&sResultAVX, &sa, &sb)
if !sResultGo.Equal(&sResultAVX) {
t.Errorf("ScalarMulAVX2 mismatch:\n Go: %x\n AVX2: %x", sResultGo.Bytes(), sResultAVX.Bytes())
}
t.Logf("Field and Scalar Add/Sub AVX2 operations match Go implementations")
}
func TestDebugScalarMult(t *testing.T) {
// Test 2*G via scalar mult
var g, twoGDouble, twoGSM JacobianPoint
var affineDouble, affineSM AffinePoint
g.FromAffine(&Generator)
// Via doubling
twoGDouble.Double(&g)
twoGDouble.ToAffine(&affineDouble)
// Via scalar mult (k=2)
var two Scalar
two.D[0].Lo = 2
// Print the bytes of k=2
twoBytes := two.Bytes()
t.Logf("k=2 bytes: %x", twoBytes[:])
twoGSM.ScalarMult(&g, &two)
twoGSM.ToAffine(&affineSM)
t.Logf("2G via Double (affine):")
t.Logf(" X = %x", affineDouble.X.Bytes())
t.Logf(" Y = %x", affineDouble.Y.Bytes())
t.Logf(" On curve: %v", affineDouble.IsOnCurve())
t.Logf("2G via ScalarMult (affine):")
t.Logf(" X = %x", affineSM.X.Bytes())
t.Logf(" Y = %x", affineSM.Y.Bytes())
t.Logf(" On curve: %v", affineSM.IsOnCurve())
t.Logf("Equal: %v", affineDouble.Equal(&affineSM))
// Manual scalar mult for k=2
// Binary: 10 (2 bits)
// Start with p = infinity
// bit 1: p = 2*infinity = infinity, then p = p + G = G
// bit 0: p = 2*G, no add
// Result should be 2G
var p JacobianPoint
p.SetInfinity()
// Process bit 1 (the high bit of 2)
p.Double(&p)
t.Logf("After double of infinity: IsInfinity=%v", p.IsInfinity())
p.Add(&p, &g)
t.Logf("After add G: IsInfinity=%v", p.IsInfinity())
var affineP AffinePoint
p.ToAffine(&affineP)
t.Logf("After first iteration (should be G):")
t.Logf(" X = %x", affineP.X.Bytes())
t.Logf(" Y = %x", affineP.Y.Bytes())
t.Logf(" Equal to G: %v", affineP.Equal(&Generator))
// Process bit 0
p.Double(&p)
p.ToAffine(&affineP)
t.Logf("After second iteration (should be 2G):")
t.Logf(" X = %x", affineP.X.Bytes())
t.Logf(" Y = %x", affineP.Y.Bytes())
t.Logf(" On curve: %v", affineP.IsOnCurve())
t.Logf(" Equal to Double result: %v", affineP.Equal(&affineDouble))
// Test: does doubling G into a fresh variable work?
var fresh JacobianPoint
var freshAffine AffinePoint
fresh.Double(&g)
fresh.ToAffine(&freshAffine)
t.Logf("Fresh Double(g):")
t.Logf(" X = %x", freshAffine.X.Bytes())
t.Logf(" Y = %x", freshAffine.Y.Bytes())
t.Logf(" On curve: %v", freshAffine.IsOnCurve())
// Test: what about p.Double(p) when p == g?
var pCopy JacobianPoint
pCopy = p // now p is already set to some value
pCopy.FromAffine(&Generator)
t.Logf("Before in-place double, pCopy X: %x", pCopy.X.Bytes())
pCopy.Double(&pCopy)
var pCopyAffine AffinePoint
pCopy.ToAffine(&pCopyAffine)
t.Logf("After in-place Double(&pCopy):")
t.Logf(" X = %x", pCopyAffine.X.Bytes())
t.Logf(" Y = %x", pCopyAffine.Y.Bytes())
t.Logf(" On curve: %v", pCopyAffine.IsOnCurve())
}

119
avx/types.go Normal file
View File

@@ -0,0 +1,119 @@
// Package avx provides AVX2-accelerated secp256k1 operations using 128-bit limbs.
//
// This implementation uses 128-bit limbs stored in 256-bit AVX2 registers:
// - Scalar: 256-bit value as 2×128-bit limbs (fits in 1 YMM register)
// - FieldElement: 256-bit value as 2×128-bit limbs (fits in 1 YMM register)
// - AffinePoint: 512-bit (x,y) as 2×256-bit (fits in 2 YMM registers)
// - JacobianPoint: 768-bit (x,y,z) as 3×256-bit (fits in 3 YMM registers)
package avx
// Uint128 represents a 128-bit unsigned integer as two 64-bit limbs.
// This is the fundamental building block for AVX2 operations.
// In AVX2 assembly, two Uint128 values fit in a single YMM register.
type Uint128 struct {
Lo, Hi uint64 // Lo + Hi<<64
}
// Scalar represents a 256-bit scalar value modulo the secp256k1 group order.
// Uses 2×128-bit limbs for efficient AVX2 processing.
// The entire scalar fits in a single YMM register.
type Scalar struct {
D [2]Uint128 // D[0] is low 128 bits, D[1] is high 128 bits
}
// FieldElement represents a field element modulo the secp256k1 field prime.
// Uses 2×128-bit limbs for efficient AVX2 processing.
// The entire field element fits in a single YMM register.
type FieldElement struct {
N [2]Uint128 // N[0] is low 128 bits, N[1] is high 128 bits
}
// AffinePoint represents a point on the secp256k1 curve in affine coordinates.
// Uses 2 YMM registers (one for X, one for Y).
type AffinePoint struct {
X, Y FieldElement
Infinity bool
}
// JacobianPoint represents a point in Jacobian coordinates (X, Y, Z).
// Affine coordinates are (X/Z², Y/Z³).
// Uses 3 YMM registers (one each for X, Y, Z).
type JacobianPoint struct {
X, Y, Z FieldElement
Infinity bool
}
// Constants for secp256k1
// Group order n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
var (
ScalarN = Scalar{
D: [2]Uint128{
{Lo: 0xBFD25E8CD0364141, Hi: 0xBAAEDCE6AF48A03B}, // low 128 bits
{Lo: 0xFFFFFFFFFFFFFFFE, Hi: 0xFFFFFFFFFFFFFFFF}, // high 128 bits
},
}
// 2^256 - n (used for reduction)
ScalarNC = Scalar{
D: [2]Uint128{
{Lo: 0x402DA1732FC9BEBF, Hi: 0x4551231950B75FC4}, // low 128 bits
{Lo: 0x0000000000000001, Hi: 0x0000000000000000}, // high 128 bits
},
}
// n/2 (for checking if scalar is high)
ScalarNHalf = Scalar{
D: [2]Uint128{
{Lo: 0xDFE92F46681B20A0, Hi: 0x5D576E7357A4501D}, // low 128 bits
{Lo: 0xFFFFFFFFFFFFFFFF, Hi: 0x7FFFFFFFFFFFFFFF}, // high 128 bits
},
}
ScalarZero = Scalar{}
ScalarOne = Scalar{D: [2]Uint128{{Lo: 1, Hi: 0}, {Lo: 0, Hi: 0}}}
)
// Field prime p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
var (
FieldP = FieldElement{
N: [2]Uint128{
{Lo: 0xFFFFFFFEFFFFFC2F, Hi: 0xFFFFFFFFFFFFFFFF}, // low 128 bits
{Lo: 0xFFFFFFFFFFFFFFFF, Hi: 0xFFFFFFFFFFFFFFFF}, // high 128 bits
},
}
// 2^256 - p = 2^32 + 977 = 0x1000003D1
FieldPC = FieldElement{
N: [2]Uint128{
{Lo: 0x1000003D1, Hi: 0}, // low 128 bits
{Lo: 0, Hi: 0}, // high 128 bits
},
}
FieldZero = FieldElement{}
FieldOne = FieldElement{N: [2]Uint128{{Lo: 1, Hi: 0}, {Lo: 0, Hi: 0}}}
)
// Generator point G for secp256k1
var (
GeneratorX = FieldElement{
N: [2]Uint128{
{Lo: 0x59F2815B16F81798, Hi: 0x029BFCDB2DCE28D9},
{Lo: 0x55A06295CE870B07, Hi: 0x79BE667EF9DCBBAC},
},
}
GeneratorY = FieldElement{
N: [2]Uint128{
{Lo: 0x9C47D08FFB10D4B8, Hi: 0xFD17B448A6855419},
{Lo: 0x5DA4FBFC0E1108A8, Hi: 0x483ADA7726A3C465},
},
}
Generator = AffinePoint{
X: GeneratorX,
Y: GeneratorY,
Infinity: false,
}
)

149
avx/uint128.go Normal file
View File

@@ -0,0 +1,149 @@
//go:build !amd64
package avx
import "math/bits"
// Pure Go fallback implementation for non-amd64 platforms
// Add adds two Uint128 values, returning the result and carry.
func (a Uint128) Add(b Uint128) (result Uint128, carry uint64) {
result.Lo, carry = bits.Add64(a.Lo, b.Lo, 0)
result.Hi, carry = bits.Add64(a.Hi, b.Hi, carry)
return
}
// AddCarry adds two Uint128 values with an input carry.
func (a Uint128) AddCarry(b Uint128, carryIn uint64) (result Uint128, carryOut uint64) {
result.Lo, carryOut = bits.Add64(a.Lo, b.Lo, carryIn)
result.Hi, carryOut = bits.Add64(a.Hi, b.Hi, carryOut)
return
}
// Sub subtracts b from a, returning the result and borrow.
func (a Uint128) Sub(b Uint128) (result Uint128, borrow uint64) {
result.Lo, borrow = bits.Sub64(a.Lo, b.Lo, 0)
result.Hi, borrow = bits.Sub64(a.Hi, b.Hi, borrow)
return
}
// SubBorrow subtracts b from a with an input borrow.
func (a Uint128) SubBorrow(b Uint128, borrowIn uint64) (result Uint128, borrowOut uint64) {
result.Lo, borrowOut = bits.Sub64(a.Lo, b.Lo, borrowIn)
result.Hi, borrowOut = bits.Sub64(a.Hi, b.Hi, borrowOut)
return
}
// Mul64 multiplies two 64-bit values and returns a 128-bit result.
func Mul64(a, b uint64) Uint128 {
hi, lo := bits.Mul64(a, b)
return Uint128{Lo: lo, Hi: hi}
}
// Mul multiplies two Uint128 values and returns a 256-bit result as [4]uint64.
// Result is [lo0, lo1, hi0, hi1] where value = lo0 + lo1<<64 + hi0<<128 + hi1<<192
func (a Uint128) Mul(b Uint128) [4]uint64 {
// (a.Hi*2^64 + a.Lo) * (b.Hi*2^64 + b.Lo)
// = a.Hi*b.Hi*2^128 + (a.Hi*b.Lo + a.Lo*b.Hi)*2^64 + a.Lo*b.Lo
// a.Lo * b.Lo -> r[0:1]
r0Hi, r0Lo := bits.Mul64(a.Lo, b.Lo)
// a.Lo * b.Hi -> r[1:2]
r1Hi, r1Lo := bits.Mul64(a.Lo, b.Hi)
// a.Hi * b.Lo -> r[1:2]
r2Hi, r2Lo := bits.Mul64(a.Hi, b.Lo)
// a.Hi * b.Hi -> r[2:3]
r3Hi, r3Lo := bits.Mul64(a.Hi, b.Hi)
var result [4]uint64
var carry uint64
result[0] = r0Lo
// result[1] = r0Hi + r1Lo + r2Lo
result[1], carry = bits.Add64(r0Hi, r1Lo, 0)
result[1], carry = bits.Add64(result[1], r2Lo, carry)
// result[2] = r1Hi + r2Hi + r3Lo + carry
result[2], carry = bits.Add64(r1Hi, r2Hi, carry)
result[2], carry = bits.Add64(result[2], r3Lo, carry)
// result[3] = r3Hi + carry
result[3] = r3Hi + carry
return result
}
// IsZero returns true if the Uint128 is zero.
func (a Uint128) IsZero() bool {
return a.Lo == 0 && a.Hi == 0
}
// Cmp compares two Uint128 values.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func (a Uint128) Cmp(b Uint128) int {
if a.Hi < b.Hi {
return -1
}
if a.Hi > b.Hi {
return 1
}
if a.Lo < b.Lo {
return -1
}
if a.Lo > b.Lo {
return 1
}
return 0
}
// Lsh shifts a Uint128 left by n bits (n < 128).
func (a Uint128) Lsh(n uint) Uint128 {
if n >= 64 {
return Uint128{Lo: 0, Hi: a.Lo << (n - 64)}
}
if n == 0 {
return a
}
return Uint128{
Lo: a.Lo << n,
Hi: (a.Hi << n) | (a.Lo >> (64 - n)),
}
}
// Rsh shifts a Uint128 right by n bits (n < 128).
func (a Uint128) Rsh(n uint) Uint128 {
if n >= 64 {
return Uint128{Lo: a.Hi >> (n - 64), Hi: 0}
}
if n == 0 {
return a
}
return Uint128{
Lo: (a.Lo >> n) | (a.Hi << (64 - n)),
Hi: a.Hi >> n,
}
}
// Or returns the bitwise OR of two Uint128 values.
func (a Uint128) Or(b Uint128) Uint128 {
return Uint128{Lo: a.Lo | b.Lo, Hi: a.Hi | b.Hi}
}
// And returns the bitwise AND of two Uint128 values.
func (a Uint128) And(b Uint128) Uint128 {
return Uint128{Lo: a.Lo & b.Lo, Hi: a.Hi & b.Hi}
}
// Xor returns the bitwise XOR of two Uint128 values.
func (a Uint128) Xor(b Uint128) Uint128 {
return Uint128{Lo: a.Lo ^ b.Lo, Hi: a.Hi ^ b.Hi}
}
// Not returns the bitwise NOT of a Uint128.
func (a Uint128) Not() Uint128 {
return Uint128{Lo: ^a.Lo, Hi: ^a.Hi}
}

125
avx/uint128_amd64.go Normal file
View File

@@ -0,0 +1,125 @@
//go:build amd64
package avx
import "math/bits"
// AMD64 implementation with AVX2 assembly where beneficial.
// For simple operations, Go with compiler intrinsics is often as fast as assembly.
// Add adds two Uint128 values, returning the result and carry.
func (a Uint128) Add(b Uint128) (result Uint128, carry uint64) {
result.Lo, carry = bits.Add64(a.Lo, b.Lo, 0)
result.Hi, carry = bits.Add64(a.Hi, b.Hi, carry)
return
}
// AddCarry adds two Uint128 values with an input carry.
func (a Uint128) AddCarry(b Uint128, carryIn uint64) (result Uint128, carryOut uint64) {
result.Lo, carryOut = bits.Add64(a.Lo, b.Lo, carryIn)
result.Hi, carryOut = bits.Add64(a.Hi, b.Hi, carryOut)
return
}
// Sub subtracts b from a, returning the result and borrow.
func (a Uint128) Sub(b Uint128) (result Uint128, borrow uint64) {
result.Lo, borrow = bits.Sub64(a.Lo, b.Lo, 0)
result.Hi, borrow = bits.Sub64(a.Hi, b.Hi, borrow)
return
}
// SubBorrow subtracts b from a with an input borrow.
func (a Uint128) SubBorrow(b Uint128, borrowIn uint64) (result Uint128, borrowOut uint64) {
result.Lo, borrowOut = bits.Sub64(a.Lo, b.Lo, borrowIn)
result.Hi, borrowOut = bits.Sub64(a.Hi, b.Hi, borrowOut)
return
}
// Mul64 multiplies two 64-bit values and returns a 128-bit result.
func Mul64(a, b uint64) Uint128 {
hi, lo := bits.Mul64(a, b)
return Uint128{Lo: lo, Hi: hi}
}
// Mul multiplies two Uint128 values and returns a 256-bit result as [4]uint64.
// Result is [lo0, lo1, hi0, hi1] where value = lo0 + lo1<<64 + hi0<<128 + hi1<<192
func (a Uint128) Mul(b Uint128) [4]uint64 {
// Use assembly for the full 128x128->256 multiplication
return uint128Mul(a, b)
}
// uint128Mul performs 128x128->256 bit multiplication using optimized assembly.
//
//go:noescape
func uint128Mul(a, b Uint128) [4]uint64
// IsZero returns true if the Uint128 is zero.
func (a Uint128) IsZero() bool {
return a.Lo == 0 && a.Hi == 0
}
// Cmp compares two Uint128 values.
// Returns -1 if a < b, 0 if a == b, 1 if a > b.
func (a Uint128) Cmp(b Uint128) int {
if a.Hi < b.Hi {
return -1
}
if a.Hi > b.Hi {
return 1
}
if a.Lo < b.Lo {
return -1
}
if a.Lo > b.Lo {
return 1
}
return 0
}
// Lsh shifts a Uint128 left by n bits (n < 128).
func (a Uint128) Lsh(n uint) Uint128 {
if n >= 64 {
return Uint128{Lo: 0, Hi: a.Lo << (n - 64)}
}
if n == 0 {
return a
}
return Uint128{
Lo: a.Lo << n,
Hi: (a.Hi << n) | (a.Lo >> (64 - n)),
}
}
// Rsh shifts a Uint128 right by n bits (n < 128).
func (a Uint128) Rsh(n uint) Uint128 {
if n >= 64 {
return Uint128{Lo: a.Hi >> (n - 64), Hi: 0}
}
if n == 0 {
return a
}
return Uint128{
Lo: (a.Lo >> n) | (a.Hi << (64 - n)),
Hi: a.Hi >> n,
}
}
// Or returns the bitwise OR of two Uint128 values.
func (a Uint128) Or(b Uint128) Uint128 {
return Uint128{Lo: a.Lo | b.Lo, Hi: a.Hi | b.Hi}
}
// And returns the bitwise AND of two Uint128 values.
func (a Uint128) And(b Uint128) Uint128 {
return Uint128{Lo: a.Lo & b.Lo, Hi: a.Hi & b.Hi}
}
// Xor returns the bitwise XOR of two Uint128 values.
func (a Uint128) Xor(b Uint128) Uint128 {
return Uint128{Lo: a.Lo ^ b.Lo, Hi: a.Hi ^ b.Hi}
}
// Not returns the bitwise NOT of a Uint128.
func (a Uint128) Not() Uint128 {
return Uint128{Lo: ^a.Lo, Hi: ^a.Hi}
}

67
avx/uint128_amd64.s Normal file
View File

@@ -0,0 +1,67 @@
//go:build amd64
#include "textflag.h"
// func uint128Mul(a, b Uint128) [4]uint64
// Multiplies two 128-bit values and returns a 256-bit result.
//
// Input:
// a.Lo = arg+0(FP)
// a.Hi = arg+8(FP)
// b.Lo = arg+16(FP)
// b.Hi = arg+24(FP)
//
// Output:
// result[0] = ret+32(FP) (bits 0-63)
// result[1] = ret+40(FP) (bits 64-127)
// result[2] = ret+48(FP) (bits 128-191)
// result[3] = ret+56(FP) (bits 192-255)
//
// Algorithm:
// (a.Hi*2^64 + a.Lo) * (b.Hi*2^64 + b.Lo)
// = a.Hi*b.Hi*2^128 + (a.Hi*b.Lo + a.Lo*b.Hi)*2^64 + a.Lo*b.Lo
//
TEXT ·uint128Mul(SB), NOSPLIT, $0-64
// Load inputs
MOVQ a_Lo+0(FP), AX // AX = a.Lo
MOVQ a_Hi+8(FP), BX // BX = a.Hi
MOVQ b_Lo+16(FP), CX // CX = b.Lo
MOVQ b_Hi+24(FP), DX // DX = b.Hi
// Save b.Hi for later (DX will be clobbered by MUL)
MOVQ DX, R11 // R11 = b.Hi
// r0:r1 = a.Lo * b.Lo
MOVQ AX, R8 // R8 = a.Lo (save for later)
MULQ CX // DX:AX = a.Lo * b.Lo
MOVQ AX, R9 // R9 = result[0] (low 64 bits)
MOVQ DX, R10 // R10 = carry to result[1]
// r1:r2 += a.Hi * b.Lo
MOVQ BX, AX // AX = a.Hi
MULQ CX // DX:AX = a.Hi * b.Lo
ADDQ AX, R10 // R10 += low part
ADCQ $0, DX // DX += carry
MOVQ DX, CX // CX = carry to result[2]
// r1:r2 += a.Lo * b.Hi
MOVQ R8, AX // AX = a.Lo
MULQ R11 // DX:AX = a.Lo * b.Hi
ADDQ AX, R10 // R10 += low part
ADCQ DX, CX // CX += high part + carry
MOVQ $0, R8
ADCQ $0, R8 // R8 = carry to result[3]
// r2:r3 += a.Hi * b.Hi
MOVQ BX, AX // AX = a.Hi
MULQ R11 // DX:AX = a.Hi * b.Hi
ADDQ AX, CX // CX += low part
ADCQ DX, R8 // R8 += high part + carry
// Store results
MOVQ R9, ret+32(FP) // result[0]
MOVQ R10, ret+40(FP) // result[1]
MOVQ CX, ret+48(FP) // result[2]
MOVQ R8, ret+56(FP) // result[3]
RET

View File

@@ -0,0 +1,163 @@
# Signer Optimization Report
## Summary
Optimized the P256K1Signer implementation by profiling and eliminating memory allocations in hot paths. The optimizations focused on reusing buffers for frequently called methods instead of allocating on each call.
## Key Changes
### 1. **P256K1Gen.KeyPairBytes() - Eliminated 94% of allocations**
**Before:**
- 1469 MB total allocations (94% of all allocations)
- 32 B/op with 1 alloc/op
- 23.58 ns/op
**After:**
- 0 B/op with 0 allocs/op
- 4.529 ns/op (5.2x faster)
**Implementation:**
- Added reusable buffer (`pubBuf []byte`) to `P256K1Gen` struct
- Buffer is allocated once and reused across calls
- Documented that returned slice may be reused
### 2. **Sign() method - Reduced allocations by ~10%**
**Before:**
- 640 B/op with 11 allocs/op
- 55,645 ns/op
**After:**
- 576 B/op with 10 allocs/op (10% reduction)
- 56,291 ns/op
**Implementation:**
- Added reusable signature buffer (`sigBuf []byte`) to `P256K1Signer` struct
- Eliminated stack-to-heap allocation from returning `sig64[:]`
- Documented that returned slice may be reused
### 3. **ECDH() method - Reduced allocations by ~15%**
**Before:**
- 246 B/op with 6 allocs/op
- 106,611 ns/op
**After:**
- 209 B/op with 5 allocs/op (15% reduction)
- 106,638 ns/op
**Implementation:**
- Added reusable ECDH buffer (`ecdhBuf []byte`) to `P256K1Signer` struct
- Eliminated stack-to-heap allocation from returning `sharedSecret[:]`
- Documented that returned slice may be reused
### 4. **InitSec() method - Cut allocations in half**
**Before:**
- 257 B/op with 4 allocs/op
- 54,223 ns/op
**After:**
- 128 B/op with 2 allocs/op (50% reduction)
- 28,319 ns/op (1.9x faster)
**Implementation:**
- Benefits from buffer reuse in other methods
- Fewer intermediate allocations
### 5. **Pub() method - Already optimal**
**Before & After:**
- 0 B/op with 0 allocs/op
- ~0.5 ns/op
**Implementation:**
- Already returning slice from stack array efficiently
- No changes needed, just documented behavior
## Overall Impact
### Total Memory Allocations
- **Before:** 1,556.43 MB total allocated space
- **After:** 65.82 MB total allocated space
- **Reduction:** **95.8% reduction** in total allocations
### Performance Summary
| Benchmark | Before (ns/op) | After (ns/op) | Speedup | Before (B/op) | After (B/op) | Reduction |
|-----------|----------------|---------------|---------|---------------|--------------|-----------|
| Generate | 44,420 | 44,018 | 1.01x | 289 | 287 | 0.7% |
| InitSec | 54,223 | 28,319 | 1.91x | 257 | 128 | 50.2% |
| InitPub | 5,708 | 5,669 | 1.01x | 32 | 32 | 0% |
| Sign | 55,645 | 56,291 | 0.99x | 640 | 576 | 10% |
| Verify | 136,922 | 134,306 | 1.02x | 97 | 96 | 1% |
| ECDH | 106,611 | 106,638 | 1.00x | 246 | 209 | 15% |
| Pub | 0.52 | 0.25 | 2.08x | 0 | 0 | 0% |
| Gen.Generate | 29,534 | 31,402 | 0.94x | 304 | 304 | 0% |
| Gen.Negate | 27,707 | 27,994 | 0.99x | 192 | 192 | 0% |
| Gen.KeyPairBytes | 23.58 | 4.529 | 5.21x | 32 | 0 | 100% |
## Important Notes
### API Compatibility Warning
The optimizations introduce a subtle API change that users must be aware of:
**Methods that now return reusable buffers:**
- `Sign(msg []byte) ([]byte, error)`
- `ECDH(pub []byte) ([]byte, error)`
- `KeyPairBytes() ([]byte, []byte)`
**Behavior:**
- The returned slices are backed by internal buffers
- These buffers **may be reused** on subsequent calls to the same method
- If you need to retain the data, you **must copy it**
**Example:**
```go
// ❌ WRONG - data may be overwritten
sig1, _ := signer.Sign(msg1)
sig2, _ := signer.Sign(msg2)
// sig1 may now contain sig2's data!
// ✅ CORRECT - copy if you need to retain
sig1, _ := signer.Sign(msg1)
sig1Copy := make([]byte, len(sig1))
copy(sig1Copy, sig1)
sig2, _ := signer.Sign(msg2)
// sig1Copy is safe to use
```
### Why This Approach?
1. **Performance:** Eliminates allocations in hot paths (signing, ECDH)
2. **Common Pattern:** Many crypto libraries use this pattern (e.g., Go's crypto/cipher)
3. **Documented:** All affected methods have clear documentation
4. **Optional:** Users can still copy if needed for their use case
## Testing
All existing tests pass without modification, confirming backward compatibility for the common use case where results are used immediately.
```bash
cd /home/mleku/src/p256k1.mleku.dev/signer
go test -v
# PASS
```
## Profiling Commands
To reproduce the profiling results:
```bash
# Run benchmarks with profiling
go test -bench=. -benchmem -memprofile=mem.prof -cpuprofile=cpu.prof
# Analyze memory allocations
go tool pprof -top -alloc_space mem.prof
# Detailed line-by-line analysis
go tool pprof -list=P256K1Signer mem.prof
```

View File

@@ -10,7 +10,9 @@ import (
type P256K1Signer struct {
keypair *p256k1.KeyPair
xonlyPub *p256k1.XOnlyPubkey
hasSecret bool // Whether we have the secret key (if false, can only verify)
hasSecret bool // Whether we have the secret key (if false, can only verify)
sigBuf []byte // Reusable buffer for signatures to avoid allocations
ecdhBuf []byte // Reusable buffer for ECDH shared secrets
}
// NewP256K1Signer creates a new P256K1Signer instance
@@ -129,6 +131,8 @@ func (s *P256K1Signer) Sec() []byte {
}
// Pub returns the public key bytes (x-only schnorr pubkey)
// The returned slice is backed by an internal buffer that may be
// reused on subsequent calls. Copy if you need to retain it.
func (s *P256K1Signer) Pub() []byte {
if s.xonlyPub == nil {
return nil
@@ -138,6 +142,8 @@ func (s *P256K1Signer) Pub() []byte {
}
// Sign creates a signature using the stored secret key
// The returned slice is backed by an internal buffer that may be
// reused on subsequent calls. Copy if you need to retain it.
func (s *P256K1Signer) Sign(msg []byte) (sig []byte, err error) {
if !s.hasSecret || s.keypair == nil {
return nil, errors.New("no secret key available for signing")
@@ -147,12 +153,18 @@ func (s *P256K1Signer) Sign(msg []byte) (sig []byte, err error) {
return nil, errors.New("message must be 32 bytes")
}
var sig64 [64]byte
if err := p256k1.SchnorrSign(sig64[:], msg, s.keypair, nil); err != nil {
// Pre-allocate buffer to reuse across calls
if cap(s.sigBuf) < 64 {
s.sigBuf = make([]byte, 64)
} else {
s.sigBuf = s.sigBuf[:64]
}
if err := p256k1.SchnorrSign(s.sigBuf, msg, s.keypair, nil); err != nil {
return nil, err
}
return sig64[:], nil
return s.sigBuf, nil
}
// Verify checks a message hash and signature match the stored public key
@@ -185,6 +197,8 @@ func (s *P256K1Signer) Zero() {
}
// ECDH returns a shared secret derived using Elliptic Curve Diffie-Hellman on the I secret and provided pubkey
// The returned slice is backed by an internal buffer that may be
// reused on subsequent calls. Copy if you need to retain it.
func (s *P256K1Signer) ECDH(pub []byte) (secret []byte, err error) {
if !s.hasSecret || s.keypair == nil {
return nil, errors.New("no secret key available for ECDH")
@@ -205,13 +219,19 @@ func (s *P256K1Signer) ECDH(pub []byte) (secret []byte, err error) {
return nil, err
}
// Pre-allocate buffer to reuse across calls
if cap(s.ecdhBuf) < 32 {
s.ecdhBuf = make([]byte, 32)
} else {
s.ecdhBuf = s.ecdhBuf[:32]
}
// Compute ECDH shared secret using standard ECDH (hashes the point)
var sharedSecret [32]byte
if err := p256k1.ECDH(sharedSecret[:], &pubkey, s.keypair.Seckey(), nil); err != nil {
if err := p256k1.ECDH(s.ecdhBuf, &pubkey, s.keypair.Seckey(), nil); err != nil {
return nil, err
}
return sharedSecret[:], nil
return s.ecdhBuf, nil
}
// P256K1Gen implements the Gen interface for nostr BIP-340 key generation
@@ -219,6 +239,7 @@ type P256K1Gen struct {
keypair *p256k1.KeyPair
xonlyPub *p256k1.XOnlyPubkey
compressedPub *p256k1.PublicKey
pubBuf []byte // Reusable buffer to avoid allocations in KeyPairBytes
}
// NewP256K1Gen creates a new P256K1Gen instance
@@ -283,6 +304,8 @@ func (g *P256K1Gen) Negate() {
}
// KeyPairBytes returns the raw bytes of the secret and public key, this returns the 32 byte X-only pubkey
// The returned pubkey slice is backed by an internal buffer that may be
// reused on subsequent calls. Copy if you need to retain it.
func (g *P256K1Gen) KeyPairBytes() (secBytes, cmprPubBytes []byte) {
if g.keypair == nil {
return nil, nil
@@ -298,8 +321,17 @@ func (g *P256K1Gen) KeyPairBytes() (secBytes, cmprPubBytes []byte) {
g.xonlyPub = xonly
}
// Pre-allocate buffer to reuse across calls
if cap(g.pubBuf) < 32 {
g.pubBuf = make([]byte, 32)
} else {
g.pubBuf = g.pubBuf[:32]
}
// Copy the serialized public key into our buffer
serialized := g.xonlyPub.Serialize()
cmprPubBytes = serialized[:]
copy(g.pubBuf, serialized[:])
cmprPubBytes = g.pubBuf
return secBytes, cmprPubBytes
}

View File

@@ -0,0 +1,176 @@
package signer
import (
"crypto/rand"
"testing"
"p256k1.mleku.dev"
)
// BenchmarkP256K1Signer_Generate benchmarks key generation
func BenchmarkP256K1Signer_Generate(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s := NewP256K1Signer()
if err := s.Generate(); err != nil {
b.Fatal(err)
}
s.Zero()
}
}
// BenchmarkP256K1Signer_InitSec benchmarks secret key initialization
func BenchmarkP256K1Signer_InitSec(b *testing.B) {
// Pre-generate a secret key
sec := make([]byte, 32)
rand.Read(sec)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s := NewP256K1Signer()
if err := s.InitSec(sec); err != nil {
b.Fatal(err)
}
s.Zero()
}
}
// BenchmarkP256K1Signer_InitPub benchmarks public key initialization
func BenchmarkP256K1Signer_InitPub(b *testing.B) {
// Pre-generate a public key
kp, _ := p256k1.KeyPairGenerate()
xonly, _ := kp.XOnlyPubkey()
pub := xonly.Serialize()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
s := NewP256K1Signer()
if err := s.InitPub(pub[:]); err != nil {
b.Fatal(err)
}
s.Zero()
}
}
// BenchmarkP256K1Signer_Sign benchmarks signing
func BenchmarkP256K1Signer_Sign(b *testing.B) {
s := NewP256K1Signer()
if err := s.Generate(); err != nil {
b.Fatal(err)
}
defer s.Zero()
msg := make([]byte, 32)
rand.Read(msg)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := s.Sign(msg); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkP256K1Signer_Verify benchmarks verification
func BenchmarkP256K1Signer_Verify(b *testing.B) {
s := NewP256K1Signer()
if err := s.Generate(); err != nil {
b.Fatal(err)
}
defer s.Zero()
msg := make([]byte, 32)
rand.Read(msg)
sig, _ := s.Sign(msg)
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := s.Verify(msg, sig); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkP256K1Signer_ECDH benchmarks ECDH computation
func BenchmarkP256K1Signer_ECDH(b *testing.B) {
s1 := NewP256K1Signer()
if err := s1.Generate(); err != nil {
b.Fatal(err)
}
defer s1.Zero()
s2 := NewP256K1Signer()
if err := s2.Generate(); err != nil {
b.Fatal(err)
}
defer s2.Zero()
pub2 := s2.Pub()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if _, err := s1.ECDH(pub2); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkP256K1Signer_Pub benchmarks public key retrieval
func BenchmarkP256K1Signer_Pub(b *testing.B) {
s := NewP256K1Signer()
if err := s.Generate(); err != nil {
b.Fatal(err)
}
defer s.Zero()
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_ = s.Pub()
}
}
// BenchmarkP256K1Gen_Generate benchmarks Gen.Generate
func BenchmarkP256K1Gen_Generate(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
g := NewP256K1Gen()
if _, err := g.Generate(); err != nil {
b.Fatal(err)
}
}
}
// BenchmarkP256K1Gen_Negate benchmarks Gen.Negate
func BenchmarkP256K1Gen_Negate(b *testing.B) {
g := NewP256K1Gen()
if _, err := g.Generate(); err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
g.Negate()
}
}
// BenchmarkP256K1Gen_KeyPairBytes benchmarks Gen.KeyPairBytes
func BenchmarkP256K1Gen_KeyPairBytes(b *testing.B) {
g := NewP256K1Gen()
if _, err := g.Generate(); err != nil {
b.Fatal(err)
}
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
_, _ = g.KeyPairBytes()
}
}