use std::mem::size_of;
use blake2b_simd::blake2b;
pub const FEISTEL_ROUNDS: usize = 3;
pub type Index = u64;
pub type FeistelPrecomputed = (Index, Index, Index);
pub fn precompute(num_elements: Index) -> FeistelPrecomputed {
let mut next_pow4: Index = 4;
let mut log4 = 1;
while next_pow4 < num_elements {
next_pow4 *= 4;
log4 += 1;
}
let left_mask = ((1 << log4) - 1) << log4;
let right_mask = (1 << log4) - 1;
let half_bits = log4;
(left_mask, right_mask, half_bits)
}
pub fn permute(
num_elements: Index,
index: Index,
keys: &[Index],
precomputed: FeistelPrecomputed,
) -> Index {
let mut u = encode(index, keys, precomputed);
while u >= num_elements {
u = encode(u, keys, precomputed)
}
u
}
pub fn invert_permute(
num_elements: Index,
index: Index,
keys: &[Index],
precomputed: FeistelPrecomputed,
) -> Index {
let mut u = decode(index, keys, precomputed);
while u >= num_elements {
u = decode(u, keys, precomputed);
}
u
}
fn common_setup(index: Index, precomputed: FeistelPrecomputed) -> (Index, Index, Index, Index) {
let (left_mask, right_mask, half_bits) = precomputed;
let left = (index & left_mask) >> half_bits;
let right = index & right_mask;
(left, right, right_mask, half_bits)
}
fn encode(index: Index, keys: &[Index], precomputed: FeistelPrecomputed) -> Index {
let (mut left, mut right, right_mask, half_bits) = common_setup(index, precomputed);
for key in keys.iter().take(FEISTEL_ROUNDS) {
let (l, r) = (right, left ^ feistel(right, *key, right_mask));
left = l;
right = r;
}
(left << half_bits) | right
}
fn decode(index: Index, keys: &[Index], precomputed: FeistelPrecomputed) -> Index {
let (mut left, mut right, right_mask, half_bits) = common_setup(index, precomputed);
for i in (0..FEISTEL_ROUNDS).rev() {
let (l, r) = ((right ^ feistel(left, keys[i], right_mask)), left);
left = l;
right = r;
}
(left << half_bits) | right
}
const HALF_FEISTEL_BYTES: usize = size_of::<Index>();
const FEISTEL_BYTES: usize = 2 * HALF_FEISTEL_BYTES;
fn feistel(right: Index, key: Index, right_mask: Index) -> Index {
let mut data: [u8; FEISTEL_BYTES] = [0; FEISTEL_BYTES];
let r = if FEISTEL_BYTES <= 8 {
data[0] = (right >> 24) as u8;
data[1] = (right >> 16) as u8;
data[2] = (right >> 8) as u8;
data[3] = right as u8;
data[4] = (key >> 24) as u8;
data[5] = (key >> 16) as u8;
data[6] = (key >> 8) as u8;
data[7] = key as u8;
let raw = blake2b(&data);
let hash = raw.as_bytes();
Index::from(hash[0]) << 24
| Index::from(hash[1]) << 16
| Index::from(hash[2]) << 8
| Index::from(hash[3])
} else {
data[0] = (right >> 56) as u8;
data[1] = (right >> 48) as u8;
data[2] = (right >> 40) as u8;
data[3] = (right >> 32) as u8;
data[4] = (right >> 24) as u8;
data[5] = (right >> 16) as u8;
data[6] = (right >> 8) as u8;
data[7] = right as u8;
data[8] = (key >> 56) as u8;
data[9] = (key >> 48) as u8;
data[10] = (key >> 40) as u8;
data[11] = (key >> 32) as u8;
data[12] = (key >> 24) as u8;
data[13] = (key >> 16) as u8;
data[14] = (key >> 8) as u8;
data[15] = key as u8;
let raw = blake2b(&data);
let hash = raw.as_bytes();
Index::from(hash[0]) << 56
| Index::from(hash[1]) << 48
| Index::from(hash[2]) << 40
| Index::from(hash[3]) << 32
| Index::from(hash[4]) << 24
| Index::from(hash[5]) << 16
| Index::from(hash[6]) << 8
| Index::from(hash[7])
};
r & right_mask
}
#[cfg(test)]
mod tests {
use super::*;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
const BAD_NS: &[Index] = &[5, 6, 8, 12, 17]; fn encode_decode(n: Index, expect_success: bool) {
let mut failed = false;
let precomputed = precompute(n);
for i in 0..n {
let p = encode(i, &[1, 2, 3, 4], precomputed);
let v = decode(p, &[1, 2, 3, 4], precomputed);
let equal = i == v;
let in_range = p < n;
if expect_success {
assert!(equal, "failed to permute (n = {})", n);
assert!(in_range, "output number is too big (n = {})", n);
} else if !equal || !in_range {
failed = true;
}
}
if !expect_success {
assert!(failed, "expected failure (n = {})", n);
}
}
#[test]
fn test_feistel_power_of_4() {
let mut n = 1;
for _ in 0..4 {
n *= 4;
encode_decode(n, true);
}
for i in BAD_NS.iter() {
encode_decode(*i, false);
}
}
#[test]
fn test_feistel_on_arbitrary_set() {
for n in BAD_NS.iter() {
let precomputed = precompute(*n as Index);
for i in 0..*n {
let p = permute(*n, i, &[1, 2, 3, 4], precomputed);
let v = invert_permute(*n, p, &[1, 2, 3, 4], precomputed);
assert_eq!(i, v, "failed to permute");
assert!(p < *n, "output number is too big");
}
}
}
#[test]
#[ignore]
fn test_feistel_valid_permutation() {
let n = (1u64 << 30) as Index;
let mut flags = vec![false; n as usize];
let precomputed = precompute(n);
let perm: Vec<Index> = (0..n)
.into_par_iter()
.map(|i| permute(n, i, &[1, 2, 3, 4], precomputed))
.collect();
for i in perm {
assert!(i < n, "output number is too big");
flags[i as usize] = true;
}
assert!(flags.iter().all(|f| *f), "output isn't a permutation");
}
}