use std::time::Duration;
use std::{borrow::Cow, num::NonZeroUsize};
use super::{
beacon_entries::BeaconEntry,
signatures::{
verify_messages_chained, PublicKeyOnG1, PublicKeyOnG2, SignatureOnG1, SignatureOnG2,
},
};
use crate::shim::clock::ChainEpoch;
use crate::shim::version::NetworkVersion;
use crate::utils::net::global_http_client;
use anyhow::Context as _;
use async_trait::async_trait;
use bls_signatures::Serialize as _;
use itertools::Itertools as _;
use lru::LruCache;
use parking_lot::RwLock;
use serde::{Deserialize as SerdeDeserialize, Serialize as SerdeSerialize};
use url::Url;
pub const IGNORE_DRAND_VAR: &str = "IGNORE_DRAND";
#[derive(PartialEq, Eq, Copy, Clone, Debug)]
pub enum DrandNetwork {
Mainnet,
Quicknet,
Incentinet,
}
impl DrandNetwork {
pub fn is_unchained(&self) -> bool {
matches!(self, Self::Quicknet)
}
pub fn is_chained(&self) -> bool {
!self.is_unchained()
}
}
#[derive(Clone)]
pub struct DrandConfig<'a> {
pub servers: Vec<Url>,
pub chain_info: ChainInfo<'a>,
pub network_type: DrandNetwork,
}
pub struct BeaconSchedule(pub Vec<BeaconPoint>);
impl BeaconSchedule {
pub fn with_capacity(capacity: usize) -> Self {
BeaconSchedule(Vec::with_capacity(capacity))
}
pub async fn beacon_entries_for_block(
&self,
network_version: NetworkVersion,
epoch: ChainEpoch,
parent_epoch: ChainEpoch,
prev: &BeaconEntry,
) -> Result<Vec<BeaconEntry>, anyhow::Error> {
let (cb_epoch, curr_beacon) = self.beacon_for_epoch(epoch)?;
if curr_beacon.network().is_chained() {
let (pb_epoch, _) = self.beacon_for_epoch(parent_epoch)?;
if cb_epoch != pb_epoch {
let round = curr_beacon.max_beacon_round_for_epoch(network_version, epoch);
let mut entries = Vec::with_capacity(2);
entries.push(curr_beacon.entry(round - 1).await?);
entries.push(curr_beacon.entry(round).await?);
return Ok(entries);
}
}
let max_round = curr_beacon.max_beacon_round_for_epoch(network_version, epoch);
if max_round == prev.round() {
tracing::warn!("Unexpected `max_round == prev.round()` condition, network_version: {network_version:?}, max_round: {max_round}, prev_round: {}", prev.round());
return Ok(vec![]);
}
let prev_round = if prev.round() == 0 {
max_round - 1
} else {
prev.round()
};
if curr_beacon.network().is_unchained() {
let entry = curr_beacon.entry(max_round).await?;
Ok(vec![entry])
} else {
let mut cur = max_round;
let mut out = Vec::new();
while cur > prev_round {
let entry = curr_beacon.entry(cur).await?;
cur = entry.round() - 1;
out.push(entry);
}
out.reverse();
Ok(out)
}
}
pub fn beacon_for_epoch(&self, epoch: ChainEpoch) -> anyhow::Result<(ChainEpoch, &dyn Beacon)> {
self.0
.iter()
.rev()
.find(|upgrade| epoch >= upgrade.height)
.map(|upgrade| (upgrade.height, upgrade.beacon.as_ref()))
.context("Invalid beacon schedule, no valid beacon")
}
}
pub struct BeaconPoint {
pub height: ChainEpoch,
pub beacon: Box<dyn Beacon>,
}
#[async_trait]
pub trait Beacon
where
Self: Send + Sync + 'static,
{
fn network(&self) -> DrandNetwork;
fn verify_entries(
&self,
entries: &[BeaconEntry],
prev: &BeaconEntry,
) -> Result<bool, anyhow::Error>;
async fn entry(&self, round: u64) -> anyhow::Result<BeaconEntry>;
fn max_beacon_round_for_epoch(
&self,
network_version: NetworkVersion,
fil_epoch: ChainEpoch,
) -> u64;
}
#[async_trait]
impl Beacon for Box<dyn Beacon> {
fn network(&self) -> DrandNetwork {
self.as_ref().network()
}
fn verify_entries(
&self,
entries: &[BeaconEntry],
prev: &BeaconEntry,
) -> Result<bool, anyhow::Error> {
self.as_ref().verify_entries(entries, prev)
}
async fn entry(&self, round: u64) -> Result<BeaconEntry, anyhow::Error> {
self.as_ref().entry(round).await
}
fn max_beacon_round_for_epoch(
&self,
network_version: NetworkVersion,
fil_epoch: ChainEpoch,
) -> u64 {
self.as_ref()
.max_beacon_round_for_epoch(network_version, fil_epoch)
}
}
#[derive(SerdeDeserialize, SerdeSerialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct ChainInfo<'a> {
pub public_key: Cow<'a, str>,
pub period: i32,
pub genesis_time: i32,
pub hash: Cow<'a, str>,
#[serde(rename = "groupHash")]
pub group_hash: Cow<'a, str>,
}
#[derive(SerdeDeserialize, SerdeSerialize, Debug, Clone)]
pub struct BeaconEntryJson {
round: u64,
randomness: String,
signature: String,
previous_signature: Option<String>,
}
pub struct DrandBeacon {
servers: Vec<Url>,
hash: String,
network: DrandNetwork,
public_key: Vec<u8>,
interval: u64,
drand_gen_time: u64,
fil_gen_time: u64,
fil_round_time: u64,
verified_beacons: RwLock<LruCache<u64, BeaconEntry>>,
}
impl DrandBeacon {
pub fn new(genesis_ts: u64, interval: u64, config: &DrandConfig<'_>) -> Self {
assert_ne!(genesis_ts, 0, "Genesis timestamp cannot be 0");
const CACHE_SIZE: usize = 1000;
Self {
servers: config.servers.clone(),
hash: config.chain_info.hash.to_string(),
network: config.network_type,
public_key: hex::decode(config.chain_info.public_key.as_ref())
.expect("invalid static encoding of drand hex public key"),
interval: config.chain_info.period as u64,
drand_gen_time: config.chain_info.genesis_time as u64,
fil_round_time: interval,
fil_gen_time: genesis_ts,
verified_beacons: RwLock::new(LruCache::new(
NonZeroUsize::new(CACHE_SIZE).expect("Infallible"),
)),
}
}
}
#[async_trait]
impl Beacon for DrandBeacon {
fn network(&self) -> DrandNetwork {
self.network
}
fn verify_entries<'a>(
&self,
entries: &'a [BeaconEntry],
prev: &'a BeaconEntry,
) -> Result<bool, anyhow::Error> {
let mut validated = vec![];
let is_valid = if self.network.is_unchained() {
let mut messages = vec![];
let mut signatures = vec![];
let pk = PublicKeyOnG2::from_bytes(&self.public_key)?;
{
let cache = self.verified_beacons.read();
for entry in entries.iter() {
if cache.contains(&entry.round()) {
continue;
}
messages.push(BeaconEntry::message_unchained(entry.round()));
signatures.push(SignatureOnG1::from_bytes(entry.signature())?);
validated.push(entry);
}
}
pk.verify_batch(
messages.iter().map(AsRef::as_ref).collect_vec().as_slice(),
signatures.iter().collect_vec().as_slice(),
)
} else {
let mut messages = vec![];
let mut signatures = vec![];
let pk = PublicKeyOnG1::from_bytes(&self.public_key)?;
{
let prev_curr_pairs = std::iter::once(prev)
.chain(entries.iter())
.unique_by(|e| e.round())
.tuple_windows::<(_, _)>();
let cache = self.verified_beacons.read();
for (prev, curr) in prev_curr_pairs {
if prev.round() > 0 && !cache.contains(&curr.round()) {
messages.push(BeaconEntry::message_chained(curr.round(), prev.signature()));
signatures.push(SignatureOnG2::from_bytes(curr.signature())?);
validated.push(curr);
}
}
}
verify_messages_chained(
&pk,
messages.iter().map(AsRef::as_ref).collect_vec().as_slice(),
&signatures,
)
};
if is_valid && !validated.is_empty() {
let mut cache = self.verified_beacons.write();
assert!(cache.cap().get() >= validated.len());
for entry in validated {
cache.put(entry.round(), entry.clone());
}
}
Ok(is_valid)
}
async fn entry(&self, round: u64) -> anyhow::Result<BeaconEntry> {
let cached: Option<BeaconEntry> = self.verified_beacons.read().peek(&round).cloned();
match cached {
Some(cached_entry) => Ok(cached_entry),
None => {
async fn fetch_entry_from_url(
url: impl reqwest::IntoUrl,
) -> anyhow::Result<BeaconEntry> {
let resp: BeaconEntryJson = global_http_client()
.get(url)
.timeout(Duration::from_secs(15))
.send()
.await?
.error_for_status()?
.json()
.await?;
anyhow::Ok(BeaconEntry::new(resp.round, hex::decode(resp.signature)?))
}
async fn fetch_entry(
urls: impl Iterator<Item = impl reqwest::IntoUrl>,
) -> anyhow::Result<BeaconEntry> {
let mut errors = vec![];
for url in urls {
match fetch_entry_from_url(url).await {
Ok(e) => return Ok(e),
Err(e) => errors.push(e),
}
}
anyhow::bail!(
"Aggregated errors:\n{}",
errors.into_iter().map(|e| e.to_string()).join("\n\n")
);
}
let urls: Vec<_> = self
.servers
.iter()
.map(|server| {
anyhow::Ok(server.join(&format!("{}/public/{round}", self.hash))?)
})
.try_collect()?;
Ok(
backoff::future::retry(backoff::ExponentialBackoff::default(), || async {
Ok(fetch_entry(urls.iter().cloned()).await?)
})
.await?,
)
}
}
}
fn max_beacon_round_for_epoch(
&self,
network_version: NetworkVersion,
fil_epoch: ChainEpoch,
) -> u64 {
let latest_ts =
((fil_epoch as u64 * self.fil_round_time) + self.fil_gen_time) - self.fil_round_time;
if network_version <= NetworkVersion::V15 {
(latest_ts - self.drand_gen_time) / self.interval
} else {
if latest_ts < self.drand_gen_time {
return 1;
}
let from_genesis = latest_ts - self.drand_gen_time;
from_genesis / self.interval + 1
}
}
}