1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
// Copyright 2019-2024 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT
use std::{pin::Pin, task::Poll};

use digest::{Digest, Output};
use pin_project_lite::pin_project;
use tokio::io::{AsyncWrite, AsyncWriteExt, BufWriter};

pin_project! {
    /// Wrapper `AsyncWriter` implementation that calculates the optional checksum on the fly.
    /// Both `Writer` and `Digest` parameters are generic so one can use freely the relevant
    /// structures, e.g. `BufWriter` and `Sha256`.
    pub struct AsyncWriterWithChecksum<D, W> {
        #[pin]
        inner: BufWriter<W>,
        hasher:Option<D>,
    }
}

/// Trait marking the object that is collecting a kind of a checksum.
pub trait Checksum<D: Digest> {
    /// Return the checksum and resets the internal hasher.
    fn finalize(&mut self) -> std::io::Result<Option<Output<D>>>;
}

impl<D: Digest, W: AsyncWriteExt> AsyncWrite for AsyncWriterWithChecksum<D, W> {
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<std::io::Result<usize>> {
        let mut this = self.project();
        let w = this.inner.poll_write(cx, buf);

        if let Some(hasher) = &mut this.hasher {
            if let Poll::Ready(Ok(size)) = w {
                if size > 0 {
                    #[allow(clippy::indexing_slicing)]
                    hasher.update(&buf[..size]);
                }
            }
        }
        w
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        self.project().inner.poll_flush(cx)
    }

    fn poll_shutdown(
        self: Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        self.project().inner.poll_shutdown(cx)
    }
}

impl<D: Digest, W: AsyncWriteExt> Checksum<D> for AsyncWriterWithChecksum<D, W> {
    fn finalize(&mut self) -> std::io::Result<Option<Output<D>>> {
        if let Some(hasher) = &mut self.hasher {
            let hasher = std::mem::replace(hasher, D::new());
            Ok(Some(hasher.finalize()))
        } else {
            Ok(None)
        }
    }
}

impl<D: Digest, W> AsyncWriterWithChecksum<D, W> {
    pub fn new(writer: BufWriter<W>, checksum_enabled: bool) -> Self {
        Self {
            inner: writer,
            hasher: if checksum_enabled {
                Some(Digest::new())
            } else {
                None
            },
        }
    }
}

/// A void writer that does nothing but implements [`AsyncWrite`]
#[derive(Debug, Clone, Default)]
pub struct VoidAsyncWriter;

impl AsyncWrite for VoidAsyncWriter {
    fn poll_write(
        self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<std::io::Result<usize>> {
        std::task::Poll::Ready(Ok(buf.len()))
    }

    fn poll_flush(
        self: std::pin::Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }

    fn poll_shutdown(
        self: Pin<&mut Self>,
        _cx: &mut std::task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        std::task::Poll::Ready(Ok(()))
    }
}

#[cfg(test)]
mod test {
    use rand::{rngs::OsRng, RngCore};
    use sha2::{Sha256, Sha512};
    use tokio::io::{AsyncWriteExt, BufWriter};

    use super::*;

    #[tokio::test]
    async fn file_writer_fs_buf_writer() {
        let temp_file_path = tempfile::Builder::new().tempfile().unwrap();
        let temp_file = tokio::fs::File::create(temp_file_path.path())
            .await
            .unwrap();
        let mut temp_file_writer =
            AsyncWriterWithChecksum::<Sha256, _>::new(BufWriter::new(temp_file), true);
        for _ in 0..1024 {
            let mut bytes = [0; 1024];
            OsRng.fill_bytes(&mut bytes);
            temp_file_writer.write_all(&bytes).await.unwrap();
        }

        temp_file_writer.flush().await.unwrap();
        temp_file_writer.shutdown().await.unwrap();

        let checksum = temp_file_writer.finalize().unwrap();

        let file_hash = {
            let mut hasher = Sha256::default();
            let bytes = std::fs::read(temp_file_path.path()).unwrap();
            hasher.update(&bytes);
            Some(hasher.finalize())
        };

        assert_eq!(checksum, file_hash);
    }

    #[tokio::test]
    async fn given_buffered_writer_and_sha256_digest_should_return_correct_checksum() {
        let buffer = Vec::new();
        let writer = BufWriter::new(buffer);

        let mut writer = AsyncWriterWithChecksum::<Sha256, _>::new(writer, true);

        let data = ["cthulhu", "azathoth", "dagon"];

        // Repeat to make sure the inner hasher can be properly reset
        for _ in 0..2 {
            for old_god in &data {
                writer.write_all(old_god.as_bytes()).await.unwrap();
            }

            assert_eq!(
                "3386191dc5c285074c3827452f4e3b685e3253f5b9ca7c4c2bb3f44d1263aef1",
                format!("{:x}", writer.finalize().unwrap().unwrap())
            );
        }
    }

    #[tokio::test]
    async fn digest_of_nothing() {
        let buffer = Vec::new();
        let writer = BufWriter::new(buffer);
        let mut writer = AsyncWriterWithChecksum::<Sha512, _>::new(writer, true);
        assert_eq!(
            "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e",
            format!("{:x}", writer.finalize().unwrap().unwrap())
        );
    }

    #[tokio::test]
    async fn no_checksum_of_nothing() {
        let buffer = Vec::new();
        let writer = BufWriter::new(buffer);
        let mut writer = AsyncWriterWithChecksum::<Sha512, _>::new(writer, false);
        assert!(writer.finalize().unwrap().is_none());
    }
}