1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
// Copyright 2019-2024 ChainSafe Systems
// SPDX-License-Identifier: Apache-2.0, MIT

mod decoder;
use std::{io, marker::PhantomData, time::Duration};

use async_trait::async_trait;
use decoder::DagCborDecodingReader;
use futures::prelude::*;
use libp2p::request_response::{self, OutboundFailure};
use serde::{de::DeserializeOwned, Serialize};

/// Generic `Cbor` `RequestResponse` type. This is just needed to satisfy
/// [`request_response::Codec`] for Hello and `ChainExchange` protocols without
/// duplication.
#[derive(Clone)]
pub struct CborRequestResponse<P, RQ, RS> {
    protocol: PhantomData<P>,
    request: PhantomData<RQ>,
    response: PhantomData<RS>,
}

impl<P, RQ, RS> Default for CborRequestResponse<P, RQ, RS> {
    fn default() -> Self {
        Self {
            protocol: PhantomData::<P>,
            request: PhantomData::<RQ>,
            response: PhantomData::<RS>,
        }
    }
}

/// Libp2p request response outbound error type. This indicates a failure
/// sending a request to a peer. This is different from a failure response from
/// a node, as this is an error that prevented a response.
///
/// This type mirrors the internal libp2p type, but this avoids having to expose
/// that internal type.
#[derive(Debug, thiserror::Error)]
pub enum RequestResponseError {
    /// The request could not be sent because a dialing attempt failed.
    #[error("DialFailure")]
    DialFailure,
    /// The request timed out before a response was received.
    ///
    /// It is not known whether the request may have been
    /// received (and processed) by the remote peer.
    #[error("Timeout")]
    Timeout,
    /// The connection closed before a response was received.
    ///
    /// It is not known whether the request may have been
    /// received (and processed) by the remote peer.
    #[error("ConnectionClosed")]
    ConnectionClosed,
    /// The remote supports none of the requested protocols.
    #[error("UnsupportedProtocols")]
    UnsupportedProtocols,
    /// An IO failure happened on an outbound stream.
    #[error("{0}")]
    Io(io::Error),
}

impl From<OutboundFailure> for RequestResponseError {
    fn from(err: OutboundFailure) -> Self {
        match err {
            OutboundFailure::DialFailure => Self::DialFailure,
            OutboundFailure::Timeout => Self::Timeout,
            OutboundFailure::ConnectionClosed => Self::ConnectionClosed,
            OutboundFailure::UnsupportedProtocols => Self::UnsupportedProtocols,
            OutboundFailure::Io(e) => Self::Io(e),
        }
    }
}

#[async_trait]
impl<P, RQ, RS> request_response::Codec for CborRequestResponse<P, RQ, RS>
where
    P: AsRef<str> + Send + Clone,
    RQ: Serialize + DeserializeOwned + Send + Sync,
    RS: Serialize + DeserializeOwned + Send + Sync,
{
    type Protocol = P;
    type Request = RQ;
    type Response = RS;

    async fn read_request<T>(&mut self, _: &Self::Protocol, io: &mut T) -> io::Result<Self::Request>
    where
        T: AsyncRead + Unpin + Send,
    {
        read_request_and_decode(io).await
    }

    async fn read_response<T>(
        &mut self,
        _: &Self::Protocol,
        io: &mut T,
    ) -> io::Result<Self::Response>
    where
        T: AsyncRead + Unpin + Send,
    {
        let mut bytes = vec![];
        io.read_to_end(&mut bytes).await?;
        serde_ipld_dagcbor::de::from_reader(bytes.as_slice()).map_err(io::Error::other)
    }

    async fn write_request<T>(
        &mut self,
        _: &Self::Protocol,
        io: &mut T,
        req: Self::Request,
    ) -> io::Result<()>
    where
        T: AsyncWrite + Unpin + Send,
    {
        encode_and_write(io, req).await
    }

    async fn write_response<T>(
        &mut self,
        _: &Self::Protocol,
        io: &mut T,
        res: Self::Response,
    ) -> io::Result<()>
    where
        T: AsyncWrite + Unpin + Send,
    {
        encode_and_write(io, res).await
    }
}

// Because of how lotus implements the protocol, it will deadlock when calling
// `io.ReadToEnd` on requests.
//
// for sending requests, the flow in lotus is
// 1. write encoded request bytes
// 2. wait for response
// 3. close request stream, which sends `FIN` header over `yamux` protocol
// if we call `io.ReadToEnd` before `FIN` is sent, it will deadlock
//
// but for sending responses, the flow in lotus is
// 1. receive request
// 2. write encode response bytes
// 3. close response stream, which sends `FIN` header over `yamux` protocol
// and we call `io.ReadToEnd` after `FIN` is sent, it will not deadlock
//
// Note: `FIN` - Performs a half-close of a stream. May be sent with a data
// message or window update. See <https://github.com/libp2p/go-yamux/blob/master/spec.md#flag-field>
//
// `io` is essentially [yamux::Stream](https://docs.rs/yamux/0.11.0/yamux/struct.Stream.html)
//
async fn read_request_and_decode<IO, T>(io: &mut IO) -> io::Result<T>
where
    IO: AsyncRead + Unpin,
    T: serde::de::DeserializeOwned,
{
    const MAX_BYTES_ALLOWED: usize = 2 * 1024 * 1024; // messages over 2MB are likely malicious
    const TIMEOUT: Duration = Duration::from_secs(30);

    // Currently the protocol does not send length encoded message,
    // and we use `decode-success-with-no-trailing-data` to detect end of frame
    // just like what `FramedRead` does, so it's possible to cause deadlock at
    // `io.poll_ready` Adding timeout here to mitigate the issue
    match tokio::time::timeout(TIMEOUT, DagCborDecodingReader::new(io, MAX_BYTES_ALLOWED)).await {
        Ok(r) => r,
        Err(_) => {
            let err = io::Error::other("read_and_decode timeout");
            tracing::warn!("{err}");
            Err(err)
        }
    }
}

async fn encode_and_write<IO, T>(io: &mut IO, data: T) -> io::Result<()>
where
    IO: AsyncWrite + Unpin,
    T: serde::Serialize,
{
    let bytes = fvm_ipld_encoding::to_vec(&data).map_err(io::Error::other)?;
    io.write_all(&bytes).await?;
    io.close().await?;
    Ok(())
}