Add context tests and implement generator multiplication context

This commit introduces a new test file for context management, covering various scenarios for context creation, destruction, and capabilities. Additionally, it implements the generator multiplication context, enhancing the secp256k1 elliptic curve operations. The changes ensure comprehensive testing and improved functionality for context handling, contributing to the overall robustness of the implementation.
This commit is contained in:
2025-11-01 20:01:52 +00:00
parent 715bdff306
commit 5416381478
37 changed files with 2213 additions and 7833 deletions

575
scalar.go
View File

@@ -6,22 +6,21 @@ import (
"unsafe"
)
// Scalar represents a scalar modulo the group order of the secp256k1 curve
// This implementation uses 4 uint64 limbs, ported from scalar_4x64.h
// Scalar represents a scalar value modulo the secp256k1 group order.
// Uses 4 uint64 limbs to represent a 256-bit scalar.
type Scalar struct {
d [4]uint64
}
// Group order constants (secp256k1 curve order n)
// Scalar constants from the C implementation
const (
// Limbs of the secp256k1 order
// Limbs of the secp256k1 order n
scalarN0 = 0xBFD25E8CD0364141
scalarN1 = 0xBAAEDCE6AF48A03B
scalarN2 = 0xFFFFFFFFFFFFFFFE
scalarN3 = 0xFFFFFFFFFFFFFFFF
// Limbs of 2^256 minus the secp256k1 order
// These are precomputed values to avoid overflow issues
// Limbs of 2^256 minus the secp256k1 order (complement constants)
scalarNC0 = 0x402DA1732FC9BEBF // ~scalarN0 + 1
scalarNC1 = 0x4551231950B75FC4 // ~scalarN1
scalarNC2 = 0x0000000000000001 // 1
@@ -33,7 +32,7 @@ const (
scalarNH3 = 0x7FFFFFFFFFFFFFFF
)
// Scalar constants
// Scalar element constants
var (
// ScalarZero represents the scalar 0
ScalarZero = Scalar{d: [4]uint64{0, 0, 0, 0}}
@@ -42,53 +41,7 @@ var (
ScalarOne = Scalar{d: [4]uint64{1, 0, 0, 0}}
)
// NewScalar creates a new scalar from a 32-byte big-endian array
func NewScalar(b32 []byte) *Scalar {
if len(b32) != 32 {
panic("input must be 32 bytes")
}
s := &Scalar{}
s.setB32(b32)
return s
}
// setB32 sets a scalar from a 32-byte big-endian array, reducing modulo group order
func (r *Scalar) setB32(bin []byte) (overflow bool) {
// Convert from big-endian bytes to limbs
r.d[0] = readBE64(bin[24:32])
r.d[1] = readBE64(bin[16:24])
r.d[2] = readBE64(bin[8:16])
r.d[3] = readBE64(bin[0:8])
// Check for overflow and reduce if necessary
overflow = r.checkOverflow()
if overflow {
r.reduce(1)
}
return overflow
}
// setB32Seckey sets a scalar from a 32-byte array and returns true if it's a valid secret key
func (r *Scalar) setB32Seckey(bin []byte) bool {
overflow := r.setB32(bin)
return !overflow && !r.isZero()
}
// getB32 converts a scalar to a 32-byte big-endian array
func (r *Scalar) getB32(bin []byte) {
if len(bin) != 32 {
panic("output buffer must be 32 bytes")
}
writeBE64(bin[0:8], r.d[3])
writeBE64(bin[8:16], r.d[2])
writeBE64(bin[16:24], r.d[1])
writeBE64(bin[24:32], r.d[0])
}
// setInt sets a scalar to an unsigned integer value
// setInt sets a scalar to a small integer value
func (r *Scalar) setInt(v uint) {
r.d[0] = uint64(v)
r.d[1] = 0
@@ -96,31 +49,113 @@ func (r *Scalar) setInt(v uint) {
r.d[3] = 0
}
// setB32 sets a scalar from a 32-byte big-endian array
func (r *Scalar) setB32(b []byte) bool {
if len(b) != 32 {
panic("scalar byte array must be 32 bytes")
}
// Convert from big-endian bytes to uint64 limbs
r.d[0] = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
r.d[1] = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
r.d[2] = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
r.d[3] = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
// Check if the scalar overflows the group order
overflow := r.checkOverflow()
if overflow {
r.reduce(1)
}
return overflow
}
// setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
func (r *Scalar) setB32Seckey(b []byte) bool {
overflow := r.setB32(b)
return !r.isZero() && !overflow
}
// getB32 converts a scalar to a 32-byte big-endian array
func (r *Scalar) getB32(b []byte) {
if len(b) != 32 {
panic("scalar byte array must be 32 bytes")
}
// Convert from uint64 limbs to big-endian bytes
b[31] = byte(r.d[0])
b[30] = byte(r.d[0] >> 8)
b[29] = byte(r.d[0] >> 16)
b[28] = byte(r.d[0] >> 24)
b[27] = byte(r.d[0] >> 32)
b[26] = byte(r.d[0] >> 40)
b[25] = byte(r.d[0] >> 48)
b[24] = byte(r.d[0] >> 56)
b[23] = byte(r.d[1])
b[22] = byte(r.d[1] >> 8)
b[21] = byte(r.d[1] >> 16)
b[20] = byte(r.d[1] >> 24)
b[19] = byte(r.d[1] >> 32)
b[18] = byte(r.d[1] >> 40)
b[17] = byte(r.d[1] >> 48)
b[16] = byte(r.d[1] >> 56)
b[15] = byte(r.d[2])
b[14] = byte(r.d[2] >> 8)
b[13] = byte(r.d[2] >> 16)
b[12] = byte(r.d[2] >> 24)
b[11] = byte(r.d[2] >> 32)
b[10] = byte(r.d[2] >> 40)
b[9] = byte(r.d[2] >> 48)
b[8] = byte(r.d[2] >> 56)
b[7] = byte(r.d[3])
b[6] = byte(r.d[3] >> 8)
b[5] = byte(r.d[3] >> 16)
b[4] = byte(r.d[3] >> 24)
b[3] = byte(r.d[3] >> 32)
b[2] = byte(r.d[3] >> 40)
b[1] = byte(r.d[3] >> 48)
b[0] = byte(r.d[3] >> 56)
}
// checkOverflow checks if the scalar is >= the group order
func (r *Scalar) checkOverflow() bool {
// Simple comparison with group order
if r.d[3] > scalarN3 {
return true
}
yes := 0
no := 0
// Check each limb from most significant to least significant
if r.d[3] < scalarN3 {
return false
no = 1
}
if r.d[3] > scalarN3 {
yes = 1
}
if r.d[2] > scalarN2 {
return true
}
if r.d[2] < scalarN2 {
return false
no |= (yes ^ 1)
}
if r.d[2] > scalarN2 {
yes |= (no ^ 1)
}
if r.d[1] > scalarN1 {
return true
}
if r.d[1] < scalarN1 {
return false
no |= (yes ^ 1)
}
if r.d[1] > scalarN1 {
yes |= (no ^ 1)
}
return r.d[0] >= scalarN0
if r.d[0] >= scalarN0 {
yes |= (no ^ 1)
}
return yes != 0
}
// reduce reduces the scalar modulo the group order
@@ -129,20 +164,30 @@ func (r *Scalar) reduce(overflow int) {
panic("overflow must be 0 or 1")
}
// Subtract overflow * n from the scalar
var borrow uint64
// Use 128-bit arithmetic for the reduction
var t uint128
// d[0] -= overflow * scalarN0
r.d[0], borrow = bits.Sub64(r.d[0], uint64(overflow)*scalarN0, 0)
// d[0] += overflow * scalarNC0
t = uint128FromU64(r.d[0])
t = t.addU64(uint64(overflow) * scalarNC0)
r.d[0] = t.lo()
t = t.rshift(64)
// d[1] -= overflow * scalarN1 + borrow
r.d[1], borrow = bits.Sub64(r.d[1], uint64(overflow)*scalarN1, borrow)
// d[1] += overflow * scalarNC1 + carry
t = t.addU64(r.d[1])
t = t.addU64(uint64(overflow) * scalarNC1)
r.d[1] = t.lo()
t = t.rshift(64)
// d[2] -= overflow * scalarN2 + borrow
r.d[2], borrow = bits.Sub64(r.d[2], uint64(overflow)*scalarN2, borrow)
// d[2] += overflow * scalarNC2 + carry
t = t.addU64(r.d[2])
t = t.addU64(uint64(overflow) * scalarNC2)
r.d[2] = t.lo()
t = t.rshift(64)
// d[3] -= overflow * scalarN3 + borrow
r.d[3], _ = bits.Sub64(r.d[3], uint64(overflow)*scalarN3, borrow)
// d[3] += carry (scalarNC3 = 0)
t = t.addU64(r.d[3])
r.d[3] = t.lo()
}
// add adds two scalars: r = a + b, returns overflow
@@ -174,94 +219,217 @@ func (r *Scalar) sub(a, b *Scalar) {
// mul multiplies two scalars: r = a * b
func (r *Scalar) mul(a, b *Scalar) {
// Compute full 512-bit product using all 16 cross products
var c [8]uint64
// Compute all cross products a[i] * b[j]
for i := 0; i < 4; i++ {
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(a.d[i], b.d[j])
k := i + j
// Add lo to c[k]
var carry uint64
c[k], carry = bits.Add64(c[k], lo, 0)
// Add hi to c[k+1] and propagate carry
if k+1 < 8 {
c[k+1], carry = bits.Add64(c[k+1], hi, carry)
// Propagate any remaining carry
for l := k + 2; l < 8 && carry != 0; l++ {
c[l], carry = bits.Add64(c[l], 0, carry)
}
}
}
}
// Reduce the 512-bit result modulo the group order
r.reduceWide(c)
var l [8]uint64
r.mul512(l[:], a, b)
r.reduce512(l[:])
}
// reduceWide reduces a 512-bit value modulo the group order
func (r *Scalar) reduceWide(wide [8]uint64) {
// For now, use a very simple approach that just takes the lower 256 bits
// and ignores the upper bits. This is incorrect but will allow testing
// of other functionality. A proper implementation would use Barrett reduction.
r.d[0] = wide[0]
r.d[1] = wide[1]
r.d[2] = wide[2]
r.d[3] = wide[3]
// If there are upper bits, we need to do some reduction
// For now, just add a simple approximation
if wide[4] != 0 || wide[5] != 0 || wide[6] != 0 || wide[7] != 0 {
// Very crude approximation: add the upper bits to the lower bits
// This is mathematically incorrect but prevents infinite loops
// mul512 computes the 512-bit product of two scalars (from C implementation)
func (r *Scalar) mul512(l8 []uint64, a, b *Scalar) {
// 160-bit accumulator (c0, c1, c2)
var c0, c1 uint64
var c2 uint32
// Helper macros translated from C
muladd := func(ai, bi uint64) {
hi, lo := bits.Mul64(ai, bi)
var carry uint64
r.d[0], carry = bits.Add64(r.d[0], wide[4], 0)
r.d[1], carry = bits.Add64(r.d[1], wide[5], carry)
r.d[2], carry = bits.Add64(r.d[2], wide[6], carry)
r.d[3], _ = bits.Add64(r.d[3], wide[7], carry)
c0, carry = bits.Add64(c0, lo, 0)
c1, carry = bits.Add64(c1, hi, carry)
c2 += uint32(carry)
}
// Check if we need reduction
if r.checkOverflow() {
r.reduce(1)
muladdFast := func(ai, bi uint64) {
hi, lo := bits.Mul64(ai, bi)
var carry uint64
c0, carry = bits.Add64(c0, lo, 0)
c1 += hi + carry
}
extract := func() uint64 {
result := c0
c0 = c1
c1 = uint64(c2)
c2 = 0
return result
}
extractFast := func() uint64 {
result := c0
c0 = c1
c1 = 0
return result
}
// l8[0..7] = a[0..3] * b[0..3] (following C implementation exactly)
muladdFast(a.d[0], b.d[0])
l8[0] = extractFast()
muladd(a.d[0], b.d[1])
muladd(a.d[1], b.d[0])
l8[1] = extract()
muladd(a.d[0], b.d[2])
muladd(a.d[1], b.d[1])
muladd(a.d[2], b.d[0])
l8[2] = extract()
muladd(a.d[0], b.d[3])
muladd(a.d[1], b.d[2])
muladd(a.d[2], b.d[1])
muladd(a.d[3], b.d[0])
l8[3] = extract()
muladd(a.d[1], b.d[3])
muladd(a.d[2], b.d[2])
muladd(a.d[3], b.d[1])
l8[4] = extract()
muladd(a.d[2], b.d[3])
muladd(a.d[3], b.d[2])
l8[5] = extract()
muladdFast(a.d[3], b.d[3])
l8[6] = extractFast()
l8[7] = c0
}
// mulByOrder multiplies a 256-bit value by the group order
func (r *Scalar) mulByOrder(a [4]uint64, result *[8]uint64) {
// Multiply a by the group order n
n := [4]uint64{scalarN0, scalarN1, scalarN2, scalarN3}
// Clear result
for i := range result {
result[i] = 0
// reduce512 reduces a 512-bit value to 256-bit (from C implementation)
func (r *Scalar) reduce512(l []uint64) {
// 160-bit accumulator
var c0, c1 uint64
var c2 uint32
// Extract upper 256 bits
n0, n1, n2, n3 := l[4], l[5], l[6], l[7]
// Helper macros
muladd := func(ai, bi uint64) {
hi, lo := bits.Mul64(ai, bi)
var carry uint64
c0, carry = bits.Add64(c0, lo, 0)
c1, carry = bits.Add64(c1, hi, carry)
c2 += uint32(carry)
}
// Compute all cross products
for i := 0; i < 4; i++ {
for j := 0; j < 4; j++ {
hi, lo := bits.Mul64(a[i], n[j])
k := i + j
// Add lo to result[k]
var carry uint64
result[k], carry = bits.Add64(result[k], lo, 0)
// Add hi to result[k+1] and propagate carry
if k+1 < 8 {
result[k+1], carry = bits.Add64(result[k+1], hi, carry)
// Propagate any remaining carry
for l := k + 2; l < 8 && carry != 0; l++ {
result[l], carry = bits.Add64(result[l], 0, carry)
}
}
}
muladdFast := func(ai, bi uint64) {
hi, lo := bits.Mul64(ai, bi)
var carry uint64
c0, carry = bits.Add64(c0, lo, 0)
c1 += hi + carry
}
sumadd := func(a uint64) {
var carry uint64
c0, carry = bits.Add64(c0, a, 0)
c1, carry = bits.Add64(c1, 0, carry)
c2 += uint32(carry)
}
sumaddFast := func(a uint64) {
var carry uint64
c0, carry = bits.Add64(c0, a, 0)
c1 += carry
}
extract := func() uint64 {
result := c0
c0 = c1
c1 = uint64(c2)
c2 = 0
return result
}
extractFast := func() uint64 {
result := c0
c0 = c1
c1 = 0
return result
}
// Reduce 512 bits into 385 bits
// m[0..6] = l[0..3] + n[0..3] * SECP256K1_N_C
c0 = l[0]
c1 = 0
c2 = 0
muladdFast(n0, scalarNC0)
m0 := extractFast()
sumaddFast(l[1])
muladd(n1, scalarNC0)
muladd(n0, scalarNC1)
m1 := extract()
sumadd(l[2])
muladd(n2, scalarNC0)
muladd(n1, scalarNC1)
sumadd(n0)
m2 := extract()
sumadd(l[3])
muladd(n3, scalarNC0)
muladd(n2, scalarNC1)
sumadd(n1)
m3 := extract()
muladd(n3, scalarNC1)
sumadd(n2)
m4 := extract()
sumaddFast(n3)
m5 := extractFast()
m6 := uint32(c0)
// Reduce 385 bits into 258 bits
// p[0..4] = m[0..3] + m[4..6] * SECP256K1_N_C
c0 = m0
c1 = 0
c2 = 0
muladdFast(m4, scalarNC0)
p0 := extractFast()
sumaddFast(m1)
muladd(m5, scalarNC0)
muladd(m4, scalarNC1)
p1 := extract()
sumadd(m2)
muladd(uint64(m6), scalarNC0)
muladd(m5, scalarNC1)
sumadd(m4)
p2 := extract()
sumaddFast(m3)
muladdFast(uint64(m6), scalarNC1)
sumaddFast(m5)
p3 := extractFast()
p4 := uint32(c0 + uint64(m6))
// Reduce 258 bits into 256 bits
// r[0..3] = p[0..3] + p[4] * SECP256K1_N_C
var t uint128
t = uint128FromU64(p0)
t = t.addMul(scalarNC0, uint64(p4))
r.d[0] = t.lo()
t = t.rshift(64)
t = t.addU64(p1)
t = t.addMul(scalarNC1, uint64(p4))
r.d[1] = t.lo()
t = t.rshift(64)
t = t.addU64(p2)
t = t.addU64(uint64(p4))
r.d[2] = t.lo()
t = t.rshift(64)
t = t.addU64(p3)
r.d[3] = t.lo()
c := t.hi()
// Final reduction
r.reduce(int(c) + boolToInt(r.checkOverflow()))
}
// negate negates a scalar: r = -a
@@ -279,22 +447,15 @@ func (r *Scalar) negate(a *Scalar) {
func (r *Scalar) inverse(a *Scalar) {
// Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
// where n is the group order (which is prime)
// The group order minus 2:
// n-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD036413F
// Use binary exponentiation with n-2
// n-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD036413F
var exp Scalar
// Since scalarN0 = 0xBFD25E8CD0364141, and we need n0 - 2
// We need to handle the subtraction properly
var borrow uint64
exp.d[0], borrow = bits.Sub64(scalarN0, 2, 0)
exp.d[1], borrow = bits.Sub64(scalarN1, 0, borrow)
exp.d[2], borrow = bits.Sub64(scalarN2, 0, borrow)
exp.d[3], _ = bits.Sub64(scalarN3, 0, borrow)
r.exp(a, &exp)
}
@@ -343,7 +504,7 @@ func (r *Scalar) half(a *Scalar) {
// isZero returns true if the scalar is zero
func (r *Scalar) isZero() bool {
return r.d[0] == 0 && r.d[1] == 0 && r.d[2] == 0 && r.d[3] == 0
return (r.d[0] | r.d[1] | r.d[2] | r.d[3]) == 0
}
// isOne returns true if the scalar is one
@@ -358,28 +519,43 @@ func (r *Scalar) isEven() bool {
// isHigh returns true if the scalar is > n/2
func (r *Scalar) isHigh() bool {
// Compare with n/2
if r.d[3] != scalarNH3 {
return r.d[3] > scalarNH3
var yes, no int
if r.d[3] < scalarNH3 {
no = 1
}
if r.d[2] != scalarNH2 {
return r.d[2] > scalarNH2
if r.d[3] > scalarNH3 {
yes = 1
}
if r.d[1] != scalarNH1 {
return r.d[1] > scalarNH1
if r.d[2] < scalarNH2 {
no |= (yes ^ 1)
}
return r.d[0] > scalarNH0
if r.d[2] > scalarNH2 {
yes |= (no ^ 1)
}
if r.d[1] < scalarNH1 {
no |= (yes ^ 1)
}
if r.d[1] > scalarNH1 {
yes |= (no ^ 1)
}
if r.d[0] > scalarNH0 {
yes |= (no ^ 1)
}
return yes != 0
}
// condNegate conditionally negates a scalar if flag is true
func (r *Scalar) condNegate(flag bool) bool {
if flag {
// condNegate conditionally negates the scalar if flag is true
func (r *Scalar) condNegate(flag int) {
if flag != 0 {
var neg Scalar
neg.negate(r)
*r = neg
return true
}
return false
}
// equal returns true if two scalars are equal
@@ -392,8 +568,11 @@ func (r *Scalar) equal(a *Scalar) bool {
// getBits extracts count bits starting at offset
func (r *Scalar) getBits(offset, count uint) uint32 {
if count == 0 || count > 32 || offset+count > 256 {
panic("invalid bit range")
if count == 0 || count > 32 {
panic("count must be 1-32")
}
if offset+count > 256 {
panic("offset + count must be <= 256")
}
limbIdx := offset / 64
@@ -406,17 +585,15 @@ func (r *Scalar) getBits(offset, count uint) uint32 {
// Bits span two limbs
lowBits := 64 - bitIdx
highBits := count - lowBits
low := uint32((r.d[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
high := uint32(r.d[limbIdx+1] & ((1 << highBits) - 1))
return low | (high << lowBits)
}
}
// cmov conditionally moves a scalar. If flag is true, r = a; otherwise r is unchanged.
func (r *Scalar) cmov(a *Scalar, flag int) {
mask := uint64(-flag)
mask := uint64(-(int64(flag) & 1))
r.d[0] ^= mask & (r.d[0] ^ a.d[0])
r.d[1] ^= mask & (r.d[1] ^ a.d[1])
r.d[2] ^= mask & (r.d[2] ^ a.d[2])
@@ -427,3 +604,31 @@ func (r *Scalar) cmov(a *Scalar, flag int) {
func (r *Scalar) clear() {
memclear(unsafe.Pointer(&r.d[0]), unsafe.Sizeof(r.d))
}
// Helper functions for 128-bit arithmetic (using uint128 from field_mul.go)
func uint128FromU64(x uint64) uint128 {
return uint128{low: x, high: 0}
}
func (x uint128) addU64(y uint64) uint128 {
low, carry := bits.Add64(x.low, y, 0)
high := x.high + carry
return uint128{low: low, high: high}
}
func (x uint128) addMul(a, b uint64) uint128 {
hi, lo := bits.Mul64(a, b)
low, carry := bits.Add64(x.low, lo, 0)
high, _ := bits.Add64(x.high, hi, carry)
return uint128{low: low, high: high}
}
// Helper function to convert bool to int
func boolToInt(b bool) int {
if b {
return 1
}
return 0
}