447 lines
13 KiB
Go
447 lines
13 KiB
Go
package avx
|
|
|
|
import "math/bits"
|
|
|
|
// Field operations modulo the secp256k1 field prime p.
|
|
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
|
|
// = 2^256 - 2^32 - 977
|
|
|
|
// SetBytes sets a field element from a 32-byte big-endian slice.
|
|
// Returns true if the value was >= p and was reduced.
|
|
func (f *FieldElement) SetBytes(b []byte) bool {
|
|
if len(b) != 32 {
|
|
panic("field element must be 32 bytes")
|
|
}
|
|
|
|
// Convert big-endian bytes to little-endian limbs
|
|
f.N[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
|
|
uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
|
|
f.N[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
|
|
uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
|
|
f.N[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
|
|
uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
|
|
f.N[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
|
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
|
|
|
// Check overflow and reduce if necessary
|
|
overflow := f.checkOverflow()
|
|
if overflow {
|
|
f.reduce()
|
|
}
|
|
return overflow
|
|
}
|
|
|
|
// Bytes returns the field element as a 32-byte big-endian slice.
|
|
func (f *FieldElement) Bytes() [32]byte {
|
|
var b [32]byte
|
|
b[31] = byte(f.N[0].Lo)
|
|
b[30] = byte(f.N[0].Lo >> 8)
|
|
b[29] = byte(f.N[0].Lo >> 16)
|
|
b[28] = byte(f.N[0].Lo >> 24)
|
|
b[27] = byte(f.N[0].Lo >> 32)
|
|
b[26] = byte(f.N[0].Lo >> 40)
|
|
b[25] = byte(f.N[0].Lo >> 48)
|
|
b[24] = byte(f.N[0].Lo >> 56)
|
|
|
|
b[23] = byte(f.N[0].Hi)
|
|
b[22] = byte(f.N[0].Hi >> 8)
|
|
b[21] = byte(f.N[0].Hi >> 16)
|
|
b[20] = byte(f.N[0].Hi >> 24)
|
|
b[19] = byte(f.N[0].Hi >> 32)
|
|
b[18] = byte(f.N[0].Hi >> 40)
|
|
b[17] = byte(f.N[0].Hi >> 48)
|
|
b[16] = byte(f.N[0].Hi >> 56)
|
|
|
|
b[15] = byte(f.N[1].Lo)
|
|
b[14] = byte(f.N[1].Lo >> 8)
|
|
b[13] = byte(f.N[1].Lo >> 16)
|
|
b[12] = byte(f.N[1].Lo >> 24)
|
|
b[11] = byte(f.N[1].Lo >> 32)
|
|
b[10] = byte(f.N[1].Lo >> 40)
|
|
b[9] = byte(f.N[1].Lo >> 48)
|
|
b[8] = byte(f.N[1].Lo >> 56)
|
|
|
|
b[7] = byte(f.N[1].Hi)
|
|
b[6] = byte(f.N[1].Hi >> 8)
|
|
b[5] = byte(f.N[1].Hi >> 16)
|
|
b[4] = byte(f.N[1].Hi >> 24)
|
|
b[3] = byte(f.N[1].Hi >> 32)
|
|
b[2] = byte(f.N[1].Hi >> 40)
|
|
b[1] = byte(f.N[1].Hi >> 48)
|
|
b[0] = byte(f.N[1].Hi >> 56)
|
|
|
|
return b
|
|
}
|
|
|
|
// IsZero returns true if the field element is zero.
|
|
func (f *FieldElement) IsZero() bool {
|
|
return f.N[0].IsZero() && f.N[1].IsZero()
|
|
}
|
|
|
|
// IsOne returns true if the field element is one.
|
|
func (f *FieldElement) IsOne() bool {
|
|
return f.N[0].Lo == 1 && f.N[0].Hi == 0 && f.N[1].IsZero()
|
|
}
|
|
|
|
// Equal returns true if two field elements are equal.
|
|
func (f *FieldElement) Equal(other *FieldElement) bool {
|
|
return f.N[0].Lo == other.N[0].Lo && f.N[0].Hi == other.N[0].Hi &&
|
|
f.N[1].Lo == other.N[1].Lo && f.N[1].Hi == other.N[1].Hi
|
|
}
|
|
|
|
// IsOdd returns true if the field element is odd.
|
|
func (f *FieldElement) IsOdd() bool {
|
|
return f.N[0].Lo&1 == 1
|
|
}
|
|
|
|
// checkOverflow returns true if f >= p.
|
|
func (f *FieldElement) checkOverflow() bool {
|
|
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
|
|
// Compare high to low
|
|
if f.N[1].Hi > FieldP.N[1].Hi {
|
|
return true
|
|
}
|
|
if f.N[1].Hi < FieldP.N[1].Hi {
|
|
return false
|
|
}
|
|
if f.N[1].Lo > FieldP.N[1].Lo {
|
|
return true
|
|
}
|
|
if f.N[1].Lo < FieldP.N[1].Lo {
|
|
return false
|
|
}
|
|
if f.N[0].Hi > FieldP.N[0].Hi {
|
|
return true
|
|
}
|
|
if f.N[0].Hi < FieldP.N[0].Hi {
|
|
return false
|
|
}
|
|
return f.N[0].Lo >= FieldP.N[0].Lo
|
|
}
|
|
|
|
// reduce reduces f modulo p by adding the complement (2^256 - p = 2^32 + 977).
|
|
func (f *FieldElement) reduce() {
|
|
// f = f - p = f + (2^256 - p) mod 2^256
|
|
// 2^256 - p = 0x1000003D1
|
|
var carry uint64
|
|
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, 0x1000003D1, 0)
|
|
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, 0, carry)
|
|
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
|
|
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
|
|
}
|
|
|
|
// Add sets f = a + b mod p.
|
|
func (f *FieldElement) Add(a, b *FieldElement) *FieldElement {
|
|
var carry uint64
|
|
f.N[0].Lo, carry = bits.Add64(a.N[0].Lo, b.N[0].Lo, 0)
|
|
f.N[0].Hi, carry = bits.Add64(a.N[0].Hi, b.N[0].Hi, carry)
|
|
f.N[1].Lo, carry = bits.Add64(a.N[1].Lo, b.N[1].Lo, carry)
|
|
f.N[1].Hi, carry = bits.Add64(a.N[1].Hi, b.N[1].Hi, carry)
|
|
|
|
// If there was a carry or if result >= p, reduce
|
|
if carry != 0 || f.checkOverflow() {
|
|
f.reduce()
|
|
}
|
|
return f
|
|
}
|
|
|
|
// Sub sets f = a - b mod p.
|
|
func (f *FieldElement) Sub(a, b *FieldElement) *FieldElement {
|
|
var borrow uint64
|
|
f.N[0].Lo, borrow = bits.Sub64(a.N[0].Lo, b.N[0].Lo, 0)
|
|
f.N[0].Hi, borrow = bits.Sub64(a.N[0].Hi, b.N[0].Hi, borrow)
|
|
f.N[1].Lo, borrow = bits.Sub64(a.N[1].Lo, b.N[1].Lo, borrow)
|
|
f.N[1].Hi, borrow = bits.Sub64(a.N[1].Hi, b.N[1].Hi, borrow)
|
|
|
|
// If there was a borrow, add p back
|
|
if borrow != 0 {
|
|
var carry uint64
|
|
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, FieldP.N[0].Lo, 0)
|
|
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, FieldP.N[0].Hi, carry)
|
|
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, FieldP.N[1].Lo, carry)
|
|
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, FieldP.N[1].Hi, carry)
|
|
}
|
|
return f
|
|
}
|
|
|
|
// Negate sets f = -a mod p.
|
|
func (f *FieldElement) Negate(a *FieldElement) *FieldElement {
|
|
if a.IsZero() {
|
|
*f = FieldZero
|
|
return f
|
|
}
|
|
// f = p - a
|
|
var borrow uint64
|
|
f.N[0].Lo, borrow = bits.Sub64(FieldP.N[0].Lo, a.N[0].Lo, 0)
|
|
f.N[0].Hi, borrow = bits.Sub64(FieldP.N[0].Hi, a.N[0].Hi, borrow)
|
|
f.N[1].Lo, borrow = bits.Sub64(FieldP.N[1].Lo, a.N[1].Lo, borrow)
|
|
f.N[1].Hi, _ = bits.Sub64(FieldP.N[1].Hi, a.N[1].Hi, borrow)
|
|
return f
|
|
}
|
|
|
|
// Mul sets f = a * b mod p.
|
|
func (f *FieldElement) Mul(a, b *FieldElement) *FieldElement {
|
|
// Compute 512-bit product
|
|
var prod [8]uint64
|
|
fieldMul512(&prod, a, b)
|
|
|
|
// Reduce mod p using secp256k1's special structure
|
|
fieldReduce512(f, &prod)
|
|
return f
|
|
}
|
|
|
|
// fieldMul512 computes the 512-bit product of two 256-bit field elements.
|
|
func fieldMul512(prod *[8]uint64, a, b *FieldElement) {
|
|
aLimbs := [4]uint64{a.N[0].Lo, a.N[0].Hi, a.N[1].Lo, a.N[1].Hi}
|
|
bLimbs := [4]uint64{b.N[0].Lo, b.N[0].Hi, b.N[1].Lo, b.N[1].Hi}
|
|
|
|
// Clear product
|
|
for i := range prod {
|
|
prod[i] = 0
|
|
}
|
|
|
|
// Schoolbook multiplication
|
|
for i := 0; i < 4; i++ {
|
|
var carry uint64
|
|
for j := 0; j < 4; j++ {
|
|
hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
|
|
lo, c := bits.Add64(lo, prod[i+j], 0)
|
|
hi, _ = bits.Add64(hi, 0, c)
|
|
lo, c = bits.Add64(lo, carry, 0)
|
|
hi, _ = bits.Add64(hi, 0, c)
|
|
prod[i+j] = lo
|
|
carry = hi
|
|
}
|
|
prod[i+4] = carry
|
|
}
|
|
}
|
|
|
|
// fieldReduce512 reduces a 512-bit value mod p using secp256k1's special structure.
|
|
// p = 2^256 - 2^32 - 977, so 2^256 ≡ 2^32 + 977 (mod p)
|
|
func fieldReduce512(f *FieldElement, prod *[8]uint64) {
|
|
// The key insight: if we have a 512-bit number split as H*2^256 + L
|
|
// then H*2^256 + L ≡ H*(2^32 + 977) + L (mod p)
|
|
|
|
// Extract low and high 256-bit parts
|
|
low := [4]uint64{prod[0], prod[1], prod[2], prod[3]}
|
|
high := [4]uint64{prod[4], prod[5], prod[6], prod[7]}
|
|
|
|
// Compute high * (2^32 + 977) = high * 0x1000003D1
|
|
// This gives us at most a 289-bit result (256 + 33 bits)
|
|
const c = uint64(0x1000003D1)
|
|
|
|
var reduction [5]uint64
|
|
var carry uint64
|
|
|
|
for i := 0; i < 4; i++ {
|
|
hi, lo := bits.Mul64(high[i], c)
|
|
lo, cc := bits.Add64(lo, carry, 0)
|
|
hi, _ = bits.Add64(hi, 0, cc)
|
|
reduction[i] = lo
|
|
carry = hi
|
|
}
|
|
reduction[4] = carry
|
|
|
|
// Add low + reduction
|
|
var result [5]uint64
|
|
carry = 0
|
|
for i := 0; i < 4; i++ {
|
|
result[i], carry = bits.Add64(low[i], reduction[i], carry)
|
|
}
|
|
result[4] = carry + reduction[4]
|
|
|
|
// If result[4] is non-zero, we need to reduce again
|
|
// result[4] * 2^256 ≡ result[4] * (2^32 + 977) (mod p)
|
|
if result[4] != 0 {
|
|
hi, lo := bits.Mul64(result[4], c)
|
|
result[0], carry = bits.Add64(result[0], lo, 0)
|
|
result[1], carry = bits.Add64(result[1], hi, carry)
|
|
result[2], carry = bits.Add64(result[2], 0, carry)
|
|
result[3], _ = bits.Add64(result[3], 0, carry)
|
|
result[4] = 0
|
|
}
|
|
|
|
// Store result
|
|
f.N[0].Lo = result[0]
|
|
f.N[0].Hi = result[1]
|
|
f.N[1].Lo = result[2]
|
|
f.N[1].Hi = result[3]
|
|
|
|
// Final reduction if >= p
|
|
if f.checkOverflow() {
|
|
f.reduce()
|
|
}
|
|
}
|
|
|
|
// Sqr sets f = a^2 mod p.
|
|
func (f *FieldElement) Sqr(a *FieldElement) *FieldElement {
|
|
// Optimized squaring could save some multiplications, but for now use Mul
|
|
return f.Mul(a, a)
|
|
}
|
|
|
|
// Inverse sets f = a^(-1) mod p using Fermat's little theorem.
|
|
// a^(-1) = a^(p-2) mod p
|
|
func (f *FieldElement) Inverse(a *FieldElement) *FieldElement {
|
|
// p-2 in bytes (big-endian)
|
|
// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
|
|
// p-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D
|
|
pMinus2 := [32]byte{
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2D,
|
|
}
|
|
|
|
var result, base FieldElement
|
|
result = FieldOne
|
|
base = *a
|
|
|
|
for i := 0; i < 32; i++ {
|
|
b := pMinus2[31-i]
|
|
for j := 0; j < 8; j++ {
|
|
if (b>>j)&1 == 1 {
|
|
result.Mul(&result, &base)
|
|
}
|
|
base.Sqr(&base)
|
|
}
|
|
}
|
|
|
|
*f = result
|
|
return f
|
|
}
|
|
|
|
// Sqrt sets f = sqrt(a) mod p if it exists, returns true if successful.
|
|
// For secp256k1, p ≡ 3 (mod 4), so sqrt(a) = a^((p+1)/4) mod p
|
|
func (f *FieldElement) Sqrt(a *FieldElement) bool {
|
|
// (p+1)/4 in bytes
|
|
// p+1 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30
|
|
// (p+1)/4 = 3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFF0C
|
|
pPlus1Div4 := [32]byte{
|
|
0x3F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
|
|
0xFF, 0xFF, 0xFF, 0xFF, 0xBF, 0xFF, 0xFF, 0x0C,
|
|
}
|
|
|
|
var result, base FieldElement
|
|
result = FieldOne
|
|
base = *a
|
|
|
|
for i := 0; i < 32; i++ {
|
|
b := pPlus1Div4[31-i]
|
|
for j := 0; j < 8; j++ {
|
|
if (b>>j)&1 == 1 {
|
|
result.Mul(&result, &base)
|
|
}
|
|
base.Sqr(&base)
|
|
}
|
|
}
|
|
|
|
// Verify: result^2 should equal a
|
|
var check FieldElement
|
|
check.Sqr(&result)
|
|
|
|
if check.Equal(a) {
|
|
*f = result
|
|
return true
|
|
}
|
|
return false
|
|
}
|
|
|
|
// MulInt sets f = a * n mod p where n is a small integer.
|
|
func (f *FieldElement) MulInt(a *FieldElement, n uint64) *FieldElement {
|
|
if n == 0 {
|
|
*f = FieldZero
|
|
return f
|
|
}
|
|
if n == 1 {
|
|
*f = *a
|
|
return f
|
|
}
|
|
|
|
// Multiply by small integer using proper carry chain
|
|
// We need to compute a 320-bit result (256 + 64 bits max)
|
|
var result [5]uint64
|
|
var carry uint64
|
|
|
|
// Multiply each 64-bit limb by n
|
|
var hi uint64
|
|
hi, result[0] = bits.Mul64(a.N[0].Lo, n)
|
|
carry = hi
|
|
|
|
hi, result[1] = bits.Mul64(a.N[0].Hi, n)
|
|
result[1], carry = bits.Add64(result[1], carry, 0)
|
|
carry = hi + carry // carry can be at most 1 here, so no overflow
|
|
|
|
hi, result[2] = bits.Mul64(a.N[1].Lo, n)
|
|
result[2], carry = bits.Add64(result[2], carry, 0)
|
|
carry = hi + carry
|
|
|
|
hi, result[3] = bits.Mul64(a.N[1].Hi, n)
|
|
result[3], carry = bits.Add64(result[3], carry, 0)
|
|
result[4] = hi + carry
|
|
|
|
// Store preliminary result
|
|
f.N[0].Lo = result[0]
|
|
f.N[0].Hi = result[1]
|
|
f.N[1].Lo = result[2]
|
|
f.N[1].Hi = result[3]
|
|
|
|
// Reduce overflow
|
|
if result[4] != 0 {
|
|
// overflow * 2^256 ≡ overflow * (2^32 + 977) (mod p)
|
|
hi, lo := bits.Mul64(result[4], 0x1000003D1)
|
|
f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, lo, 0)
|
|
f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, hi, carry)
|
|
f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
|
|
f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
|
|
}
|
|
|
|
if f.checkOverflow() {
|
|
f.reduce()
|
|
}
|
|
return f
|
|
}
|
|
|
|
// Double sets f = 2*a mod p (optimized addition).
|
|
func (f *FieldElement) Double(a *FieldElement) *FieldElement {
|
|
return f.Add(a, a)
|
|
}
|
|
|
|
// Half sets f = a/2 mod p.
|
|
func (f *FieldElement) Half(a *FieldElement) *FieldElement {
|
|
// If a is even, just shift right
|
|
// If a is odd, add p first (which makes it even), then shift right
|
|
var result FieldElement = *a
|
|
|
|
if result.N[0].Lo&1 == 1 {
|
|
// Add p
|
|
var carry uint64
|
|
result.N[0].Lo, carry = bits.Add64(result.N[0].Lo, FieldP.N[0].Lo, 0)
|
|
result.N[0].Hi, carry = bits.Add64(result.N[0].Hi, FieldP.N[0].Hi, carry)
|
|
result.N[1].Lo, carry = bits.Add64(result.N[1].Lo, FieldP.N[1].Lo, carry)
|
|
result.N[1].Hi, _ = bits.Add64(result.N[1].Hi, FieldP.N[1].Hi, carry)
|
|
}
|
|
|
|
// Shift right by 1
|
|
f.N[0].Lo = (result.N[0].Lo >> 1) | (result.N[0].Hi << 63)
|
|
f.N[0].Hi = (result.N[0].Hi >> 1) | (result.N[1].Lo << 63)
|
|
f.N[1].Lo = (result.N[1].Lo >> 1) | (result.N[1].Hi << 63)
|
|
f.N[1].Hi = result.N[1].Hi >> 1
|
|
|
|
return f
|
|
}
|
|
|
|
// CMov conditionally moves b into f if cond is true (constant-time).
|
|
func (f *FieldElement) CMov(b *FieldElement, cond bool) *FieldElement {
|
|
mask := uint64(0)
|
|
if cond {
|
|
mask = ^uint64(0)
|
|
}
|
|
f.N[0].Lo = (f.N[0].Lo &^ mask) | (b.N[0].Lo & mask)
|
|
f.N[0].Hi = (f.N[0].Hi &^ mask) | (b.N[0].Hi & mask)
|
|
f.N[1].Lo = (f.N[1].Lo &^ mask) | (b.N[1].Lo & mask)
|
|
f.N[1].Hi = (f.N[1].Hi &^ mask) | (b.N[1].Hi & mask)
|
|
return f
|
|
}
|