use std::convert::TryFrom;
use fvm_ipld_encoding::strict_bytes;
use serde::{Deserialize, Deserializer, Serialize};
use super::BitField;
use crate::{Error, MAX_ENCODED_SIZE};
pub trait Validate<'a> {
fn validate(self) -> Result<&'a BitField, Error>;
}
impl<'a> Validate<'a> for &'a mut UnvalidatedBitField {
fn validate(self) -> Result<&'a BitField, Error> {
self.validate_mut().map(|bf| &*bf)
}
}
impl<'a> Validate<'a> for &'a BitField {
fn validate(self) -> Result<&'a BitField, Error> {
Ok(self)
}
}
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum UnvalidatedBitField {
Validated(BitField),
Unvalidated(#[serde(with = "strict_bytes")] Vec<u8>),
}
impl UnvalidatedBitField {
pub fn validate_mut(&mut self) -> Result<&mut BitField, Error> {
if let Self::Unvalidated(bytes) = self {
*self = Self::Validated(BitField::from_bytes(bytes)?);
}
match self {
Self::Validated(bf) => Ok(bf),
Self::Unvalidated(_) => unreachable!(),
}
}
}
#[cfg(feature = "enable-arbitrary")]
use arbitrary::{Arbitrary, Unstructured};
#[cfg(feature = "enable-arbitrary")]
impl<'a> Arbitrary<'a> for UnvalidatedBitField {
fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
let bf: BitField = u.arbitrary()?;
Ok(if *u.choose(&[true, false])? {
Self::Validated(bf)
} else {
Self::Unvalidated(bf.to_bytes())
})
}
fn size_hint(depth: usize) -> (usize, Option<usize>) {
arbitrary::size_hint::and(BitField::size_hint(depth), (1, Some(1)))
}
}
impl From<BitField> for UnvalidatedBitField {
fn from(bf: BitField) -> Self {
Self::Validated(bf)
}
}
impl TryFrom<UnvalidatedBitField> for BitField {
type Error = Error;
fn try_from(bf: UnvalidatedBitField) -> Result<Self, Self::Error> {
match bf {
UnvalidatedBitField::Validated(bf) => Ok(bf),
UnvalidatedBitField::Unvalidated(bf) => BitField::from_bytes(&bf),
}
}
}
impl<'de> Deserialize<'de> for UnvalidatedBitField {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let bytes: Vec<u8> = strict_bytes::deserialize(deserializer)?;
if bytes.len() > MAX_ENCODED_SIZE {
return Err(serde::de::Error::custom(format!(
"encoded bitfield was too large {}",
bytes.len()
)));
}
Ok(Self::Unvalidated(bytes))
}
}