bulletproofs: rework flow to use sarang's fast batch inversion code
This commit is contained in:
parent
fc9f7d9c81
commit
8629a42cf6
|
@ -29,8 +29,6 @@
|
|||
// Adapted from Java code by Sarang Noether
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <openssl/ssl.h>
|
||||
#include <openssl/bn.h>
|
||||
#include <boost/thread/mutex.hpp>
|
||||
#include "misc_log_ex.h"
|
||||
#include "common/perf_timer.h"
|
||||
|
@ -289,37 +287,59 @@ static rct::keyV vector_dup(const rct::key &x, size_t N)
|
|||
return rct::keyV(N, x);
|
||||
}
|
||||
|
||||
static rct::key switch_endianness(rct::key k)
|
||||
static rct::key sm(rct::key y, int n, const rct::key &x)
|
||||
{
|
||||
std::reverse(k.bytes, k.bytes + sizeof(k));
|
||||
return k;
|
||||
while (n--)
|
||||
sc_mul(y.bytes, y.bytes, y.bytes);
|
||||
sc_mul(y.bytes, y.bytes, x.bytes);
|
||||
return y;
|
||||
}
|
||||
|
||||
/* Compute the inverse of a scalar, the stupid way */
|
||||
/* Compute the inverse of a scalar, the clever way */
|
||||
static rct::key invert(const rct::key &x)
|
||||
{
|
||||
rct::key _1, _10, _100, _11, _101, _111, _1001, _1011, _1111;
|
||||
|
||||
_1 = x;
|
||||
sc_mul(_10.bytes, _1.bytes, _1.bytes);
|
||||
sc_mul(_100.bytes, _10.bytes, _10.bytes);
|
||||
sc_mul(_11.bytes, _10.bytes, _1.bytes);
|
||||
sc_mul(_101.bytes, _10.bytes, _11.bytes);
|
||||
sc_mul(_111.bytes, _10.bytes, _101.bytes);
|
||||
sc_mul(_1001.bytes, _10.bytes, _111.bytes);
|
||||
sc_mul(_1011.bytes, _10.bytes, _1001.bytes);
|
||||
sc_mul(_1111.bytes, _100.bytes, _1011.bytes);
|
||||
|
||||
rct::key inv;
|
||||
sc_mul(inv.bytes, _1111.bytes, _1.bytes);
|
||||
|
||||
BN_CTX *ctx = BN_CTX_new();
|
||||
BIGNUM *X = BN_new();
|
||||
BIGNUM *L = BN_new();
|
||||
BIGNUM *I = BN_new();
|
||||
|
||||
BN_bin2bn(switch_endianness(x).bytes, sizeof(rct::key), X);
|
||||
BN_bin2bn(switch_endianness(rct::curveOrder()).bytes, sizeof(rct::key), L);
|
||||
|
||||
CHECK_AND_ASSERT_THROW_MES(BN_mod_inverse(I, X, L, ctx), "Failed to invert");
|
||||
|
||||
const int len = BN_num_bytes(I);
|
||||
CHECK_AND_ASSERT_THROW_MES((size_t)len <= sizeof(rct::key), "Invalid number length");
|
||||
inv = rct::zero();
|
||||
BN_bn2bin(I, inv.bytes);
|
||||
std::reverse(inv.bytes, inv.bytes + len);
|
||||
|
||||
BN_free(I);
|
||||
BN_free(L);
|
||||
BN_free(X);
|
||||
BN_CTX_free(ctx);
|
||||
inv = sm(inv, 123 + 3, _101);
|
||||
inv = sm(inv, 2 + 2, _11);
|
||||
inv = sm(inv, 1 + 4, _1111);
|
||||
inv = sm(inv, 1 + 4, _1111);
|
||||
inv = sm(inv, 4, _1001);
|
||||
inv = sm(inv, 2, _11);
|
||||
inv = sm(inv, 1 + 4, _1111);
|
||||
inv = sm(inv, 1 + 3, _101);
|
||||
inv = sm(inv, 3 + 3, _101);
|
||||
inv = sm(inv, 3, _111);
|
||||
inv = sm(inv, 1 + 4, _1111);
|
||||
inv = sm(inv, 2 + 3, _111);
|
||||
inv = sm(inv, 2 + 2, _11);
|
||||
inv = sm(inv, 1 + 4, _1011);
|
||||
inv = sm(inv, 2 + 4, _1011);
|
||||
inv = sm(inv, 6 + 4, _1001);
|
||||
inv = sm(inv, 2 + 2, _11);
|
||||
inv = sm(inv, 3 + 2, _11);
|
||||
inv = sm(inv, 3 + 2, _11);
|
||||
inv = sm(inv, 1 + 4, _1001);
|
||||
inv = sm(inv, 1 + 3, _111);
|
||||
inv = sm(inv, 2 + 4, _1111);
|
||||
inv = sm(inv, 1 + 4, _1011);
|
||||
inv = sm(inv, 3, _101);
|
||||
inv = sm(inv, 2 + 4, _1111);
|
||||
inv = sm(inv, 3, _101);
|
||||
inv = sm(inv, 1 + 2, _11);
|
||||
|
||||
#ifdef DEBUG_BP
|
||||
rct::key tmp;
|
||||
|
@ -329,6 +349,34 @@ static rct::key invert(const rct::key &x)
|
|||
return inv;
|
||||
}
|
||||
|
||||
static rct::keyV invert(rct::keyV x)
|
||||
{
|
||||
rct::keyV scratch;
|
||||
scratch.reserve(x.size());
|
||||
|
||||
rct::key acc = rct::identity();
|
||||
for (size_t n = 0; n < x.size(); ++n)
|
||||
{
|
||||
scratch.push_back(acc);
|
||||
if (n == 0)
|
||||
acc = x[0];
|
||||
else
|
||||
sc_mul(acc.bytes, acc.bytes, x[n].bytes);
|
||||
}
|
||||
|
||||
acc = invert(acc);
|
||||
|
||||
rct::key tmp;
|
||||
for (int i = x.size(); i-- > 0; )
|
||||
{
|
||||
sc_mul(tmp.bytes, acc.bytes, x[i].bytes);
|
||||
sc_mul(x[i].bytes, acc.bytes, scratch[i].bytes);
|
||||
acc = tmp;
|
||||
}
|
||||
|
||||
return x;
|
||||
}
|
||||
|
||||
/* Compute the slice of a vector */
|
||||
static rct::keyV slice(const rct::keyV &a, size_t start, size_t stop)
|
||||
{
|
||||
|
@ -702,6 +750,13 @@ Bulletproof bulletproof_PROVE(const std::vector<uint64_t> &v, const rct::keyV &g
|
|||
return bulletproof_PROVE(sv, gamma);
|
||||
}
|
||||
|
||||
struct proof_data_t
|
||||
{
|
||||
rct::key x, y, z, x_ip;
|
||||
std::vector<rct::key> w;
|
||||
size_t logM, inv_offset;
|
||||
};
|
||||
|
||||
/* Given a range proof, determine if it is valid */
|
||||
bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
||||
{
|
||||
|
@ -709,9 +764,17 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
|
||||
PERF_TIMER_START_BP(VERIFY);
|
||||
|
||||
const size_t logN = 6;
|
||||
const size_t N = 1 << logN;
|
||||
|
||||
// sanity and figure out which proof is longest
|
||||
size_t max_length = 0;
|
||||
size_t nV = 0;
|
||||
std::vector<proof_data_t> proof_data;
|
||||
proof_data.reserve(proofs.size());
|
||||
size_t inv_offset = 0;
|
||||
std::vector<rct::key> to_invert;
|
||||
to_invert.reserve(11 * sizeof(proofs));
|
||||
for (const Bulletproof *p: proofs)
|
||||
{
|
||||
const Bulletproof &proof = *p;
|
||||
|
@ -729,46 +792,75 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
|
||||
max_length = std::max(max_length, proof.L.size());
|
||||
nV += proof.V.size();
|
||||
|
||||
// Reconstruct the challenges
|
||||
PERF_TIMER_START_BP(VERIFY_start);
|
||||
proof_data.resize(proof_data.size() + 1);
|
||||
proof_data_t &pd = proof_data.back();
|
||||
rct::key hash_cache = rct::hash_to_scalar(proof.V);
|
||||
pd.y = hash_cache_mash(hash_cache, proof.A, proof.S);
|
||||
CHECK_AND_ASSERT_MES(!(pd.y == rct::zero()), false, "y == 0");
|
||||
pd.z = hash_cache = rct::hash_to_scalar(pd.y);
|
||||
CHECK_AND_ASSERT_MES(!(pd.z == rct::zero()), false, "z == 0");
|
||||
pd.x = hash_cache_mash(hash_cache, pd.z, proof.T1, proof.T2);
|
||||
CHECK_AND_ASSERT_MES(!(pd.x == rct::zero()), false, "x == 0");
|
||||
pd.x_ip = hash_cache_mash(hash_cache, pd.x, proof.taux, proof.mu, proof.t);
|
||||
CHECK_AND_ASSERT_MES(!(pd.x_ip == rct::zero()), false, "x_ip == 0");
|
||||
PERF_TIMER_STOP(VERIFY_start);
|
||||
|
||||
size_t M;
|
||||
for (pd.logM = 0; (M = 1<<pd.logM) <= maxM && M < proof.V.size(); ++pd.logM);
|
||||
CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
|
||||
|
||||
const size_t rounds = pd.logM+logN;
|
||||
CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_21_22);
|
||||
// PAPER LINES 21-22
|
||||
// The inner product challenges are computed per round
|
||||
pd.w.resize(rounds);
|
||||
for (size_t i = 0; i < rounds; ++i)
|
||||
{
|
||||
pd.w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
|
||||
CHECK_AND_ASSERT_MES(!(pd.w[i] == rct::zero()), false, "w[i] == 0");
|
||||
}
|
||||
PERF_TIMER_STOP(VERIFY_line_21_22);
|
||||
|
||||
pd.inv_offset = inv_offset;
|
||||
for (size_t i = 0; i < rounds; ++i)
|
||||
to_invert.push_back(pd.w[i]);
|
||||
to_invert.push_back(pd.y);
|
||||
inv_offset += rounds + 1;
|
||||
}
|
||||
CHECK_AND_ASSERT_MES(max_length < 32, false, "At least one proof is too large");
|
||||
size_t maxMN = 1u << max_length;
|
||||
|
||||
const size_t logN = 6;
|
||||
const size_t N = 1 << logN;
|
||||
rct::key tmp;
|
||||
|
||||
std::vector<MultiexpData> multiexp_data;
|
||||
multiexp_data.reserve(nV + (2 * (10/*logM*/ + logN) + 4) * proofs.size() + 2 * maxMN);
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_24_25_invert);
|
||||
const std::vector<rct::key> inverses = invert(to_invert);
|
||||
PERF_TIMER_STOP(VERIFY_line_24_25_invert);
|
||||
|
||||
// setup weighted aggregates
|
||||
rct::key z1 = rct::zero();
|
||||
rct::key z3 = rct::zero();
|
||||
rct::keyV z4(maxMN, rct::zero()), z5(maxMN, rct::zero());
|
||||
rct::key y0 = rct::zero(), y1 = rct::zero();
|
||||
int proof_data_index = 0;
|
||||
for (const Bulletproof *p: proofs)
|
||||
{
|
||||
const Bulletproof &proof = *p;
|
||||
const proof_data_t &pd = proof_data[proof_data_index++];
|
||||
|
||||
size_t M, logM;
|
||||
for (logM = 0; (M = 1<<logM) <= maxM && M < proof.V.size(); ++logM);
|
||||
CHECK_AND_ASSERT_MES(proof.L.size() == 6+logM, false, "Proof is not the expected size");
|
||||
CHECK_AND_ASSERT_MES(proof.L.size() == 6+pd.logM, false, "Proof is not the expected size");
|
||||
const size_t M = 1 << pd.logM;
|
||||
const size_t MN = M*N;
|
||||
const rct::key weight_y = rct::skGen();
|
||||
const rct::key weight_z = rct::skGen();
|
||||
|
||||
// Reconstruct the challenges
|
||||
PERF_TIMER_START_BP(VERIFY_start);
|
||||
rct::key hash_cache = rct::hash_to_scalar(proof.V);
|
||||
rct::key y = hash_cache_mash(hash_cache, proof.A, proof.S);
|
||||
CHECK_AND_ASSERT_MES(!(y == rct::zero()), false, "y == 0");
|
||||
rct::key z = hash_cache = rct::hash_to_scalar(y);
|
||||
CHECK_AND_ASSERT_MES(!(z == rct::zero()), false, "z == 0");
|
||||
rct::key x = hash_cache_mash(hash_cache, z, proof.T1, proof.T2);
|
||||
CHECK_AND_ASSERT_MES(!(x == rct::zero()), false, "x == 0");
|
||||
rct::key x_ip = hash_cache_mash(hash_cache, x, proof.taux, proof.mu, proof.t);
|
||||
CHECK_AND_ASSERT_MES(!(x_ip == rct::zero()), false, "x_ip == 0");
|
||||
PERF_TIMER_STOP(VERIFY_start);
|
||||
|
||||
// pre-multiply some points by 8
|
||||
rct::keyV proof8_V = proof.V; for (rct::key &k: proof8_V) k = rct::scalarmult8(k);
|
||||
rct::keyV proof8_L = proof.L; for (rct::key &k: proof8_L) k = rct::scalarmult8(k);
|
||||
|
@ -782,10 +874,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
// PAPER LINE 61
|
||||
sc_muladd(y0.bytes, proof.taux.bytes, weight_y.bytes, y0.bytes);
|
||||
|
||||
const rct::keyV zpow = vector_powers(z, M+3);
|
||||
const rct::keyV zpow = vector_powers(pd.z, M+3);
|
||||
|
||||
rct::key k;
|
||||
const rct::key ip1y = vector_power_sum(y, MN);
|
||||
const rct::key ip1y = vector_power_sum(pd.y, MN);
|
||||
sc_mulsub(k.bytes, zpow[2].bytes, ip1y.bytes, rct::zero().bytes);
|
||||
for (size_t j = 1; j <= M; ++j)
|
||||
{
|
||||
|
@ -795,7 +887,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
PERF_TIMER_STOP(VERIFY_line_61);
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_61rl_new);
|
||||
sc_muladd(tmp.bytes, z.bytes, ip1y.bytes, k.bytes);
|
||||
sc_muladd(tmp.bytes, pd.z.bytes, ip1y.bytes, k.bytes);
|
||||
sc_sub(tmp.bytes, proof.t.bytes, tmp.bytes);
|
||||
sc_muladd(y1.bytes, tmp.bytes, weight_y.bytes, y1.bytes);
|
||||
for (size_t j = 0; j < proof8_V.size(); j++)
|
||||
|
@ -803,10 +895,10 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
sc_mul(tmp.bytes, zpow[j+2].bytes, weight_y.bytes);
|
||||
multiexp_data.emplace_back(tmp, proof8_V[j]);
|
||||
}
|
||||
sc_mul(tmp.bytes, x.bytes, weight_y.bytes);
|
||||
sc_mul(tmp.bytes, pd.x.bytes, weight_y.bytes);
|
||||
multiexp_data.emplace_back(tmp, proof8_T1);
|
||||
rct::key xsq;
|
||||
sc_mul(xsq.bytes, x.bytes, x.bytes);
|
||||
sc_mul(xsq.bytes, pd.x.bytes, pd.x.bytes);
|
||||
sc_mul(tmp.bytes, xsq.bytes, weight_y.bytes);
|
||||
multiexp_data.emplace_back(tmp, proof8_T2);
|
||||
PERF_TIMER_STOP(VERIFY_line_61rl_new);
|
||||
|
@ -814,49 +906,34 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
PERF_TIMER_START_BP(VERIFY_line_62);
|
||||
// PAPER LINE 62
|
||||
multiexp_data.emplace_back(weight_z, proof8_A);
|
||||
sc_mul(tmp.bytes, x.bytes, weight_z.bytes);
|
||||
sc_mul(tmp.bytes, pd.x.bytes, weight_z.bytes);
|
||||
multiexp_data.emplace_back(tmp, proof8_S);
|
||||
PERF_TIMER_STOP(VERIFY_line_62);
|
||||
|
||||
// Compute the number of rounds for the inner product
|
||||
const size_t rounds = logM+logN;
|
||||
const size_t rounds = pd.logM+logN;
|
||||
CHECK_AND_ASSERT_MES(rounds > 0, false, "Zero rounds");
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_21_22);
|
||||
// PAPER LINES 21-22
|
||||
// The inner product challenges are computed per round
|
||||
rct::keyV w(rounds);
|
||||
for (size_t i = 0; i < rounds; ++i)
|
||||
{
|
||||
w[i] = hash_cache_mash(hash_cache, proof.L[i], proof.R[i]);
|
||||
CHECK_AND_ASSERT_MES(!(w[i] == rct::zero()), false, "w[i] == 0");
|
||||
}
|
||||
PERF_TIMER_STOP(VERIFY_line_21_22);
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_24_25);
|
||||
// Basically PAPER LINES 24-25
|
||||
// Compute the curvepoints from G[i] and H[i]
|
||||
rct::key yinvpow = rct::identity();
|
||||
rct::key ypow = rct::identity();
|
||||
|
||||
PERF_TIMER_START_BP(VERIFY_line_24_25_invert);
|
||||
const rct::key yinv = invert(y);
|
||||
rct::keyV winv(rounds);
|
||||
for (size_t i = 0; i < rounds; ++i)
|
||||
winv[i] = invert(w[i]);
|
||||
PERF_TIMER_STOP(VERIFY_line_24_25_invert);
|
||||
const rct::key *winv = &inverses[pd.inv_offset];
|
||||
const rct::key yinv = inverses[pd.inv_offset + rounds];
|
||||
|
||||
// precalc
|
||||
PERF_TIMER_START_BP(VERIFY_line_24_25_precalc);
|
||||
rct::keyV w_cache(1<<rounds);
|
||||
w_cache[0] = winv[0];
|
||||
w_cache[1] = w[0];
|
||||
w_cache[1] = pd.w[0];
|
||||
for (size_t j = 1; j < rounds; ++j)
|
||||
{
|
||||
const size_t slots = 1<<(j+1);
|
||||
for (size_t s = slots; s-- > 0; --s)
|
||||
{
|
||||
sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, w[j].bytes);
|
||||
sc_mul(w_cache[s].bytes, w_cache[s/2].bytes, pd.w[j].bytes);
|
||||
sc_mul(w_cache[s-1].bytes, w_cache[s/2].bytes, winv[j].bytes);
|
||||
}
|
||||
}
|
||||
|
@ -876,18 +953,18 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
sc_mul(h_scalar.bytes, h_scalar.bytes, w_cache[(~i) & (MN-1)].bytes);
|
||||
|
||||
// Adjust the scalars using the exponents from PAPER LINE 62
|
||||
sc_add(g_scalar.bytes, g_scalar.bytes, z.bytes);
|
||||
sc_add(g_scalar.bytes, g_scalar.bytes, pd.z.bytes);
|
||||
CHECK_AND_ASSERT_MES(2+i/N < zpow.size(), false, "invalid zpow index");
|
||||
CHECK_AND_ASSERT_MES(i%N < twoN.size(), false, "invalid twoN index");
|
||||
sc_mul(tmp.bytes, zpow[2+i/N].bytes, twoN[i%N].bytes);
|
||||
if (i == 0)
|
||||
{
|
||||
sc_add(tmp.bytes, tmp.bytes, z.bytes);
|
||||
sc_add(tmp.bytes, tmp.bytes, pd.z.bytes);
|
||||
sc_sub(h_scalar.bytes, h_scalar.bytes, tmp.bytes);
|
||||
}
|
||||
else
|
||||
{
|
||||
sc_muladd(tmp.bytes, z.bytes, ypow.bytes, tmp.bytes);
|
||||
sc_muladd(tmp.bytes, pd.z.bytes, ypow.bytes, tmp.bytes);
|
||||
sc_mulsub(h_scalar.bytes, tmp.bytes, yinvpow.bytes, h_scalar.bytes);
|
||||
}
|
||||
|
||||
|
@ -897,12 +974,12 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
if (i == 0)
|
||||
{
|
||||
yinvpow = yinv;
|
||||
ypow = y;
|
||||
ypow = pd.y;
|
||||
}
|
||||
else if (i != MN-1)
|
||||
{
|
||||
sc_mul(yinvpow.bytes, yinvpow.bytes, yinv.bytes);
|
||||
sc_mul(ypow.bytes, ypow.bytes, y.bytes);
|
||||
sc_mul(ypow.bytes, ypow.bytes, pd.y.bytes);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -913,7 +990,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
sc_muladd(z1.bytes, proof.mu.bytes, weight_z.bytes, z1.bytes);
|
||||
for (size_t i = 0; i < rounds; ++i)
|
||||
{
|
||||
sc_mul(tmp.bytes, w[i].bytes, w[i].bytes);
|
||||
sc_mul(tmp.bytes, pd.w[i].bytes, pd.w[i].bytes);
|
||||
sc_mul(tmp.bytes, tmp.bytes, weight_z.bytes);
|
||||
multiexp_data.emplace_back(tmp, proof8_L[i]);
|
||||
sc_mul(tmp.bytes, winv[i].bytes, winv[i].bytes);
|
||||
|
@ -921,7 +998,7 @@ bool bulletproof_VERIFY(const std::vector<const Bulletproof*> &proofs)
|
|||
multiexp_data.emplace_back(tmp, proof8_R[i]);
|
||||
}
|
||||
sc_mulsub(tmp.bytes, proof.a.bytes, proof.b.bytes, proof.t.bytes);
|
||||
sc_mul(tmp.bytes, tmp.bytes, x_ip.bytes);
|
||||
sc_mul(tmp.bytes, tmp.bytes, pd.x_ip.bytes);
|
||||
sc_muladd(z3.bytes, tmp.bytes, weight_z.bytes, z3.bytes);
|
||||
PERF_TIMER_STOP(VERIFY_line_26_new);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue