bulletproofs: speedup PROVE

This commit is contained in:
moneromooo-monero 2018-08-07 08:02:42 +00:00
parent 2287fb9fb4
commit 4564a5d17b
No known key found for this signature in database
GPG Key ID: 686F07454D6CEFC3
1 changed files with 58 additions and 64 deletions

View File

@ -127,15 +127,6 @@ static void sub_acc_p3(ge_p3 *acc_p3, const rct::key &point)
ge_p1p1_to_p3(acc_p3, &p1);
}
static rct::key scalarmultKey(const ge_p3 &P, const rct::key &a)
{
ge_p2 R;
ge_scalarmult(&R, a.bytes, &P);
rct::key aP;
ge_tobytes(aP.bytes, &R);
return aP;
}
static rct::key get_exponent(const rct::key &base, size_t idx)
{
static const std::string salt("bulletproof");
@ -193,23 +184,28 @@ static rct::key vector_exponent(const rct::keyV &a, const rct::keyV &b)
}
/* Compute a custom vector-scalar commitment */
static rct::key vector_exponent_custom(const rct::keyV &A, const rct::keyV &B, const rct::keyV &a, const rct::keyV &b)
static rct::key cross_vector_exponent8(size_t size, const std::vector<ge_p3> &A, size_t Ao, const std::vector<ge_p3> &B, size_t Bo, const rct::keyV &a, size_t ao, const rct::keyV &b, size_t bo, const ge_p3 *extra_point, const rct::key *extra_scalar)
{
CHECK_AND_ASSERT_THROW_MES(A.size() == B.size(), "Incompatible sizes of A and B");
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b");
CHECK_AND_ASSERT_THROW_MES(a.size() == A.size(), "Incompatible sizes of a and A");
CHECK_AND_ASSERT_THROW_MES(a.size() <= maxN*maxM, "Incompatible sizes of a and maxN");
CHECK_AND_ASSERT_THROW_MES(size + Ao <= A.size(), "Incompatible size for A");
CHECK_AND_ASSERT_THROW_MES(size + Bo <= B.size(), "Incompatible size for B");
CHECK_AND_ASSERT_THROW_MES(size + ao <= a.size(), "Incompatible size for a");
CHECK_AND_ASSERT_THROW_MES(size + bo <= b.size(), "Incompatible size for b");
CHECK_AND_ASSERT_THROW_MES(size <= maxN*maxM, "size is too large");
CHECK_AND_ASSERT_THROW_MES(!!extra_point == !!extra_scalar, "only one of extra point/scalar present");
std::vector<MultiexpData> multiexp_data;
multiexp_data.reserve(a.size()*2);
for (size_t i = 0; i < a.size(); ++i)
multiexp_data.resize(size*2 + (!!extra_point));
for (size_t i = 0; i < size; ++i)
{
multiexp_data.resize(multiexp_data.size() + 1);
multiexp_data.back().scalar = a[i];
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, A[i].bytes) == 0, "ge_frombytes_vartime failed");
multiexp_data.resize(multiexp_data.size() + 1);
multiexp_data.back().scalar = b[i];
CHECK_AND_ASSERT_THROW_MES(ge_frombytes_vartime(&multiexp_data.back().point, B[i].bytes) == 0, "ge_frombytes_vartime failed");
sc_mul(multiexp_data[i*2].scalar.bytes, a[ao+i].bytes, INV_EIGHT.bytes);;
multiexp_data[i*2].point = A[Ao+i];
sc_mul(multiexp_data[i*2+1].scalar.bytes, b[bo+i].bytes, INV_EIGHT.bytes);
multiexp_data[i*2+1].point = B[Bo+i];
}
if (extra_point)
{
sc_mul(multiexp_data.back().scalar.bytes, extra_scalar->bytes, INV_EIGHT.bytes);
multiexp_data.back().point = *extra_point;
}
return multiexp(multiexp_data, false);
}
@ -273,16 +269,19 @@ static rct::keyV hadamard(const rct::keyV &a, const rct::keyV &b)
return res;
}
/* Given two curvepoint arrays, construct the Hadamard product */
static rct::keyV hadamard2(const rct::keyV &a, const rct::keyV &b)
/* folds a curvepoint array using a two way scaled Hadamard product */
static void hadamard_fold(std::vector<ge_p3> &v, const rct::key &a, const rct::key &b)
{
CHECK_AND_ASSERT_THROW_MES(a.size() == b.size(), "Incompatible sizes of a and b");
rct::keyV res(a.size());
for (size_t i = 0; i < a.size(); ++i)
CHECK_AND_ASSERT_THROW_MES((v.size() & 1) == 0, "Vector size should be even");
const size_t sz = v.size() / 2;
for (size_t n = 0; n < sz; ++n)
{
rct::addKeys(res[i], a[i], b[i]);
ge_dsmp c[2];
ge_dsm_precomp(c[0], &v[n]);
ge_dsm_precomp(c[1], &v[sz + n]);
ge_double_scalarmult_precomp_vartime2_p3(&v[n], a.bytes, c[0], b.bytes, c[1]);
}
return res;
v.resize(sz);
}
/* Add two vectors */
@ -326,17 +325,6 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
return rct::keyV(N, x);
}
/* Exponentiate a curve vector by a scalar */
static rct::keyV vector_scalar2(const rct::keyV &a, const rct::key &x)
{
rct::keyV res(a.size());
for (size_t i = 0; i < a.size(); ++i)
{
rct::scalarmultKey(res[i], a[i], x);
}
return res;
}
/* Get the sum of a vector's elements */
static rct::key vector_sum(const rct::keyV &a)
{
@ -620,16 +608,16 @@ try_again:
// These are used in the inner product rounds
size_t nprime = N;
rct::keyV Gprime(N);
rct::keyV Hprime(N);
std::vector<ge_p3> Gprime(N);
std::vector<ge_p3> Hprime(N);
rct::keyV aprime(N);
rct::keyV bprime(N);
const rct::key yinv = invert(y);
rct::key yinvpow = rct::identity();
for (size_t i = 0; i < N; ++i)
{
Gprime[i] = Gi[i];
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow);
Gprime[i] = Gi_p3[i];
ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
aprime[i] = l[i];
bprime[i] = r[i];
@ -652,14 +640,10 @@ try_again:
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
// PAPER LINES 18-19
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp));
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp));
R[round] = rct::scalarmultKey(R[round], INV_EIGHT);
R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
// PAPER LINES 21-22
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
@ -672,8 +656,11 @@ try_again:
// PAPER LINES 24-25
const rct::key winv = invert(w[round]);
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round]));
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv));
if (nprime > 1)
{
hadamard_fold(Gprime, winv, w[round]);
hadamard_fold(Hprime, w[round], winv);
}
// PAPER LINES 28-29
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
@ -914,16 +901,16 @@ try_again:
// These are used in the inner product rounds
size_t nprime = MN;
rct::keyV Gprime(MN);
rct::keyV Hprime(MN);
std::vector<ge_p3> Gprime(MN);
std::vector<ge_p3> Hprime(MN);
rct::keyV aprime(MN);
rct::keyV bprime(MN);
const rct::key yinv = invert(y);
rct::key yinvpow = rct::identity();
for (size_t i = 0; i < MN; ++i)
{
Gprime[i] = Gi[i];
Hprime[i] = scalarmultKey(Hi_p3[i], yinvpow);
Gprime[i] = Gi_p3[i];
ge_scalarmult_p3(&Hprime[i], yinvpow.bytes, &Hi_p3[i]);
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
aprime[i] = l[i];
bprime[i] = r[i];
@ -942,18 +929,18 @@ try_again:
nprime /= 2;
// PAPER LINES 16-17
PERF_TIMER_START_BP(PROVE_inner_product);
rct::key cL = inner_product(slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
rct::key cR = inner_product(slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
PERF_TIMER_STOP(PROVE_inner_product);
// PAPER LINES 18-19
L[round] = vector_exponent_custom(slice(Gprime, nprime, Gprime.size()), slice(Hprime, 0, nprime), slice(aprime, 0, nprime), slice(bprime, nprime, bprime.size()));
PERF_TIMER_START_BP(PROVE_LR);
sc_mul(tmp.bytes, cL.bytes, x_ip.bytes);
rct::addKeys(L[round], L[round], rct::scalarmultH(tmp));
L[round] = rct::scalarmultKey(L[round], INV_EIGHT);
R[round] = vector_exponent_custom(slice(Gprime, 0, nprime), slice(Hprime, nprime, Hprime.size()), slice(aprime, nprime, aprime.size()), slice(bprime, 0, nprime));
L[round] = cross_vector_exponent8(nprime, Gprime, nprime, Hprime, 0, aprime, 0, bprime, nprime, &ge_p3_H, &tmp);
sc_mul(tmp.bytes, cR.bytes, x_ip.bytes);
rct::addKeys(R[round], R[round], rct::scalarmultH(tmp));
R[round] = rct::scalarmultKey(R[round], INV_EIGHT);
R[round] = cross_vector_exponent8(nprime, Gprime, 0, Hprime, nprime, aprime, nprime, bprime, 0, &ge_p3_H, &tmp);
PERF_TIMER_STOP(PROVE_LR);
// PAPER LINES 21-22
w[round] = hash_cache_mash(hash_cache, L[round], R[round]);
@ -966,12 +953,19 @@ try_again:
// PAPER LINES 24-25
const rct::key winv = invert(w[round]);
Gprime = hadamard2(vector_scalar2(slice(Gprime, 0, nprime), winv), vector_scalar2(slice(Gprime, nprime, Gprime.size()), w[round]));
Hprime = hadamard2(vector_scalar2(slice(Hprime, 0, nprime), w[round]), vector_scalar2(slice(Hprime, nprime, Hprime.size()), winv));
if (nprime > 1)
{
PERF_TIMER_START_BP(PROVE_hadamard2);
hadamard_fold(Gprime, winv, w[round]);
hadamard_fold(Hprime, w[round], winv);
PERF_TIMER_STOP(PROVE_hadamard2);
}
// PAPER LINES 28-29
PERF_TIMER_START_BP(PROVE_prime);
aprime = vector_add(vector_scalar(slice(aprime, 0, nprime), w[round]), vector_scalar(slice(aprime, nprime, aprime.size()), winv));
bprime = vector_add(vector_scalar(slice(bprime, 0, nprime), winv), vector_scalar(slice(bprime, nprime, bprime.size()), w[round]));
PERF_TIMER_STOP(PROVE_prime);
++round;
}