1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// Copyright 2021-2023 Protocol Labs
// Copyright 2019-2022 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

use std::convert::TryFrom;

use fvm_ipld_encoding::strict_bytes;
use serde::{Deserialize, Deserializer, Serialize};

use super::BitField;
use crate::{Error, MAX_ENCODED_SIZE};

/// A trait for types that can produce a `&BitField` (or fail to do so).
/// Generalizes over `&BitField` and `&mut UnvalidatedBitField`.
pub trait Validate<'a> {
    fn validate(self) -> Result<&'a BitField, Error>;
}

impl<'a> Validate<'a> for &'a mut UnvalidatedBitField {
    /// Validates the RLE+ encoding of the bit field, returning a shared
    /// reference to the decoded bit field.
    fn validate(self) -> Result<&'a BitField, Error> {
        self.validate_mut().map(|bf| &*bf)
    }
}

impl<'a> Validate<'a> for &'a BitField {
    fn validate(self) -> Result<&'a BitField, Error> {
        Ok(self)
    }
}

/// A bit field that may not yet have been validated for valid RLE+.
/// Used to defer this validation step until when the bit field is
/// first used, rather than at deserialization.
#[derive(Debug, Serialize)]
#[serde(untagged)]
pub enum UnvalidatedBitField {
    Validated(BitField),
    Unvalidated(#[serde(with = "strict_bytes")] Vec<u8>),
}

impl UnvalidatedBitField {
    /// Validates the RLE+ encoding of the bit field, returning a unique
    /// reference to the decoded bit field.
    pub fn validate_mut(&mut self) -> Result<&mut BitField, Error> {
        if let Self::Unvalidated(bytes) = self {
            *self = Self::Validated(BitField::from_bytes(bytes)?);
        }

        match self {
            Self::Validated(bf) => Ok(bf),
            Self::Unvalidated(_) => unreachable!(),
        }
    }
}
#[cfg(feature = "enable-arbitrary")]
use arbitrary::{Arbitrary, Unstructured};

#[cfg(feature = "enable-arbitrary")]
impl<'a> Arbitrary<'a> for UnvalidatedBitField {
    fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result<Self> {
        let bf: BitField = u.arbitrary()?;
        Ok(if *u.choose(&[true, false])? {
            Self::Validated(bf)
        } else {
            Self::Unvalidated(bf.to_bytes())
        })
    }

    fn size_hint(depth: usize) -> (usize, Option<usize>) {
        arbitrary::size_hint::and(BitField::size_hint(depth), (1, Some(1)))
    }
}

impl From<BitField> for UnvalidatedBitField {
    fn from(bf: BitField) -> Self {
        Self::Validated(bf)
    }
}

impl TryFrom<UnvalidatedBitField> for BitField {
    type Error = Error;

    fn try_from(bf: UnvalidatedBitField) -> Result<Self, Self::Error> {
        match bf {
            UnvalidatedBitField::Validated(bf) => Ok(bf),
            UnvalidatedBitField::Unvalidated(bf) => BitField::from_bytes(&bf),
        }
    }
}

impl<'de> Deserialize<'de> for UnvalidatedBitField {
    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        let bytes: Vec<u8> = strict_bytes::deserialize(deserializer)?;
        if bytes.len() > MAX_ENCODED_SIZE {
            return Err(serde::de::Error::custom(format!(
                "encoded bitfield was too large {}",
                bytes.len()
            )));
        }
        Ok(Self::Unvalidated(bytes))
    }
}