bulletproofs: speedup PROVE
This commit is contained in:
parent
2287fb9fb4
commit
4564a5d17b
|
@ -127,15 +127,6 @@ static void sub_acc_p3(ge_p3 *acc_p3, const rct::key &point)
|
||||||
ge_p1p1_to_p3(acc_p3, &p1);
|
ge_p1p1_to_p3(acc_p3, &p1);
|
||||||
}
|
}
|
||||||
|
|
||||||
static rct::key scalarmultKey(const ge_p3 &P, const rct::key &a)
|
|
||||||
{
|
|
||||||
ge_p2 R;
|
|
||||||
ge_scalarmult(&R, a.bytes, &P);
|
|
||||||
rct::key aP;
|
|
||||||
ge_tobytes(aP.bytes, &R);
|
|
||||||
return aP;
|
|
||||||
}
|
|
||||||
|
|
||||||
static rct::key get_exponent(const rct::key &base, size_t idx)
|
static rct::key get_exponent(const rct::key &base, size_t idx)
|
||||||
{
|
{
|
||||||
static const std::string salt("bulletproof");
|
static const std::string salt("bulletproof");
|
||||||
|
@ -193,23 +184,28 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Compute a custom vector-scalar commitment */
|
/* Compute a custom vector-scalar commitment */
|
||||||
static rct::key vector_exponent_custom(const rct::keyV &A, const rct::keyV &B, const rct::keyV &a, const rct::keyV &b)
|
static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar)
|
||||||
{
|
{
|
||||||
CHECK_AND_ASSERT_THROW_MES(A.size() == B.size(), "Incompatible sizes of A and B");
|
CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A");
|
||||||
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b");
|
CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B");
|
||||||
CHECK_AND_ASSERT_THROW_MES(a.size() == A.size(), "Incompatible sizes of a and A");
|
CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a");
|
||||||
CHECK_AND_ASSERT_THROW_MES(a.size() <= maxN*maxM, "Incompatible sizes of a and maxN");
|
CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b");
|
||||||
|
CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large");
|
||||||
|
CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present");
|
||||||
|
|
||||||
std::vector<MultiexpData> multiexp_data;
|
std::vector<MultiexpData> multiexp_data;
|
||||||
multiexp_data.reserve(a.size()*2);
|
multiexp_data.resize(size*2 + (!!extra_point));
|
||||||
for (size_t i = 0; i < a.size(); ++i)
|
for (size_t i = 0; i < size; ++i)
|
||||||
{
|
{
|
||||||
multiexp_data.resize(multiexp_data.size() + 1);
|
sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);;
|
||||||
multiexp_data.back().scalar = a[i];
|
multiexp_data[i*2].point = A[Ao+i];
|
||||||
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, A[i].bytes) == 0, "ge_frombytes_vartime failed");
|
sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes);
|
||||||
multiexp_data.resize(multiexp_data.size() + 1);
|
multiexp_data[i*2+1].point = B[Bo+i];
|
||||||
multiexp_data.back().scalar = b[i];
|
}
|
||||||
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, B[i].bytes) == 0, "ge_frombytes_vartime failed");
|
if (extra_point)
|
||||||
|
{
|
||||||
|
sc_mul(multiexp_data.back().scalar.bytes, extra_scalar->bytes, INV_EIGHT.bytes);
|
||||||
|
multiexp_data.back().point = *extra_point;
|
||||||
}
|
}
|
||||||
return multiexp(multiexp_data, false);
|
return multiexp(multiexp_data, false);
|
||||||
}
|
}
|
||||||
|
@ -273,16 +269,19 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b)
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Given two curvepoint arrays, construct the Hadamard product */
|
/* folds a curvepoint array using a two way scaled Hadamard product */
|
||||||
static rct::keyV hadamard2(const rct::keyV &a, const rct::keyV &b)
|
static void hadamard_fold(std::vector<ge_p3> &v, const rct::key &a, const rct::key &b)
|
||||||
{
|
{
|
||||||
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b");
|
CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even");
|
||||||
rct::keyV res(a.size());
|
const size_t sz = v.size() / 2;
|
||||||
for (size_t i = 0; i < a.size(); ++i)
|
for (size_t n = 0; n < sz; ++n)
|
||||||
{
|
{
|
||||||
rct::addKeys(res[i], a[i], b[i]);
|
ge_dsmp c[2];
|
||||||
|
ge_dsm_precomp(c[0], &v[n]);
|
||||||
|
ge_dsm_precomp(c[1], &v[sz + n]);
|
||||||
|
ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]);
|
||||||
}
|
}
|
||||||
return res;
|
v.resize(sz);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Add two vectors */
|
/* Add two vectors */
|
||||||
|
@ -326,17 +325,6 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
|
||||||
return rct::keyV(N, x);
|
return rct::keyV(N, x);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* Exponentiate a curve vector by a scalar */
|
|
||||||
static rct::keyV vector_scalar2(const rct::keyV &a, const rct::key &x)
|
|
||||||
{
|
|
||||||
rct::keyV res(a.size());
|
|
||||||
for (size_t i = 0; i < a.size(); ++i)
|
|
||||||
{
|
|
||||||
rct::scalarmultKey(res[i], a[i], x);
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Get the sum of a vector's elements */
|
/* Get the sum of a vector's elements */
|
||||||
static rct::key vector_sum(const rct::keyV &a)
|
static rct::key vector_sum(const rct::keyV &a)
|
||||||
{
|
{
|
||||||
|
@ -620,16 +608,16 @@ try_again:
|
||||||
|
|
||||||
// These are used in the inner product rounds
|
// These are used in the inner product rounds
|
||||||
size_t nprime = N;
|
size_t nprime = N;
|
||||||
rct::keyV Gprime(N);
|
std::vector<ge_p3> Gprime(N);
|
||||||
rct::keyV Hprime(N);
|
std::vector<ge_p3> Hprime(N);
|
||||||
rct::keyV aprime(N);
|
rct::keyV aprime(N);
|
||||||
rct::keyV bprime(N);
|
rct::keyV bprime(N);
|
||||||
const rct::key yinv = invert(y);
|
const rct::key yinv = invert(y);
|
||||||
rct::key yinvpow = rct::identity();
|
rct::key yinvpow = rct::identity();
|
||||||
for (size_t i = 0; i < N; ++i)
|
for (size_t i = 0; i < N; ++i)
|
||||||
{
|
{
|
||||||
Gprime[i] = Gi[i];
|
Gprime[i] = Gi_p3[i];
|
||||||
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow);
|
ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
|
||||||
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
|
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
|
||||||
aprime[i] = l[i];
|
aprime[i] = l[i];
|
||||||
bprime[i] = r[i];
|
bprime[i] = r[i];
|
||||||
|
@ -652,14 +640,10 @@ try_again:
|
||||||
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
||||||
|
|
||||||
// PAPER LINES 18-19
|
// PAPER LINES 18-19
|
||||||
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
|
|
||||||
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
|
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
|
||||||
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp));
|
L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
|
||||||
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
|
|
||||||
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
|
||||||
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
|
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
|
||||||
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp));
|
R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
|
||||||
R[round] = rct::scalarmultKey(R[round], INV_EIGHT);
|
|
||||||
|
|
||||||
// PAPER LINES 21-22
|
// PAPER LINES 21-22
|
||||||
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
|
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
|
||||||
|
@ -672,8 +656,11 @@ try_again:
|
||||||
|
|
||||||
// PAPER LINES 24-25
|
// PAPER LINES 24-25
|
||||||
const rct::key winv = invert(w[round]);
|
const rct::key winv = invert(w[round]);
|
||||||
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round]));
|
if (nprime > 1)
|
||||||
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv));
|
{
|
||||||
|
hadamard_fold(Gprime, winv, w[round]);
|
||||||
|
hadamard_fold(Hprime, w[round], winv);
|
||||||
|
}
|
||||||
|
|
||||||
// PAPER LINES 28-29
|
// PAPER LINES 28-29
|
||||||
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
|
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
|
||||||
|
@ -914,16 +901,16 @@ try_again:
|
||||||
|
|
||||||
// These are used in the inner product rounds
|
// These are used in the inner product rounds
|
||||||
size_t nprime = MN;
|
size_t nprime = MN;
|
||||||
rct::keyV Gprime(MN);
|
std::vector<ge_p3> Gprime(MN);
|
||||||
rct::keyV Hprime(MN);
|
std::vector<ge_p3> Hprime(MN);
|
||||||
rct::keyV aprime(MN);
|
rct::keyV aprime(MN);
|
||||||
rct::keyV bprime(MN);
|
rct::keyV bprime(MN);
|
||||||
const rct::key yinv = invert(y);
|
const rct::key yinv = invert(y);
|
||||||
rct::key yinvpow = rct::identity();
|
rct::key yinvpow = rct::identity();
|
||||||
for (size_t i = 0; i < MN; ++i)
|
for (size_t i = 0; i < MN; ++i)
|
||||||
{
|
{
|
||||||
Gprime[i] = Gi[i];
|
Gprime[i] = Gi_p3[i];
|
||||||
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow);
|
ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
|
||||||
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
|
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
|
||||||
aprime[i] = l[i];
|
aprime[i] = l[i];
|
||||||
bprime[i] = r[i];
|
bprime[i] = r[i];
|
||||||
|
@ -942,18 +929,18 @@ try_again:
|
||||||
nprime /= 2;
|
nprime /= 2;
|
||||||
|
|
||||||
// PAPER LINES 16-17
|
// PAPER LINES 16-17
|
||||||
|
PERF_TIMER_START_BP(PROVE_inner_product);
|
||||||
rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
|
rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
|
||||||
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
||||||
|
PERF_TIMER_STOP(PROVE_inner_product);
|
||||||
|
|
||||||
// PAPER LINES 18-19
|
// PAPER LINES 18-19
|
||||||
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
|
PERF_TIMER_START_BP(PROVE_LR);
|
||||||
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
|
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
|
||||||
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp));
|
L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
|
||||||
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
|
|
||||||
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
|
|
||||||
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
|
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
|
||||||
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp));
|
R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
|
||||||
R[round] = rct::scalarmultKey(R[round], INV_EIGHT);
|
PERF_TIMER_STOP(PROVE_LR);
|
||||||
|
|
||||||
// PAPER LINES 21-22
|
// PAPER LINES 21-22
|
||||||
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
|
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
|
||||||
|
@ -966,12 +953,19 @@ try_again:
|
||||||
|
|
||||||
// PAPER LINES 24-25
|
// PAPER LINES 24-25
|
||||||
const rct::key winv = invert(w[round]);
|
const rct::key winv = invert(w[round]);
|
||||||
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round]));
|
if (nprime > 1)
|
||||||
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv));
|
{
|
||||||
|
PERF_TIMER_START_BP(PROVE_hadamard2);
|
||||||
|
hadamard_fold(Gprime, winv, w[round]);
|
||||||
|
hadamard_fold(Hprime, w[round], winv);
|
||||||
|
PERF_TIMER_STOP(PROVE_hadamard2);
|
||||||
|
}
|
||||||
|
|
||||||
// PAPER LINES 28-29
|
// PAPER LINES 28-29
|
||||||
|
PERF_TIMER_START_BP(PROVE_prime);
|
||||||
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
|
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
|
||||||
bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round]));
|
bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round]));
|
||||||
|
PERF_TIMER_STOP(PROVE_prime);
|
||||||
|
|
||||||
++round;
|
++round;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue