Add context tests and implement generator multiplication context
This commit introduces a new test file for context management, covering various scenarios for context creation, destruction, and capabilities. Additionally, it implements the generator multiplication context, enhancing the secp256k1 elliptic curve operations. The changes ensure comprehensive testing and improved functionality for context handling, contributing to the overall robustness of the implementation.
This commit is contained in:
575
scalar.go
575
scalar.go
@@ -6,22 +6,21 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Scalar represents a scalar modulo the group order of the secp256k1 curve
|
||||
// This implementation uses 4 uint64 limbs, ported from scalar_4x64.h
|
||||
// Scalar represents a scalar value modulo the secp256k1 group order.
|
||||
// Uses 4 uint64 limbs to represent a 256-bit scalar.
|
||||
type Scalar struct {
|
||||
d [4]uint64
|
||||
}
|
||||
|
||||
// Group order constants (secp256k1 curve order n)
|
||||
// Scalar constants from the C implementation
|
||||
const (
|
||||
// Limbs of the secp256k1 order
|
||||
// Limbs of the secp256k1 order n
|
||||
scalarN0 = 0xBFD25E8CD0364141
|
||||
scalarN1 = 0xBAAEDCE6AF48A03B
|
||||
scalarN2 = 0xFFFFFFFFFFFFFFFE
|
||||
scalarN3 = 0xFFFFFFFFFFFFFFFF
|
||||
|
||||
// Limbs of 2^256 minus the secp256k1 order
|
||||
// These are precomputed values to avoid overflow issues
|
||||
// Limbs of 2^256 minus the secp256k1 order (complement constants)
|
||||
scalarNC0 = 0x402DA1732FC9BEBF // ~scalarN0 + 1
|
||||
scalarNC1 = 0x4551231950B75FC4 // ~scalarN1
|
||||
scalarNC2 = 0x0000000000000001 // 1
|
||||
@@ -33,7 +32,7 @@ const (
|
||||
scalarNH3 = 0x7FFFFFFFFFFFFFFF
|
||||
)
|
||||
|
||||
// Scalar constants
|
||||
// Scalar element constants
|
||||
var (
|
||||
// ScalarZero represents the scalar 0
|
||||
ScalarZero = Scalar{d: [4]uint64{0, 0, 0, 0}}
|
||||
@@ -42,53 +41,7 @@ var (
|
||||
ScalarOne = Scalar{d: [4]uint64{1, 0, 0, 0}}
|
||||
)
|
||||
|
||||
// NewScalar creates a new scalar from a 32-byte big-endian array
|
||||
func NewScalar(b32 []byte) *Scalar {
|
||||
if len(b32) != 32 {
|
||||
panic("input must be 32 bytes")
|
||||
}
|
||||
|
||||
s := &Scalar{}
|
||||
s.setB32(b32)
|
||||
return s
|
||||
}
|
||||
|
||||
// setB32 sets a scalar from a 32-byte big-endian array, reducing modulo group order
|
||||
func (r *Scalar) setB32(bin []byte) (overflow bool) {
|
||||
// Convert from big-endian bytes to limbs
|
||||
r.d[0] = readBE64(bin[24:32])
|
||||
r.d[1] = readBE64(bin[16:24])
|
||||
r.d[2] = readBE64(bin[8:16])
|
||||
r.d[3] = readBE64(bin[0:8])
|
||||
|
||||
// Check for overflow and reduce if necessary
|
||||
overflow = r.checkOverflow()
|
||||
if overflow {
|
||||
r.reduce(1)
|
||||
}
|
||||
|
||||
return overflow
|
||||
}
|
||||
|
||||
// setB32Seckey sets a scalar from a 32-byte array and returns true if it's a valid secret key
|
||||
func (r *Scalar) setB32Seckey(bin []byte) bool {
|
||||
overflow := r.setB32(bin)
|
||||
return !overflow && !r.isZero()
|
||||
}
|
||||
|
||||
// getB32 converts a scalar to a 32-byte big-endian array
|
||||
func (r *Scalar) getB32(bin []byte) {
|
||||
if len(bin) != 32 {
|
||||
panic("output buffer must be 32 bytes")
|
||||
}
|
||||
|
||||
writeBE64(bin[0:8], r.d[3])
|
||||
writeBE64(bin[8:16], r.d[2])
|
||||
writeBE64(bin[16:24], r.d[1])
|
||||
writeBE64(bin[24:32], r.d[0])
|
||||
}
|
||||
|
||||
// setInt sets a scalar to an unsigned integer value
|
||||
// setInt sets a scalar to a small integer value
|
||||
func (r *Scalar) setInt(v uint) {
|
||||
r.d[0] = uint64(v)
|
||||
r.d[1] = 0
|
||||
@@ -96,31 +49,113 @@ func (r *Scalar) setInt(v uint) {
|
||||
r.d[3] = 0
|
||||
}
|
||||
|
||||
// setB32 sets a scalar from a 32-byte big-endian array
|
||||
func (r *Scalar) setB32(b []byte) bool {
|
||||
if len(b) != 32 {
|
||||
panic("scalar byte array must be 32 bytes")
|
||||
}
|
||||
|
||||
// Convert from big-endian bytes to uint64 limbs
|
||||
r.d[0] = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
|
||||
uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
|
||||
r.d[1] = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
|
||||
uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
|
||||
r.d[2] = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
|
||||
uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
|
||||
r.d[3] = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
|
||||
uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
|
||||
|
||||
// Check if the scalar overflows the group order
|
||||
overflow := r.checkOverflow()
|
||||
if overflow {
|
||||
r.reduce(1)
|
||||
}
|
||||
|
||||
return overflow
|
||||
}
|
||||
|
||||
// setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
|
||||
func (r *Scalar) setB32Seckey(b []byte) bool {
|
||||
overflow := r.setB32(b)
|
||||
return !r.isZero() && !overflow
|
||||
}
|
||||
|
||||
// getB32 converts a scalar to a 32-byte big-endian array
|
||||
func (r *Scalar) getB32(b []byte) {
|
||||
if len(b) != 32 {
|
||||
panic("scalar byte array must be 32 bytes")
|
||||
}
|
||||
|
||||
// Convert from uint64 limbs to big-endian bytes
|
||||
b[31] = byte(r.d[0])
|
||||
b[30] = byte(r.d[0] >> 8)
|
||||
b[29] = byte(r.d[0] >> 16)
|
||||
b[28] = byte(r.d[0] >> 24)
|
||||
b[27] = byte(r.d[0] >> 32)
|
||||
b[26] = byte(r.d[0] >> 40)
|
||||
b[25] = byte(r.d[0] >> 48)
|
||||
b[24] = byte(r.d[0] >> 56)
|
||||
|
||||
b[23] = byte(r.d[1])
|
||||
b[22] = byte(r.d[1] >> 8)
|
||||
b[21] = byte(r.d[1] >> 16)
|
||||
b[20] = byte(r.d[1] >> 24)
|
||||
b[19] = byte(r.d[1] >> 32)
|
||||
b[18] = byte(r.d[1] >> 40)
|
||||
b[17] = byte(r.d[1] >> 48)
|
||||
b[16] = byte(r.d[1] >> 56)
|
||||
|
||||
b[15] = byte(r.d[2])
|
||||
b[14] = byte(r.d[2] >> 8)
|
||||
b[13] = byte(r.d[2] >> 16)
|
||||
b[12] = byte(r.d[2] >> 24)
|
||||
b[11] = byte(r.d[2] >> 32)
|
||||
b[10] = byte(r.d[2] >> 40)
|
||||
b[9] = byte(r.d[2] >> 48)
|
||||
b[8] = byte(r.d[2] >> 56)
|
||||
|
||||
b[7] = byte(r.d[3])
|
||||
b[6] = byte(r.d[3] >> 8)
|
||||
b[5] = byte(r.d[3] >> 16)
|
||||
b[4] = byte(r.d[3] >> 24)
|
||||
b[3] = byte(r.d[3] >> 32)
|
||||
b[2] = byte(r.d[3] >> 40)
|
||||
b[1] = byte(r.d[3] >> 48)
|
||||
b[0] = byte(r.d[3] >> 56)
|
||||
}
|
||||
|
||||
// checkOverflow checks if the scalar is >= the group order
|
||||
func (r *Scalar) checkOverflow() bool {
|
||||
// Simple comparison with group order
|
||||
if r.d[3] > scalarN3 {
|
||||
return true
|
||||
}
|
||||
yes := 0
|
||||
no := 0
|
||||
|
||||
// Check each limb from most significant to least significant
|
||||
if r.d[3] < scalarN3 {
|
||||
return false
|
||||
no = 1
|
||||
}
|
||||
if r.d[3] > scalarN3 {
|
||||
yes = 1
|
||||
}
|
||||
|
||||
if r.d[2] > scalarN2 {
|
||||
return true
|
||||
}
|
||||
if r.d[2] < scalarN2 {
|
||||
return false
|
||||
no |= (yes ^ 1)
|
||||
}
|
||||
if r.d[2] > scalarN2 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
if r.d[1] > scalarN1 {
|
||||
return true
|
||||
}
|
||||
if r.d[1] < scalarN1 {
|
||||
return false
|
||||
no |= (yes ^ 1)
|
||||
}
|
||||
if r.d[1] > scalarN1 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
return r.d[0] >= scalarN0
|
||||
if r.d[0] >= scalarN0 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
return yes != 0
|
||||
}
|
||||
|
||||
// reduce reduces the scalar modulo the group order
|
||||
@@ -129,20 +164,30 @@ func (r *Scalar) reduce(overflow int) {
|
||||
panic("overflow must be 0 or 1")
|
||||
}
|
||||
|
||||
// Subtract overflow * n from the scalar
|
||||
var borrow uint64
|
||||
// Use 128-bit arithmetic for the reduction
|
||||
var t uint128
|
||||
|
||||
// d[0] -= overflow * scalarN0
|
||||
r.d[0], borrow = bits.Sub64(r.d[0], uint64(overflow)*scalarN0, 0)
|
||||
// d[0] += overflow * scalarNC0
|
||||
t = uint128FromU64(r.d[0])
|
||||
t = t.addU64(uint64(overflow) * scalarNC0)
|
||||
r.d[0] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
// d[1] -= overflow * scalarN1 + borrow
|
||||
r.d[1], borrow = bits.Sub64(r.d[1], uint64(overflow)*scalarN1, borrow)
|
||||
// d[1] += overflow * scalarNC1 + carry
|
||||
t = t.addU64(r.d[1])
|
||||
t = t.addU64(uint64(overflow) * scalarNC1)
|
||||
r.d[1] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
// d[2] -= overflow * scalarN2 + borrow
|
||||
r.d[2], borrow = bits.Sub64(r.d[2], uint64(overflow)*scalarN2, borrow)
|
||||
// d[2] += overflow * scalarNC2 + carry
|
||||
t = t.addU64(r.d[2])
|
||||
t = t.addU64(uint64(overflow) * scalarNC2)
|
||||
r.d[2] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
// d[3] -= overflow * scalarN3 + borrow
|
||||
r.d[3], _ = bits.Sub64(r.d[3], uint64(overflow)*scalarN3, borrow)
|
||||
// d[3] += carry (scalarNC3 = 0)
|
||||
t = t.addU64(r.d[3])
|
||||
r.d[3] = t.lo()
|
||||
}
|
||||
|
||||
// add adds two scalars: r = a + b, returns overflow
|
||||
@@ -174,94 +219,217 @@ func (r *Scalar) sub(a, b *Scalar) {
|
||||
// mul multiplies two scalars: r = a * b
|
||||
func (r *Scalar) mul(a, b *Scalar) {
|
||||
// Compute full 512-bit product using all 16 cross products
|
||||
var c [8]uint64
|
||||
|
||||
// Compute all cross products a[i] * b[j]
|
||||
for i := 0; i < 4; i++ {
|
||||
for j := 0; j < 4; j++ {
|
||||
hi, lo := bits.Mul64(a.d[i], b.d[j])
|
||||
k := i + j
|
||||
|
||||
// Add lo to c[k]
|
||||
var carry uint64
|
||||
c[k], carry = bits.Add64(c[k], lo, 0)
|
||||
|
||||
// Add hi to c[k+1] and propagate carry
|
||||
if k+1 < 8 {
|
||||
c[k+1], carry = bits.Add64(c[k+1], hi, carry)
|
||||
|
||||
// Propagate any remaining carry
|
||||
for l := k + 2; l < 8 && carry != 0; l++ {
|
||||
c[l], carry = bits.Add64(c[l], 0, carry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reduce the 512-bit result modulo the group order
|
||||
r.reduceWide(c)
|
||||
var l [8]uint64
|
||||
r.mul512(l[:], a, b)
|
||||
r.reduce512(l[:])
|
||||
}
|
||||
|
||||
// reduceWide reduces a 512-bit value modulo the group order
|
||||
func (r *Scalar) reduceWide(wide [8]uint64) {
|
||||
// For now, use a very simple approach that just takes the lower 256 bits
|
||||
// and ignores the upper bits. This is incorrect but will allow testing
|
||||
// of other functionality. A proper implementation would use Barrett reduction.
|
||||
|
||||
r.d[0] = wide[0]
|
||||
r.d[1] = wide[1]
|
||||
r.d[2] = wide[2]
|
||||
r.d[3] = wide[3]
|
||||
|
||||
// If there are upper bits, we need to do some reduction
|
||||
// For now, just add a simple approximation
|
||||
if wide[4] != 0 || wide[5] != 0 || wide[6] != 0 || wide[7] != 0 {
|
||||
// Very crude approximation: add the upper bits to the lower bits
|
||||
// This is mathematically incorrect but prevents infinite loops
|
||||
// mul512 computes the 512-bit product of two scalars (from C implementation)
|
||||
func (r *Scalar) mul512(l8 []uint64, a, b *Scalar) {
|
||||
// 160-bit accumulator (c0, c1, c2)
|
||||
var c0, c1 uint64
|
||||
var c2 uint32
|
||||
|
||||
// Helper macros translated from C
|
||||
muladd := func(ai, bi uint64) {
|
||||
hi, lo := bits.Mul64(ai, bi)
|
||||
var carry uint64
|
||||
r.d[0], carry = bits.Add64(r.d[0], wide[4], 0)
|
||||
r.d[1], carry = bits.Add64(r.d[1], wide[5], carry)
|
||||
r.d[2], carry = bits.Add64(r.d[2], wide[6], carry)
|
||||
r.d[3], _ = bits.Add64(r.d[3], wide[7], carry)
|
||||
c0, carry = bits.Add64(c0, lo, 0)
|
||||
c1, carry = bits.Add64(c1, hi, carry)
|
||||
c2 += uint32(carry)
|
||||
}
|
||||
|
||||
// Check if we need reduction
|
||||
if r.checkOverflow() {
|
||||
r.reduce(1)
|
||||
|
||||
muladdFast := func(ai, bi uint64) {
|
||||
hi, lo := bits.Mul64(ai, bi)
|
||||
var carry uint64
|
||||
c0, carry = bits.Add64(c0, lo, 0)
|
||||
c1 += hi + carry
|
||||
}
|
||||
|
||||
extract := func() uint64 {
|
||||
result := c0
|
||||
c0 = c1
|
||||
c1 = uint64(c2)
|
||||
c2 = 0
|
||||
return result
|
||||
}
|
||||
|
||||
extractFast := func() uint64 {
|
||||
result := c0
|
||||
c0 = c1
|
||||
c1 = 0
|
||||
return result
|
||||
}
|
||||
|
||||
// l8[0..7] = a[0..3] * b[0..3] (following C implementation exactly)
|
||||
muladdFast(a.d[0], b.d[0])
|
||||
l8[0] = extractFast()
|
||||
|
||||
muladd(a.d[0], b.d[1])
|
||||
muladd(a.d[1], b.d[0])
|
||||
l8[1] = extract()
|
||||
|
||||
muladd(a.d[0], b.d[2])
|
||||
muladd(a.d[1], b.d[1])
|
||||
muladd(a.d[2], b.d[0])
|
||||
l8[2] = extract()
|
||||
|
||||
muladd(a.d[0], b.d[3])
|
||||
muladd(a.d[1], b.d[2])
|
||||
muladd(a.d[2], b.d[1])
|
||||
muladd(a.d[3], b.d[0])
|
||||
l8[3] = extract()
|
||||
|
||||
muladd(a.d[1], b.d[3])
|
||||
muladd(a.d[2], b.d[2])
|
||||
muladd(a.d[3], b.d[1])
|
||||
l8[4] = extract()
|
||||
|
||||
muladd(a.d[2], b.d[3])
|
||||
muladd(a.d[3], b.d[2])
|
||||
l8[5] = extract()
|
||||
|
||||
muladdFast(a.d[3], b.d[3])
|
||||
l8[6] = extractFast()
|
||||
l8[7] = c0
|
||||
}
|
||||
|
||||
// mulByOrder multiplies a 256-bit value by the group order
|
||||
func (r *Scalar) mulByOrder(a [4]uint64, result *[8]uint64) {
|
||||
// Multiply a by the group order n
|
||||
n := [4]uint64{scalarN0, scalarN1, scalarN2, scalarN3}
|
||||
|
||||
// Clear result
|
||||
for i := range result {
|
||||
result[i] = 0
|
||||
// reduce512 reduces a 512-bit value to 256-bit (from C implementation)
|
||||
func (r *Scalar) reduce512(l []uint64) {
|
||||
// 160-bit accumulator
|
||||
var c0, c1 uint64
|
||||
var c2 uint32
|
||||
|
||||
// Extract upper 256 bits
|
||||
n0, n1, n2, n3 := l[4], l[5], l[6], l[7]
|
||||
|
||||
// Helper macros
|
||||
muladd := func(ai, bi uint64) {
|
||||
hi, lo := bits.Mul64(ai, bi)
|
||||
var carry uint64
|
||||
c0, carry = bits.Add64(c0, lo, 0)
|
||||
c1, carry = bits.Add64(c1, hi, carry)
|
||||
c2 += uint32(carry)
|
||||
}
|
||||
|
||||
// Compute all cross products
|
||||
for i := 0; i < 4; i++ {
|
||||
for j := 0; j < 4; j++ {
|
||||
hi, lo := bits.Mul64(a[i], n[j])
|
||||
k := i + j
|
||||
|
||||
// Add lo to result[k]
|
||||
var carry uint64
|
||||
result[k], carry = bits.Add64(result[k], lo, 0)
|
||||
|
||||
// Add hi to result[k+1] and propagate carry
|
||||
if k+1 < 8 {
|
||||
result[k+1], carry = bits.Add64(result[k+1], hi, carry)
|
||||
|
||||
// Propagate any remaining carry
|
||||
for l := k + 2; l < 8 && carry != 0; l++ {
|
||||
result[l], carry = bits.Add64(result[l], 0, carry)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
muladdFast := func(ai, bi uint64) {
|
||||
hi, lo := bits.Mul64(ai, bi)
|
||||
var carry uint64
|
||||
c0, carry = bits.Add64(c0, lo, 0)
|
||||
c1 += hi + carry
|
||||
}
|
||||
|
||||
sumadd := func(a uint64) {
|
||||
var carry uint64
|
||||
c0, carry = bits.Add64(c0, a, 0)
|
||||
c1, carry = bits.Add64(c1, 0, carry)
|
||||
c2 += uint32(carry)
|
||||
}
|
||||
|
||||
sumaddFast := func(a uint64) {
|
||||
var carry uint64
|
||||
c0, carry = bits.Add64(c0, a, 0)
|
||||
c1 += carry
|
||||
}
|
||||
|
||||
extract := func() uint64 {
|
||||
result := c0
|
||||
c0 = c1
|
||||
c1 = uint64(c2)
|
||||
c2 = 0
|
||||
return result
|
||||
}
|
||||
|
||||
extractFast := func() uint64 {
|
||||
result := c0
|
||||
c0 = c1
|
||||
c1 = 0
|
||||
return result
|
||||
}
|
||||
|
||||
// Reduce 512 bits into 385 bits
|
||||
// m[0..6] = l[0..3] + n[0..3] * SECP256K1_N_C
|
||||
c0 = l[0]
|
||||
c1 = 0
|
||||
c2 = 0
|
||||
muladdFast(n0, scalarNC0)
|
||||
m0 := extractFast()
|
||||
|
||||
sumaddFast(l[1])
|
||||
muladd(n1, scalarNC0)
|
||||
muladd(n0, scalarNC1)
|
||||
m1 := extract()
|
||||
|
||||
sumadd(l[2])
|
||||
muladd(n2, scalarNC0)
|
||||
muladd(n1, scalarNC1)
|
||||
sumadd(n0)
|
||||
m2 := extract()
|
||||
|
||||
sumadd(l[3])
|
||||
muladd(n3, scalarNC0)
|
||||
muladd(n2, scalarNC1)
|
||||
sumadd(n1)
|
||||
m3 := extract()
|
||||
|
||||
muladd(n3, scalarNC1)
|
||||
sumadd(n2)
|
||||
m4 := extract()
|
||||
|
||||
sumaddFast(n3)
|
||||
m5 := extractFast()
|
||||
m6 := uint32(c0)
|
||||
|
||||
// Reduce 385 bits into 258 bits
|
||||
// p[0..4] = m[0..3] + m[4..6] * SECP256K1_N_C
|
||||
c0 = m0
|
||||
c1 = 0
|
||||
c2 = 0
|
||||
muladdFast(m4, scalarNC0)
|
||||
p0 := extractFast()
|
||||
|
||||
sumaddFast(m1)
|
||||
muladd(m5, scalarNC0)
|
||||
muladd(m4, scalarNC1)
|
||||
p1 := extract()
|
||||
|
||||
sumadd(m2)
|
||||
muladd(uint64(m6), scalarNC0)
|
||||
muladd(m5, scalarNC1)
|
||||
sumadd(m4)
|
||||
p2 := extract()
|
||||
|
||||
sumaddFast(m3)
|
||||
muladdFast(uint64(m6), scalarNC1)
|
||||
sumaddFast(m5)
|
||||
p3 := extractFast()
|
||||
p4 := uint32(c0 + uint64(m6))
|
||||
|
||||
// Reduce 258 bits into 256 bits
|
||||
// r[0..3] = p[0..3] + p[4] * SECP256K1_N_C
|
||||
var t uint128
|
||||
|
||||
t = uint128FromU64(p0)
|
||||
t = t.addMul(scalarNC0, uint64(p4))
|
||||
r.d[0] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
t = t.addU64(p1)
|
||||
t = t.addMul(scalarNC1, uint64(p4))
|
||||
r.d[1] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
t = t.addU64(p2)
|
||||
t = t.addU64(uint64(p4))
|
||||
r.d[2] = t.lo()
|
||||
t = t.rshift(64)
|
||||
|
||||
t = t.addU64(p3)
|
||||
r.d[3] = t.lo()
|
||||
c := t.hi()
|
||||
|
||||
// Final reduction
|
||||
r.reduce(int(c) + boolToInt(r.checkOverflow()))
|
||||
}
|
||||
|
||||
// negate negates a scalar: r = -a
|
||||
@@ -279,22 +447,15 @@ func (r *Scalar) negate(a *Scalar) {
|
||||
func (r *Scalar) inverse(a *Scalar) {
|
||||
// Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
|
||||
// where n is the group order (which is prime)
|
||||
|
||||
// The group order minus 2:
|
||||
// n-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD036413F
|
||||
|
||||
|
||||
// Use binary exponentiation with n-2
|
||||
// n-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD036413F
|
||||
var exp Scalar
|
||||
|
||||
// Since scalarN0 = 0xBFD25E8CD0364141, and we need n0 - 2
|
||||
// We need to handle the subtraction properly
|
||||
var borrow uint64
|
||||
exp.d[0], borrow = bits.Sub64(scalarN0, 2, 0)
|
||||
exp.d[1], borrow = bits.Sub64(scalarN1, 0, borrow)
|
||||
exp.d[2], borrow = bits.Sub64(scalarN2, 0, borrow)
|
||||
exp.d[3], _ = bits.Sub64(scalarN3, 0, borrow)
|
||||
|
||||
|
||||
r.exp(a, &exp)
|
||||
}
|
||||
|
||||
@@ -343,7 +504,7 @@ func (r *Scalar) half(a *Scalar) {
|
||||
|
||||
// isZero returns true if the scalar is zero
|
||||
func (r *Scalar) isZero() bool {
|
||||
return r.d[0] == 0 && r.d[1] == 0 && r.d[2] == 0 && r.d[3] == 0
|
||||
return (r.d[0] | r.d[1] | r.d[2] | r.d[3]) == 0
|
||||
}
|
||||
|
||||
// isOne returns true if the scalar is one
|
||||
@@ -358,28 +519,43 @@ func (r *Scalar) isEven() bool {
|
||||
|
||||
// isHigh returns true if the scalar is > n/2
|
||||
func (r *Scalar) isHigh() bool {
|
||||
// Compare with n/2
|
||||
if r.d[3] != scalarNH3 {
|
||||
return r.d[3] > scalarNH3
|
||||
var yes, no int
|
||||
|
||||
if r.d[3] < scalarNH3 {
|
||||
no = 1
|
||||
}
|
||||
if r.d[2] != scalarNH2 {
|
||||
return r.d[2] > scalarNH2
|
||||
if r.d[3] > scalarNH3 {
|
||||
yes = 1
|
||||
}
|
||||
if r.d[1] != scalarNH1 {
|
||||
return r.d[1] > scalarNH1
|
||||
|
||||
if r.d[2] < scalarNH2 {
|
||||
no |= (yes ^ 1)
|
||||
}
|
||||
return r.d[0] > scalarNH0
|
||||
if r.d[2] > scalarNH2 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
if r.d[1] < scalarNH1 {
|
||||
no |= (yes ^ 1)
|
||||
}
|
||||
if r.d[1] > scalarNH1 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
if r.d[0] > scalarNH0 {
|
||||
yes |= (no ^ 1)
|
||||
}
|
||||
|
||||
return yes != 0
|
||||
}
|
||||
|
||||
// condNegate conditionally negates a scalar if flag is true
|
||||
func (r *Scalar) condNegate(flag bool) bool {
|
||||
if flag {
|
||||
// condNegate conditionally negates the scalar if flag is true
|
||||
func (r *Scalar) condNegate(flag int) {
|
||||
if flag != 0 {
|
||||
var neg Scalar
|
||||
neg.negate(r)
|
||||
*r = neg
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// equal returns true if two scalars are equal
|
||||
@@ -392,8 +568,11 @@ func (r *Scalar) equal(a *Scalar) bool {
|
||||
|
||||
// getBits extracts count bits starting at offset
|
||||
func (r *Scalar) getBits(offset, count uint) uint32 {
|
||||
if count == 0 || count > 32 || offset+count > 256 {
|
||||
panic("invalid bit range")
|
||||
if count == 0 || count > 32 {
|
||||
panic("count must be 1-32")
|
||||
}
|
||||
if offset+count > 256 {
|
||||
panic("offset + count must be <= 256")
|
||||
}
|
||||
|
||||
limbIdx := offset / 64
|
||||
@@ -406,17 +585,15 @@ func (r *Scalar) getBits(offset, count uint) uint32 {
|
||||
// Bits span two limbs
|
||||
lowBits := 64 - bitIdx
|
||||
highBits := count - lowBits
|
||||
|
||||
low := uint32((r.d[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
|
||||
high := uint32(r.d[limbIdx+1] & ((1 << highBits) - 1))
|
||||
|
||||
return low | (high << lowBits)
|
||||
}
|
||||
}
|
||||
|
||||
// cmov conditionally moves a scalar. If flag is true, r = a; otherwise r is unchanged.
|
||||
func (r *Scalar) cmov(a *Scalar, flag int) {
|
||||
mask := uint64(-flag)
|
||||
mask := uint64(-(int64(flag) & 1))
|
||||
r.d[0] ^= mask & (r.d[0] ^ a.d[0])
|
||||
r.d[1] ^= mask & (r.d[1] ^ a.d[1])
|
||||
r.d[2] ^= mask & (r.d[2] ^ a.d[2])
|
||||
@@ -427,3 +604,31 @@ func (r *Scalar) cmov(a *Scalar, flag int) {
|
||||
func (r *Scalar) clear() {
|
||||
memclear(unsafe.Pointer(&r.d[0]), unsafe.Sizeof(r.d))
|
||||
}
|
||||
|
||||
// Helper functions for 128-bit arithmetic (using uint128 from field_mul.go)
|
||||
|
||||
func uint128FromU64(x uint64) uint128 {
|
||||
return uint128{low: x, high: 0}
|
||||
}
|
||||
|
||||
func (x uint128) addU64(y uint64) uint128 {
|
||||
low, carry := bits.Add64(x.low, y, 0)
|
||||
high := x.high + carry
|
||||
return uint128{low: low, high: high}
|
||||
}
|
||||
|
||||
func (x uint128) addMul(a, b uint64) uint128 {
|
||||
hi, lo := bits.Mul64(a, b)
|
||||
low, carry := bits.Add64(x.low, lo, 0)
|
||||
high, _ := bits.Add64(x.high, hi, carry)
|
||||
return uint128{low: low, high: high}
|
||||
}
|
||||
|
||||
// Helper function to convert bool to int
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user