Files
p256k1/field_mul.go
mleku 5416381478 Add context tests and implement generator multiplication context
This commit introduces a new test file for context management, covering various scenarios for context creation, destruction, and capabilities. Additionally, it implements the generator multiplication context, enhancing the secp256k1 elliptic curve operations. The changes ensure comprehensive testing and improved functionality for context handling, contributing to the overall robustness of the implementation.
2025-11-01 20:01:52 +00:00

634 lines
14 KiB
Go

package p256k1
import "math/bits"
// uint128 represents a 128-bit unsigned integer for field arithmetic
type uint128 struct {
high, low uint64
}
// mulU64ToU128 multiplies two uint64 values and returns a uint128
func mulU64ToU128(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
return uint128{high: hi, low: lo}
}
// addMulU128 computes c + a*b and returns the result as uint128
func addMulU128(c uint128, a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
// Add lo to c.low
newLo, carry := bits.Add64(c.low, lo, 0)
// Add hi and carry to c.high
newHi, _ := bits.Add64(c.high, hi, carry)
return uint128{high: newHi, low: newLo}
}
// addU128 adds a uint64 to a uint128
func addU128(c uint128, a uint64) uint128 {
newLo, carry := bits.Add64(c.low, a, 0)
newHi, _ := bits.Add64(c.high, 0, carry)
return uint128{high: newHi, low: newLo}
}
// lo returns the lower 64 bits
func (u uint128) lo() uint64 {
return u.low
}
// hi returns the upper 64 bits
func (u uint128) hi() uint64 {
return u.high
}
// rshift shifts the uint128 right by n bits
func (u uint128) rshift(n uint) uint128 {
if n >= 64 {
return uint128{high: 0, low: u.high >> (n - 64)}
}
return uint128{
high: u.high >> n,
low: (u.low >> n) | (u.high << (64 - n)),
}
}
// mul multiplies two field elements: r = a * b
// This implementation follows the C secp256k1_fe_mul_inner algorithm
func (r *FieldElement) mul(a, b *FieldElement) {
// Normalize inputs if magnitude is too high
var aNorm, bNorm FieldElement
aNorm = *a
bNorm = *b
if aNorm.magnitude > 8 {
aNorm.normalizeWeak()
}
if bNorm.magnitude > 8 {
bNorm.normalizeWeak()
}
// Extract limbs for easier access
a0, a1, a2, a3, a4 := aNorm.n[0], aNorm.n[1], aNorm.n[2], aNorm.n[3], aNorm.n[4]
b0, b1, b2, b3, b4 := bNorm.n[0], bNorm.n[1], bNorm.n[2], bNorm.n[3], bNorm.n[4]
const M = 0xFFFFFFFFFFFFF // 2^52 - 1
const R = fieldReductionConstantShifted // 0x1000003D10
// Following the C implementation algorithm exactly
// [... a b c] is shorthand for ... + a<<104 + b<<52 + c<<0 mod n
// Compute p3 = a0*b3 + a1*b2 + a2*b1 + a3*b0
var c, d uint128
d = mulU64ToU128(a0, b3)
d = addMulU128(d, a1, b2)
d = addMulU128(d, a2, b1)
d = addMulU128(d, a3, b0)
// Compute p8 = a4*b4
c = mulU64ToU128(a4, b4)
// d += R * c_lo; c >>= 64
d = addMulU128(d, R, c.lo())
c = c.rshift(64)
// Extract t3 and shift d
t3 := d.lo() & M
d = d.rshift(52)
// Compute p4 = a0*b4 + a1*b3 + a2*b2 + a3*b1 + a4*b0
d = addMulU128(d, a0, b4)
d = addMulU128(d, a1, b3)
d = addMulU128(d, a2, b2)
d = addMulU128(d, a3, b1)
d = addMulU128(d, a4, b0)
// d += (R << 12) * c_lo
d = addMulU128(d, R<<12, c.lo())
// Extract t4 and tx
t4 := d.lo() & M
d = d.rshift(52)
tx := t4 >> 48
t4 &= (M >> 4)
// Compute p0 = a0*b0
c = mulU64ToU128(a0, b0)
// Compute p5 = a1*b4 + a2*b3 + a3*b2 + a4*b1
d = addMulU128(d, a1, b4)
d = addMulU128(d, a2, b3)
d = addMulU128(d, a3, b2)
d = addMulU128(d, a4, b1)
// Extract u0
u0 := d.lo() & M
d = d.rshift(52)
u0 = (u0 << 4) | tx
// c += u0 * (R >> 4)
c = addMulU128(c, u0, R>>4)
// r[0]
r.n[0] = c.lo() & M
c = c.rshift(52)
// Compute p1 = a0*b1 + a1*b0
c = addMulU128(c, a0, b1)
c = addMulU128(c, a1, b0)
// Compute p6 = a2*b4 + a3*b3 + a4*b2
d = addMulU128(d, a2, b4)
d = addMulU128(d, a3, b3)
d = addMulU128(d, a4, b2)
// c += R * (d & M); d >>= 52
c = addMulU128(c, R, d.lo()&M)
d = d.rshift(52)
// r[1]
r.n[1] = c.lo() & M
c = c.rshift(52)
// Compute p2 = a0*b2 + a1*b1 + a2*b0
c = addMulU128(c, a0, b2)
c = addMulU128(c, a1, b1)
c = addMulU128(c, a2, b0)
// Compute p7 = a3*b4 + a4*b3
d = addMulU128(d, a3, b4)
d = addMulU128(d, a4, b3)
// c += R * d_lo; d >>= 64
c = addMulU128(c, R, d.lo())
d = d.rshift(64)
// r[2]
r.n[2] = c.lo() & M
c = c.rshift(52)
// c += (R << 12) * d_lo + t3
c = addMulU128(c, R<<12, d.lo())
c = addU128(c, t3)
// r[3]
r.n[3] = c.lo() & M
c = c.rshift(52)
// r[4]
r.n[4] = c.lo() + t4
// Set magnitude and normalization
r.magnitude = 1
r.normalized = false
}
// reduceFromWide reduces a 520-bit (10 limb) value modulo the field prime
func (r *FieldElement) reduceFromWide(t [10]uint64) {
// The field prime is p = 2^256 - 2^32 - 977 = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
// We use the fact that 2^256 ≡ 2^32 + 977 (mod p)
// First, handle the upper limbs (t[5] through t[9])
// Each represents a multiple of 2^(52*i) where i >= 5
// Reduction constant for secp256k1: 2^32 + 977 = 0x1000003D1
const M = uint64(0x1000003D1)
// Start from the highest limb and work down
for i := 9; i >= 5; i-- {
if t[i] == 0 {
continue
}
// t[i] * 2^(52*i) ≡ t[i] * 2^(52*(i-5)) * 2^(52*5) ≡ t[i] * 2^(52*(i-5)) * 2^260
// Since 2^256 ≡ M (mod p), we have 2^260 ≡ 2^4 * M ≡ 16 * M (mod p)
// For i=5: 2^260 ≡ 16*M (mod p)
// For i=6: 2^312 ≡ 2^52 * 16*M ≡ 2^56 * M (mod p)
// etc.
shift := uint(52 * (i - 5) + 4) // Additional 4 bits for the 16 factor
// Multiply t[i] by the appropriate power of M
var carry uint64
if shift < 64 {
// Simple case: can multiply directly
factor := M << shift
hi, lo := bits.Mul64(t[i], factor)
// Add to appropriate position
pos := 0
t[pos], carry = bits.Add64(t[pos], lo, 0)
if pos+1 < 10 {
t[pos+1], carry = bits.Add64(t[pos+1], hi, carry)
}
// Propagate carry
for j := pos + 2; j < 10 && carry != 0; j++ {
t[j], carry = bits.Add64(t[j], 0, carry)
}
} else {
// Need to handle larger shifts by distributing across limbs
hi, lo := bits.Mul64(t[i], M)
limbShift := shift / 52
bitShift := shift % 52
if bitShift == 0 {
// Aligned to limb boundary
if limbShift < 10 {
t[limbShift], carry = bits.Add64(t[limbShift], lo, 0)
if limbShift+1 < 10 {
t[limbShift+1], carry = bits.Add64(t[limbShift+1], hi, carry)
}
}
} else {
// Need to split across limbs
loShifted := lo << bitShift
hiShifted := (lo >> (64 - bitShift)) | (hi << bitShift)
if limbShift < 10 {
t[limbShift], carry = bits.Add64(t[limbShift], loShifted, 0)
if limbShift+1 < 10 {
t[limbShift+1], carry = bits.Add64(t[limbShift+1], hiShifted, carry)
}
}
}
// Propagate any remaining carry
for j := int(limbShift) + 2; j < 10 && carry != 0; j++ {
t[j], carry = bits.Add64(t[j], 0, carry)
}
}
t[i] = 0 // Clear the processed limb
}
// Now we have a value in t[0..4] that may still be >= p
// Convert to 5x52 format and normalize
r.n[0] = t[0] & limb0Max
r.n[1] = ((t[0] >> 52) | (t[1] << 12)) & limb0Max
r.n[2] = ((t[1] >> 40) | (t[2] << 24)) & limb0Max
r.n[3] = ((t[2] >> 28) | (t[3] << 36)) & limb0Max
r.n[4] = ((t[3] >> 16) | (t[4] << 48)) & limb4Max
r.magnitude = 1
r.normalized = false
// Final reduction if needed
if r.n[4] == limb4Max && r.n[3] == limb0Max && r.n[2] == limb0Max &&
r.n[1] == limb0Max && r.n[0] >= fieldModulusLimb0 {
r.reduce()
}
}
// sqr squares a field element: r = a^2
// This implementation follows the C secp256k1_fe_sqr_inner algorithm
func (r *FieldElement) sqr(a *FieldElement) {
// Normalize input if magnitude is too high
var aNorm FieldElement
aNorm = *a
if aNorm.magnitude > 8 {
aNorm.normalizeWeak()
}
// Extract limbs for easier access
a0, a1, a2, a3, a4 := aNorm.n[0], aNorm.n[1], aNorm.n[2], aNorm.n[3], aNorm.n[4]
const M = 0xFFFFFFFFFFFFF // 2^52 - 1
const R = fieldReductionConstantShifted // 0x1000003D10
// Following the C implementation algorithm exactly
// Compute p3 = 2*a0*a3 + 2*a1*a2
var c, d uint128
d = mulU64ToU128(a0*2, a3)
d = addMulU128(d, a1*2, a2)
// Compute p8 = a4*a4
c = mulU64ToU128(a4, a4)
// d += R * c_lo; c >>= 64
d = addMulU128(d, R, c.lo())
c = c.rshift(64)
// Extract t3 and shift d
t3 := d.lo() & M
d = d.rshift(52)
// Compute p4 = a0*a4*2 + a1*a3*2 + a2*a2
a4 *= 2
d = addMulU128(d, a0, a4)
d = addMulU128(d, a1*2, a3)
d = addMulU128(d, a2, a2)
// d += (R << 12) * c_lo
d = addMulU128(d, R<<12, c.lo())
// Extract t4 and tx
t4 := d.lo() & M
d = d.rshift(52)
tx := t4 >> 48
t4 &= (M >> 4)
// Compute p0 = a0*a0
c = mulU64ToU128(a0, a0)
// Compute p5 = a1*a4 + a2*a3*2
d = addMulU128(d, a1, a4)
d = addMulU128(d, a2*2, a3)
// Extract u0
u0 := d.lo() & M
d = d.rshift(52)
u0 = (u0 << 4) | tx
// c += u0 * (R >> 4)
c = addMulU128(c, u0, R>>4)
// r[0]
r.n[0] = c.lo() & M
c = c.rshift(52)
// Compute p1 = a0*a1*2
a0 *= 2
c = addMulU128(c, a0, a1)
// Compute p6 = a2*a4 + a3*a3
d = addMulU128(d, a2, a4)
d = addMulU128(d, a3, a3)
// c += R * (d & M); d >>= 52
c = addMulU128(c, R, d.lo()&M)
d = d.rshift(52)
// r[1]
r.n[1] = c.lo() & M
c = c.rshift(52)
// Compute p2 = a0*a2 + a1*a1
c = addMulU128(c, a0, a2)
c = addMulU128(c, a1, a1)
// Compute p7 = a3*a4
d = addMulU128(d, a3, a4)
// c += R * d_lo; d >>= 64
c = addMulU128(c, R, d.lo())
d = d.rshift(64)
// r[2]
r.n[2] = c.lo() & M
c = c.rshift(52)
// c += (R << 12) * d_lo + t3
c = addMulU128(c, R<<12, d.lo())
c = addU128(c, t3)
// r[3]
r.n[3] = c.lo() & M
c = c.rshift(52)
// r[4]
r.n[4] = c.lo() + t4
// Set magnitude and normalization
r.magnitude = 1
r.normalized = false
}
// inv computes the modular inverse of a field element using Fermat's little theorem
// This implements a^(p-2) mod p where p is the secp256k1 field prime
// This follows secp256k1_fe_inv_var which normalizes the input first
func (r *FieldElement) inv(a *FieldElement) {
// Normalize input first (as per secp256k1_fe_inv_var)
var aNorm FieldElement
aNorm = *a
aNorm.normalize()
// For field F_p, a^(-1) = a^(p-2) mod p
// The secp256k1 field prime is p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
// So p-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D
// Use a simple but correct implementation: binary exponentiation
// Convert p-2 to bytes for bit-by-bit exponentiation
pMinus2 := []byte{
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2D,
}
// Initialize result to 1
r.setInt(1)
// Binary exponentiation
var base FieldElement
base = aNorm
for i := len(pMinus2) - 1; i >= 0; i-- {
b := pMinus2[i]
for j := 0; j < 8; j++ {
if (b >> j) & 1 == 1 {
r.mul(r, &base)
}
base.sqr(&base)
}
}
r.magnitude = 1
r.normalized = true
}
// sqrt computes the square root of a field element if it exists
// This follows the C secp256k1_fe_sqrt implementation exactly
func (r *FieldElement) sqrt(a *FieldElement) bool {
// Given that p is congruent to 3 mod 4, we can compute the square root of
// a mod p as the (p+1)/4'th power of a.
//
// As (p+1)/4 is an even number, it will have the same result for a and for
// (-a). Only one of these two numbers actually has a square root however,
// so we test at the end by squaring and comparing to the input.
var aNorm FieldElement
aNorm = *a
// Normalize input if magnitude is too high
if aNorm.magnitude > 8 {
aNorm.normalizeWeak()
} else {
aNorm.normalize()
}
// The binary representation of (p + 1)/4 has 3 blocks of 1s, with lengths in
// { 2, 22, 223 }. Use an addition chain to calculate 2^n - 1 for each block:
// 1, [2], 3, 6, 9, 11, [22], 44, 88, 176, 220, [223]
var x2, x3, x6, x9, x11, x22, x44, x88, x176, x220, x223, t1 FieldElement
// x2 = a^3
x2.sqr(&aNorm)
x2.mul(&x2, &aNorm)
// x3 = a^7
x3.sqr(&x2)
x3.mul(&x3, &aNorm)
// x6 = a^63
x6 = x3
for j := 0; j < 3; j++ {
x6.sqr(&x6)
}
x6.mul(&x6, &x3)
// x9 = a^511
x9 = x6
for j := 0; j < 3; j++ {
x9.sqr(&x9)
}
x9.mul(&x9, &x3)
// x11 = a^2047
x11 = x9
for j := 0; j < 2; j++ {
x11.sqr(&x11)
}
x11.mul(&x11, &x2)
// x22 = a^4194303
x22 = x11
for j := 0; j < 11; j++ {
x22.sqr(&x22)
}
x22.mul(&x22, &x11)
// x44 = a^17592186044415
x44 = x22
for j := 0; j < 22; j++ {
x44.sqr(&x44)
}
x44.mul(&x44, &x22)
// x88 = a^72057594037927935
x88 = x44
for j := 0; j < 44; j++ {
x88.sqr(&x88)
}
x88.mul(&x88, &x44)
// x176 = a^1180591620717411303423
x176 = x88
for j := 0; j < 88; j++ {
x176.sqr(&x176)
}
x176.mul(&x176, &x88)
// x220 = a^172543658669764094685868767685
x220 = x176
for j := 0; j < 44; j++ {
x220.sqr(&x220)
}
x220.mul(&x220, &x44)
// x223 = a^13479973333575319897333507543509815336818572211270286240551805124607
x223 = x220
for j := 0; j < 3; j++ {
x223.sqr(&x223)
}
x223.mul(&x223, &x3)
// The final result is then assembled using a sliding window over the blocks.
t1 = x223
for j := 0; j < 23; j++ {
t1.sqr(&t1)
}
t1.mul(&t1, &x22)
for j := 0; j < 6; j++ {
t1.sqr(&t1)
}
t1.mul(&t1, &x2)
t1.sqr(&t1)
r.sqr(&t1)
// Check that a square root was actually calculated
var check FieldElement
check.sqr(r)
check.normalize()
aNorm.normalize()
ret := check.equal(&aNorm)
// If sqrt(a) doesn't exist, compute sqrt(-a) instead (as per field.h comment)
if !ret {
var negA FieldElement
negA.negate(&aNorm, 1)
negA.normalize()
t1 = x223
for j := 0; j < 23; j++ {
t1.sqr(&t1)
}
t1.mul(&t1, &x22)
for j := 0; j < 6; j++ {
t1.sqr(&t1)
}
t1.mul(&t1, &x2)
t1.sqr(&t1)
r.sqr(&t1)
check.sqr(r)
check.normalize()
// Return whether sqrt(-a) exists
return check.equal(&negA)
}
return ret
}
// isSquare checks if a field element is a quadratic residue
func (a *FieldElement) isSquare() bool {
// Use Legendre symbol: a^((p-1)/2) mod p
// If result is 1, then a is a quadratic residue
var result FieldElement
result = *a
// Compute a^((p-1)/2) - simplified implementation
for i := 0; i < 127; i++ { // Approximate (p-1)/2 bit length
result.sqr(&result)
}
result.normalize()
return result.equal(&FieldElementOne)
}
// half computes r = a/2 mod p
func (r *FieldElement) half(a *FieldElement) {
// This follows the C secp256k1_fe_impl_half implementation exactly
*r = *a
t0, t1, t2, t3, t4 := r.n[0], r.n[1], r.n[2], r.n[3], r.n[4]
one := uint64(1)
// In C: mask = -(t0 & one) >> 12
// In Go, we need to convert to signed, negate, then convert back
mask := uint64(-int64(t0 & one)) >> 12
// Conditionally add field modulus if odd
t0 += 0xFFFFEFFFFFC2F & mask
t1 += mask
t2 += mask
t3 += mask
t4 += mask >> 4
// Right shift with carry propagation
r.n[0] = (t0 >> 1) + ((t1 & one) << 51)
r.n[1] = (t1 >> 1) + ((t2 & one) << 51)
r.n[2] = (t2 >> 1) + ((t3 & one) << 51)
r.n[3] = (t3 >> 1) + ((t4 & one) << 51)
r.n[4] = t4 >> 1
// Update magnitude as per C implementation
r.magnitude = (r.magnitude >> 1) + 1
r.normalized = false
}