use std::cmp::Ordering;
use std::hash::{Hash as StdHash, Hasher as StdHasher};
use std::panic::panic_any;
use anyhow::ensure;
use bellperson::{
gadgets::{boolean::Boolean, num::AllocatedNum},
ConstraintSystem, SynthesisError,
};
use blstrs::Scalar as Fr;
use ff::{Field, PrimeField};
use generic_array::typenum::{marker_traits::Unsigned, U2};
use merkletree::{
hash::{Algorithm as LightAlgorithm, Hashable},
merkle::Element,
};
use neptune::{circuit::poseidon_hash, poseidon::Poseidon};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use crate::types::{
Domain, HashFunction, Hasher, PoseidonArity, PoseidonMDArity, POSEIDON_CONSTANTS_16,
POSEIDON_CONSTANTS_2, POSEIDON_CONSTANTS_4, POSEIDON_CONSTANTS_8, POSEIDON_MD_CONSTANTS,
};
#[derive(Default, Copy, Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct PoseidonHasher {}
impl Hasher for PoseidonHasher {
type Domain = PoseidonDomain;
type Function = PoseidonFunction;
fn name() -> String {
"poseidon_hasher".into()
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct PoseidonFunction(Fr);
impl Default for PoseidonFunction {
fn default() -> PoseidonFunction {
PoseidonFunction(Fr::ZERO)
}
}
impl Hashable<PoseidonFunction> for Fr {
fn hash(&self, state: &mut PoseidonFunction) {
state.write(&self.to_repr());
}
}
impl Hashable<PoseidonFunction> for PoseidonDomain {
fn hash(&self, state: &mut PoseidonFunction) {
state.write(&self.0);
}
}
#[derive(Default, Copy, Clone, Debug, Serialize, Deserialize)]
pub struct PoseidonDomain(<Fr as PrimeField>::Repr);
impl AsRef<PoseidonDomain> for PoseidonDomain {
fn as_ref(&self) -> &PoseidonDomain {
self
}
}
impl StdHash for PoseidonDomain {
fn hash<H: StdHasher>(&self, state: &mut H) {
StdHash::hash(&self.0, state);
}
}
impl PartialEq for PoseidonDomain {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for PoseidonDomain {}
impl Ord for PoseidonDomain {
#[inline(always)]
fn cmp(&self, other: &PoseidonDomain) -> Ordering {
(self.0).cmp(&other.0)
}
}
impl PartialOrd for PoseidonDomain {
#[inline(always)]
fn partial_cmp(&self, other: &PoseidonDomain) -> Option<Ordering> {
Some((self.0).cmp(&other.0))
}
}
impl AsRef<[u8]> for PoseidonDomain {
#[inline]
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl Domain for PoseidonDomain {
fn into_bytes(&self) -> Vec<u8> {
self.0.to_vec()
}
fn try_from_bytes(raw: &[u8]) -> anyhow::Result<Self> {
ensure!(
raw.len() == PoseidonDomain::byte_len(),
"invalid amount of bytes"
);
let mut repr = <Fr as PrimeField>::Repr::default();
repr.copy_from_slice(raw);
Ok(PoseidonDomain(repr))
}
fn write_bytes(&self, dest: &mut [u8]) -> anyhow::Result<()> {
ensure!(
dest.len() == PoseidonDomain::byte_len(),
"invalid amount of bytes"
);
dest.copy_from_slice(&self.0);
Ok(())
}
fn random<R: RngCore>(rng: &mut R) -> Self {
Fr::random(rng).into()
}
}
impl Element for PoseidonDomain {
fn byte_len() -> usize {
32
}
fn from_slice(bytes: &[u8]) -> Self {
match PoseidonDomain::try_from_bytes(bytes) {
Ok(res) => res,
Err(err) => panic_any(err),
}
}
fn copy_to_slice(&self, bytes: &mut [u8]) {
bytes.copy_from_slice(&self.0);
}
}
impl StdHasher for PoseidonFunction {
#[inline]
fn write(&mut self, msg: &[u8]) {
self.0 = Fr::from_repr_vartime(shared_hash(msg).0).expect("from_repr failure");
}
#[inline]
fn finish(&self) -> u64 {
unimplemented!()
}
}
fn shared_hash(data: &[u8]) -> PoseidonDomain {
let preimage = data
.chunks(32)
.map(|chunk| {
Fr::from_repr_vartime(PoseidonDomain::from_slice(chunk).0).expect("from_repr failure")
})
.collect::<Vec<_>>();
shared_hash_frs(&preimage).into()
}
fn shared_hash_frs(preimage: &[Fr]) -> Fr {
match preimage.len() {
2 => {
let mut p = Poseidon::new_with_preimage(preimage, &POSEIDON_CONSTANTS_2);
p.hash()
}
4 => {
let mut p = Poseidon::new_with_preimage(preimage, &POSEIDON_CONSTANTS_4);
p.hash()
}
8 => {
let mut p = Poseidon::new_with_preimage(preimage, &POSEIDON_CONSTANTS_8);
p.hash()
}
16 => {
let mut p = Poseidon::new_with_preimage(preimage, &POSEIDON_CONSTANTS_16);
p.hash()
}
_ => panic_any(format!(
"Unsupported arity for Poseidon hasher: {}",
preimage.len()
)),
}
}
impl HashFunction<PoseidonDomain> for PoseidonFunction {
fn hash(data: &[u8]) -> PoseidonDomain {
shared_hash(data)
}
fn hash2(a: &PoseidonDomain, b: &PoseidonDomain) -> PoseidonDomain {
let mut p =
Poseidon::new_with_preimage(&[(*a).into(), (*b).into()][..], &*POSEIDON_CONSTANTS_2);
let fr: Fr = p.hash();
fr.into()
}
fn hash_md(input: &[PoseidonDomain]) -> PoseidonDomain {
assert!(input.len() > 1, "hash_md needs more than one element.");
let arity = PoseidonMDArity::to_usize();
let mut p = Poseidon::new(&*POSEIDON_MD_CONSTANTS);
let fr_input = input
.iter()
.map(|x| Fr::from_repr_vartime(x.0).expect("from_repr failure"))
.collect::<Vec<_>>();
fr_input[1..]
.chunks(arity - 1)
.fold(fr_input[0], |acc, elts| {
p.reset();
p.input(acc).expect("input failure"); elts.iter().for_each(|elt| {
let _ = p.input(*elt).expect("input failure");
});
p.hash()
})
.into()
}
fn hash_leaf_circuit<CS: ConstraintSystem<Fr>>(
cs: CS,
left: &AllocatedNum<Fr>,
right: &AllocatedNum<Fr>,
_height: usize,
) -> Result<AllocatedNum<Fr>, SynthesisError> {
let preimage = vec![left.clone(), right.clone()];
poseidon_hash::<CS, Fr, U2>(cs, preimage, U2::PARAMETERS())
}
fn hash_multi_leaf_circuit<Arity: 'static + PoseidonArity, CS: ConstraintSystem<Fr>>(
cs: CS,
leaves: &[AllocatedNum<Fr>],
_height: usize,
) -> Result<AllocatedNum<Fr>, SynthesisError> {
let params = Arity::PARAMETERS();
poseidon_hash::<CS, Fr, Arity>(cs, leaves.to_vec(), params)
}
fn hash_md_circuit<CS: ConstraintSystem<Fr>>(
cs: &mut CS,
elements: &[AllocatedNum<Fr>],
) -> Result<AllocatedNum<Fr>, SynthesisError> {
let params = PoseidonMDArity::PARAMETERS();
let arity = PoseidonMDArity::to_usize();
let mut hash = elements[0].clone();
let mut preimage = vec![hash.clone(); arity]; for (hash_num, elts) in elements[1..].chunks(arity - 1).enumerate() {
preimage[0] = hash;
for (i, elt) in elts.iter().enumerate() {
preimage[i + 1] = elt.clone();
}
#[allow(clippy::needless_range_loop)]
for i in (elts.len() + 1)..arity {
preimage[i] =
AllocatedNum::alloc(cs.namespace(|| format!("padding {}", i)), || Ok(Fr::ZERO))
.expect("alloc failure");
}
let cs = cs.namespace(|| format!("hash md {}", hash_num));
hash = poseidon_hash::<_, Fr, PoseidonMDArity>(cs, preimage.clone(), params)?.clone();
}
Ok(hash)
}
fn hash_circuit<CS: ConstraintSystem<Fr>>(
_cs: CS,
_bits: &[Boolean],
) -> Result<AllocatedNum<Fr>, SynthesisError> {
unimplemented!();
}
fn hash2_circuit<CS>(
cs: CS,
a: &AllocatedNum<Fr>,
b: &AllocatedNum<Fr>,
) -> Result<AllocatedNum<Fr>, SynthesisError>
where
CS: ConstraintSystem<Fr>,
{
let preimage = vec![a.clone(), b.clone()];
poseidon_hash::<CS, Fr, U2>(cs, preimage, U2::PARAMETERS())
}
}
impl LightAlgorithm<PoseidonDomain> for PoseidonFunction {
#[inline]
fn hash(&mut self) -> PoseidonDomain {
self.0.into()
}
#[inline]
fn reset(&mut self) {
self.0 = Fr::ZERO;
}
fn leaf(&mut self, leaf: PoseidonDomain) -> PoseidonDomain {
leaf
}
fn node(
&mut self,
left: PoseidonDomain,
right: PoseidonDomain,
_height: usize,
) -> PoseidonDomain {
shared_hash_frs(&[
Fr::from_repr_vartime(left.0).expect("from_repr failure"),
Fr::from_repr_vartime(right.0).expect("from_repr failure"),
])
.into()
}
fn multi_node(&mut self, parts: &[PoseidonDomain], _height: usize) -> PoseidonDomain {
match parts.len() {
1 | 2 | 4 | 8 | 16 => shared_hash_frs(
&parts
.iter()
.enumerate()
.map(|(i, x)| {
if let Some(fr) = Fr::from_repr_vartime(x.0) {
fr
} else {
panic_any(format!("from_repr failure at {}", i));
}
})
.collect::<Vec<_>>(),
)
.into(),
arity => panic_any(format!("unsupported arity {}", arity)),
}
}
}
impl From<Fr> for PoseidonDomain {
#[inline]
fn from(val: Fr) -> Self {
PoseidonDomain(val.to_repr())
}
}
impl From<[u8; 32]> for PoseidonDomain {
#[inline]
fn from(val: [u8; 32]) -> Self {
PoseidonDomain(val)
}
}
impl From<PoseidonDomain> for [u8; 32] {
#[inline]
fn from(val: PoseidonDomain) -> Self {
val.0
}
}
impl From<PoseidonDomain> for Fr {
#[inline]
fn from(val: PoseidonDomain) -> Self {
Fr::from_repr_vartime(val.0).expect("from_repr failure")
}
}
#[cfg(test)]
mod tests {
use super::*;
use bellperson::util_cs::test_cs::TestConstraintSystem;
use merkletree::{merkle::MerkleTree, store::VecStore};
fn u64s_to_u8s(u64s: [u64; 4]) -> [u8; 32] {
let mut bytes = [0u8; 32];
bytes[..8].copy_from_slice(&u64s[0].to_le_bytes());
bytes[8..16].copy_from_slice(&u64s[1].to_le_bytes());
bytes[16..24].copy_from_slice(&u64s[2].to_le_bytes());
bytes[24..].copy_from_slice(&u64s[3].to_le_bytes());
bytes
}
#[test]
fn test_path() {
let values = [
PoseidonDomain(Fr::ONE.to_repr()),
PoseidonDomain(Fr::ONE.to_repr()),
PoseidonDomain(Fr::ONE.to_repr()),
PoseidonDomain(Fr::ONE.to_repr()),
];
let t = MerkleTree::<PoseidonDomain, PoseidonFunction, VecStore<_>, U2>::new(
values.iter().copied(),
)
.expect("merkle tree new failure");
let p = t.gen_proof(0).expect("gen_proof failure"); assert_eq!(*p.path(), vec![0, 0]);
assert!(p
.validate::<PoseidonFunction>()
.expect("failed to validate"));
}
#[test]
fn test_poseidon_hasher() {
let leaves = [
PoseidonDomain(Fr::ONE.to_repr()),
PoseidonDomain(Fr::ZERO.to_repr()),
PoseidonDomain(Fr::ZERO.to_repr()),
PoseidonDomain(Fr::ONE.to_repr()),
];
let t = MerkleTree::<PoseidonDomain, PoseidonFunction, VecStore<_>, U2>::new(
leaves.iter().copied(),
)
.expect("merkle tree new failure");
assert_eq!(t.leafs(), 4);
let mut a = PoseidonFunction::default();
assert_eq!(t.read_at(0).expect("read_at failure"), leaves[0]);
assert_eq!(t.read_at(1).expect("read_at failure"), leaves[1]);
assert_eq!(t.read_at(2).expect("read_at failure"), leaves[2]);
assert_eq!(t.read_at(3).expect("read_at failure"), leaves[3]);
let i1 = a.node(leaves[0], leaves[1], 0);
a.reset();
let i2 = a.node(leaves[2], leaves[3], 0);
a.reset();
assert_eq!(t.read_at(4).expect("read_at failure"), i1);
assert_eq!(t.read_at(5).expect("read_at failure"), i2);
let root = a.node(i1, i2, 1);
a.reset();
assert_eq!(
t.read_at(4).expect("read_at failure").0,
u64s_to_u8s([
0xb339ff6079800b5e,
0xec5907b3dc3094af,
0x93c003cc74a24f26,
0x042f94ffbe786bc3,
]),
);
let expected = u64s_to_u8s([
0xefbb8be3e291e671,
0x77cc72b8cb2b5ad2,
0x30eb6385ae6b74ae,
0x1effebb7b26ad9eb,
]);
let actual = t.read_at(6).expect("read_at failure").0;
assert_eq!(actual, expected);
assert_eq!(t.read_at(6).expect("read_at failure"), root);
}
#[test]
fn test_as_ref() {
let cases: Vec<[u64; 4]> = vec![
[0, 0, 0, 0],
[
14963070332212552755,
2414807501862983188,
16116531553419129213,
6357427774790868134,
],
];
for case in cases.into_iter() {
let val = PoseidonDomain(u64s_to_u8s(case));
for _ in 0..100 {
assert_eq!(val.into_bytes(), val.into_bytes());
}
let raw: &[u8] = val.as_ref();
for (limb, bytes) in case.iter().zip(raw.chunks(8)) {
assert_eq!(&limb.to_le_bytes(), bytes);
}
}
}
#[test]
fn test_serialize() {
let val = PoseidonDomain(u64s_to_u8s([1, 2, 3, 4]));
let ser = serde_json::to_string(&val)
.expect("Failed to serialize `PoseidonDomain` element to JSON string");
let val_back = serde_json::from_str(&ser)
.expect("Failed to deserialize JSON string to `PoseidonnDomain`");
assert_eq!(val, val_back);
}
#[test]
fn test_hash_md() {
let n = 71;
let data = vec![PoseidonDomain(Fr::ONE.to_repr()); n];
let hashed = PoseidonFunction::hash_md(&data);
assert_eq!(
hashed,
PoseidonDomain(u64s_to_u8s([
0x351c54133b332c90,
0xc26f6d625f4e8195,
0x5fd9623643ed9622,
0x59f42220e09ff6f7,
]))
);
}
#[test]
fn test_hash_md_circuit() {
let n = 71;
let data = vec![PoseidonDomain(Fr::ONE.to_repr()); n];
let mut cs = TestConstraintSystem::<Fr>::new();
let circuit_data = (0..n)
.map(|n| {
AllocatedNum::alloc(cs.namespace(|| format!("input {}", n)), || Ok(Fr::ONE))
.expect("alloc failure")
})
.collect::<Vec<_>>();
let hashed = PoseidonFunction::hash_md(&data);
let hashed_fr = Fr::from_repr_vartime(hashed.0).expect("from_repr failure");
let circuit_hashed = PoseidonFunction::hash_md_circuit(&mut cs, circuit_data.as_slice())
.expect("hash_md_circuit failure");
assert!(cs.is_satisfied());
let expected_constraints = 2_770;
let actual_constraints = cs.num_constraints();
assert_eq!(expected_constraints, actual_constraints);
assert_eq!(
hashed_fr,
circuit_hashed.get_value().expect("get_value failure")
);
}
}