use std::cmp;
use anyhow::{ensure, Context};
use bellperson::{
groth16::{
self,
aggregate::{
aggregate_proofs, verify_aggregate_proof, AggregateProof, ProverSRS, VerifierSRS,
},
create_random_proof_batch, create_random_proof_batch_in_priority, verify_proofs_batch,
PreparedVerifyingKey,
},
Circuit,
};
use blstrs::{Bls12, Scalar as Fr};
use log::info;
use rand::{rngs::OsRng, RngCore};
use rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
use crate::{
error::Result,
multi_proof::MultiProof,
parameter_cache::{Bls12GrothParams, CacheableParameters, ParameterSetMetadata},
partitions::partition_count,
proof::ProofScheme,
};
const MAX_GROTH16_BATCH_SIZE: usize = 10;
#[derive(Clone)]
pub struct SetupParams<'a, S: ProofScheme<'a>> {
pub vanilla_params: <S as ProofScheme<'a>>::SetupParams,
pub partitions: Option<usize>,
pub priority: bool,
}
#[derive(Clone)]
pub struct PublicParams<'a, S: ProofScheme<'a>> {
pub vanilla_params: S::PublicParams,
pub partitions: Option<usize>,
pub priority: bool,
}
pub trait CircuitComponent {
type ComponentPrivateInputs: Default + Clone;
}
pub trait CompoundProof<'a, S: ProofScheme<'a>, C: Circuit<Fr> + CircuitComponent + Send>
where
S::Proof: Sync + Send,
S::PublicParams: ParameterSetMetadata + Sync + Send,
S::PublicInputs: Clone + Sync,
Self: CacheableParameters<C, S::PublicParams>,
{
fn setup(sp: &SetupParams<'a, S>) -> Result<PublicParams<'a, S>> {
Ok(PublicParams {
vanilla_params: S::setup(&sp.vanilla_params)?,
partitions: sp.partitions,
priority: sp.priority,
})
}
fn partition_count(public_params: &PublicParams<'a, S>) -> usize {
match public_params.partitions {
None => 1,
Some(0) => panic!("cannot specify zero partitions"),
Some(k) => k,
}
}
fn prove(
pub_params: &PublicParams<'a, S>,
pub_in: &S::PublicInputs,
priv_in: &S::PrivateInputs,
groth_params: &Bls12GrothParams,
) -> Result<Vec<groth16::Proof<Bls12>>> {
let partition_count = Self::partition_count(pub_params);
ensure!(partition_count > 0, "There must be partitions");
info!("vanilla_proofs:start");
let vanilla_proofs =
S::prove_all_partitions(&pub_params.vanilla_params, pub_in, priv_in, partition_count)?;
info!("vanilla_proofs:finish");
let sanity_check =
S::verify_all_partitions(&pub_params.vanilla_params, pub_in, &vanilla_proofs)?;
ensure!(sanity_check, "sanity check failed");
info!("snark_proof:start");
let groth_proofs = Self::circuit_proofs(
pub_in,
vanilla_proofs,
&pub_params.vanilla_params,
groth_params,
pub_params.priority,
)?;
info!("snark_proof:finish");
Ok(groth_proofs)
}
fn prove_with_vanilla(
pub_params: &PublicParams<'a, S>,
pub_in: &S::PublicInputs,
vanilla_proofs: Vec<S::Proof>,
groth_params: &Bls12GrothParams,
) -> Result<Vec<groth16::Proof<Bls12>>> {
let partition_count = Self::partition_count(pub_params);
ensure!(partition_count > 0, "There must be partitions");
info!("snark_proof:start");
let groth_proofs = Self::circuit_proofs(
pub_in,
vanilla_proofs,
&pub_params.vanilla_params,
groth_params,
pub_params.priority,
)?;
info!("snark_proof:finish");
Ok(groth_proofs)
}
fn verify<'b>(
public_params: &PublicParams<'a, S>,
public_inputs: &S::PublicInputs,
multi_proof: &MultiProof<'b>,
requirements: &S::Requirements,
) -> Result<bool> {
ensure!(
multi_proof.circuit_proofs.len() == Self::partition_count(public_params),
"Inconsistent inputs"
);
let vanilla_public_params = &public_params.vanilla_params;
let pvk = &multi_proof.verifying_key;
if !<S as ProofScheme>::satisfies_requirements(
&public_params.vanilla_params,
requirements,
multi_proof.circuit_proofs.len(),
) {
return Ok(false);
}
let inputs: Vec<_> = (0..multi_proof.circuit_proofs.len())
.into_par_iter()
.map(|k| Self::generate_public_inputs(public_inputs, vanilla_public_params, Some(k)))
.collect::<Result<_>>()?;
let proofs: Vec<_> = multi_proof.circuit_proofs.iter().collect();
let res = verify_proofs_batch(pvk, &mut OsRng, &proofs, &inputs)?;
Ok(res)
}
fn batch_verify<'b>(
public_params: &PublicParams<'a, S>,
public_inputs: &[S::PublicInputs],
multi_proofs: &[MultiProof<'b>],
requirements: &S::Requirements,
) -> Result<bool> {
ensure!(
public_inputs.len() == multi_proofs.len(),
"Inconsistent inputs"
);
for proof in multi_proofs {
ensure!(
proof.circuit_proofs.len() == Self::partition_count(public_params),
"Inconsistent inputs"
);
}
ensure!(!public_inputs.is_empty(), "Cannot verify empty proofs");
let vanilla_public_params = &public_params.vanilla_params;
let pvk = &multi_proofs[0].verifying_key;
for multi_proof in multi_proofs.iter() {
if !<S as ProofScheme>::satisfies_requirements(
&public_params.vanilla_params,
requirements,
multi_proof.circuit_proofs.len(),
) {
return Ok(false);
}
}
let inputs: Vec<_> = multi_proofs
.par_iter()
.zip(public_inputs.par_iter())
.flat_map(|(multi_proof, pub_inputs)| {
(0..multi_proof.circuit_proofs.len())
.into_par_iter()
.map(|k| {
Self::generate_public_inputs(pub_inputs, vanilla_public_params, Some(k))
})
.collect::<Result<Vec<_>>>()
.expect("Invalid public inputs") })
.collect::<Vec<_>>();
let circuit_proofs: Vec<_> = multi_proofs
.iter()
.flat_map(|m| m.circuit_proofs.iter())
.collect();
let res = verify_proofs_batch(pvk, &mut OsRng, &circuit_proofs[..], &inputs)?;
Ok(res)
}
fn circuit_proofs(
pub_in: &S::PublicInputs,
vanilla_proofs: Vec<S::Proof>,
pub_params: &S::PublicParams,
groth_params: &Bls12GrothParams,
priority: bool,
) -> Result<Vec<groth16::Proof<Bls12>>> {
let mut rng = OsRng;
ensure!(
!vanilla_proofs.is_empty(),
"cannot create a circuit proof over missing vanilla proofs"
);
let mut circuits = vanilla_proofs
.into_par_iter()
.enumerate()
.map(|(k, vanilla_proof)| {
Self::circuit(
pub_in,
C::ComponentPrivateInputs::default(),
&vanilla_proof,
pub_params,
Some(k),
)
})
.collect::<Result<Vec<_>>>()?;
let create_random_proof_batch_fun = if priority {
create_random_proof_batch_in_priority
} else {
create_random_proof_batch
};
let mut groth_proofs = Vec::with_capacity(circuits.len());
while !circuits.is_empty() {
let size = cmp::min(MAX_GROTH16_BATCH_SIZE, circuits.len());
let batch = circuits.drain(0..size).collect();
let proofs = create_random_proof_batch_fun(batch, groth_params, &mut rng)?;
groth_proofs.extend_from_slice(&proofs);
}
groth_proofs
.iter()
.map(|groth_proof| {
let mut proof_vec = Vec::new();
groth_proof.write(&mut proof_vec)?;
let gp = groth16::Proof::<Bls12>::read(&proof_vec[..])?;
Ok(gp)
})
.collect()
}
fn aggregate_proofs(
prover_srs: &ProverSRS<Bls12>,
hashed_seeds_and_comm_rs: &[u8],
proofs: &[groth16::Proof<Bls12>],
version: groth16::aggregate::AggregateVersion,
) -> Result<AggregateProof<Bls12>> {
Ok(aggregate_proofs::<Bls12>(
prover_srs,
hashed_seeds_and_comm_rs,
proofs,
version,
)?)
}
fn verify_aggregate_proofs(
ip_verifier_srs: &VerifierSRS<Bls12>,
pvk: &PreparedVerifyingKey<Bls12>,
hashed_seeds_and_comm_rs: &[u8],
public_inputs: &[Vec<Fr>],
aggregate_proof: &groth16::aggregate::AggregateProof<Bls12>,
version: groth16::aggregate::AggregateVersion,
) -> Result<bool> {
let rng = OsRng;
Ok(verify_aggregate_proof(
ip_verifier_srs,
pvk,
rng,
public_inputs,
aggregate_proof,
hashed_seeds_and_comm_rs,
version,
)?)
}
fn generate_public_inputs(
pub_in: &S::PublicInputs,
pub_params: &S::PublicParams,
partition_k: Option<usize>,
) -> Result<Vec<Fr>>;
fn circuit(
public_inputs: &S::PublicInputs,
component_private_inputs: C::ComponentPrivateInputs,
vanilla_proof: &S::Proof,
public_param: &S::PublicParams,
partition_k: Option<usize>,
) -> Result<C>;
fn blank_circuit(public_params: &S::PublicParams) -> C;
fn groth_params<R: RngCore>(
rng: Option<&mut R>,
public_params: &S::PublicParams,
) -> Result<Bls12GrothParams> {
Self::get_groth_params(rng, Self::blank_circuit(public_params), public_params)
}
fn verifying_key<R: RngCore>(
rng: Option<&mut R>,
public_params: &S::PublicParams,
) -> Result<groth16::VerifyingKey<Bls12>> {
Self::get_verifying_key(rng, Self::blank_circuit(public_params), public_params)
}
fn srs_key<R: RngCore>(
rng: Option<&mut R>,
public_params: &S::PublicParams,
num_proofs_to_aggregate: usize,
) -> Result<ProverSRS<Bls12>> {
let generic_srs = Self::get_inner_product(
rng,
Self::blank_circuit(public_params),
public_params,
num_proofs_to_aggregate,
)?;
let (prover_srs, _verifier_srs) = generic_srs.specialize(num_proofs_to_aggregate);
Ok(prover_srs)
}
fn srs_verifier_key<R: RngCore>(
rng: Option<&mut R>,
public_params: &S::PublicParams,
num_proofs_to_aggregate: usize,
) -> Result<VerifierSRS<Bls12>> {
let generic_srs = Self::get_inner_product(
rng,
Self::blank_circuit(public_params),
public_params,
num_proofs_to_aggregate,
)?;
let (_prover_srs, verifier_srs) = generic_srs.specialize(num_proofs_to_aggregate);
Ok(verifier_srs)
}
fn circuit_for_test(
public_parameters: &PublicParams<'a, S>,
public_inputs: &S::PublicInputs,
private_inputs: &S::PrivateInputs,
) -> Result<(C, Vec<Fr>)> {
let vanilla_params = &public_parameters.vanilla_params;
let partition_count = partition_count(public_parameters.partitions);
let vanilla_proofs = S::prove_all_partitions(
vanilla_params,
public_inputs,
private_inputs,
partition_count,
)
.context("failed to generate partition proofs")?;
ensure!(
vanilla_proofs.len() == partition_count,
"Vanilla proofs didn't match number of partitions."
);
let partitions_are_verified =
S::verify_all_partitions(vanilla_params, public_inputs, &vanilla_proofs)
.context("failed to verify partition proofs")?;
ensure!(partitions_are_verified, "Vanilla proof didn't verify.");
let partition_pub_in = S::with_partition(public_inputs.clone(), Some(0));
let inputs = Self::generate_public_inputs(&partition_pub_in, vanilla_params, Some(0))?;
let circuit = Self::circuit(
&partition_pub_in,
C::ComponentPrivateInputs::default(),
&vanilla_proofs[0],
vanilla_params,
Some(0),
)?;
Ok((circuit, inputs))
}
fn circuit_for_test_all(
public_parameters: &PublicParams<'a, S>,
public_inputs: &S::PublicInputs,
private_inputs: &S::PrivateInputs,
) -> Result<Vec<(C, Vec<Fr>)>> {
let vanilla_params = &public_parameters.vanilla_params;
let partition_count = partition_count(public_parameters.partitions);
let vanilla_proofs = S::prove_all_partitions(
vanilla_params,
public_inputs,
private_inputs,
partition_count,
)
.context("failed to generate partition proofs")?;
ensure!(
vanilla_proofs.len() == partition_count,
"Vanilla proofs didn't match number of partitions."
);
let partitions_are_verified =
S::verify_all_partitions(vanilla_params, public_inputs, &vanilla_proofs)
.context("failed to verify partition proofs")?;
ensure!(partitions_are_verified, "Vanilla proof didn't verify.");
let mut res = Vec::with_capacity(partition_count);
for (partition, vanilla_proof) in vanilla_proofs.iter().enumerate() {
let partition_pub_in = S::with_partition(public_inputs.clone(), Some(partition));
let inputs =
Self::generate_public_inputs(&partition_pub_in, vanilla_params, Some(partition))?;
let circuit = Self::circuit(
&partition_pub_in,
C::ComponentPrivateInputs::default(),
vanilla_proof,
vanilla_params,
Some(partition),
)?;
res.push((circuit, inputs));
}
Ok(res)
}
}