Skip to content

Commit fa2028b

Browse files
committed
Fix chain combinator
1 parent 11620e5 commit fa2028b

File tree

6 files changed

+229
-130
lines changed

6 files changed

+229
-130
lines changed

Cargo.lock

Lines changed: 39 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

manul/src/combinators/chain.rs

Lines changed: 121 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,16 @@ Usage:
4949
when verifying evidence from the chained protocol.
5050
*/
5151

52-
use alloc::{boxed::Box, collections::BTreeMap};
52+
use alloc::{boxed::Box, collections::BTreeMap, format};
5353
use core::fmt::{self, Debug};
5454

5555
use rand_core::CryptoRngCore;
5656

5757
use crate::protocol::{
5858
Artifact, BoxedFormat, BoxedRound, BoxedRoundInfo, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint,
5959
FinalizeOutcome, LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError,
60-
ProtocolMessage, ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, TransitionInfo,
60+
ProtocolMessage, ProtocolValidationError, ReceiveError, RequiredMessages, Round, RoundId, RoundInfo,
61+
TransitionInfo,
6162
};
6263

6364
/// A marker trait that is used to disambiguate blanket trait implementations for [`Protocol`] and [`EntryPoint`].
@@ -197,62 +198,152 @@ where
197198
}
198199
}
199200

200-
impl<Id, C> Protocol<Id> for C
201+
#[derive_where::derive_where(Debug)]
202+
struct RoundInfoWrapper1<Id: 'static, P: ChainedProtocol<Id> + ChainedMarker>(BoxedRoundInfo<Id, P::Protocol1>);
203+
204+
impl<Id, P> RoundInfo<Id> for RoundInfoWrapper1<Id, P>
201205
where
202-
Id: 'static,
203-
C: ChainedProtocol<Id> + ChainedMarker,
206+
P: ChainedProtocol<Id> + ChainedMarker,
204207
{
205-
type Result = <C::Protocol2 as Protocol<Id>>::Result;
206-
type ProtocolError = ChainedProtocolError<Id, C>;
208+
type Protocol = P;
207209

208-
fn round_info(_round_id: &RoundId) -> Option<BoxedRoundInfo<Id, Self>> {
209-
/*let rounds1 = C::Protocol1::rounds()
210-
.into_iter()
211-
.map(|(round_id, round)| (round_id.group_under(1), round));
212-
let rounds2 = C::Protocol2::rounds()
213-
.into_iter()
214-
.map(|(round_id, round)| (round_id.group_under(2), round));
215-
rounds1.chain(rounds2).collect()*/
216-
unimplemented!()
210+
fn verify_direct_message_is_invalid(
211+
&self,
212+
round_id: &RoundId,
213+
format: &BoxedFormat,
214+
message: &DirectMessage,
215+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
216+
) -> Result<(), MessageValidationError> {
217+
let (group, round_id) = round_id.split_group()?;
218+
if group != 1 {
219+
return Err(MessageValidationError::Local(LocalError::new(format!(
220+
"Expected round ID from group 1, got {round_id}"
221+
))));
222+
}
223+
self.0
224+
.as_ref()
225+
.verify_direct_message_is_invalid(&round_id, format, message, &associated_data.protocol1)
217226
}
218227

219-
fn verify_direct_message_is_invalid(
228+
fn verify_echo_broadcast_is_invalid(
229+
&self,
230+
round_id: &RoundId,
220231
format: &BoxedFormat,
232+
message: &EchoBroadcast,
233+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
234+
) -> Result<(), MessageValidationError> {
235+
let (group, round_id) = round_id.split_group()?;
236+
if group != 1 {
237+
return Err(MessageValidationError::Local(LocalError::new(format!(
238+
"Expected round ID from group 1, got {round_id}"
239+
))));
240+
}
241+
self.0
242+
.as_ref()
243+
.verify_echo_broadcast_is_invalid(&round_id, format, message, &associated_data.protocol1)
244+
}
245+
246+
fn verify_normal_broadcast_is_invalid(
247+
&self,
248+
round_id: &RoundId,
249+
format: &BoxedFormat,
250+
message: &NormalBroadcast,
251+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
252+
) -> Result<(), MessageValidationError> {
253+
let (group, round_id) = round_id.split_group()?;
254+
if group != 1 {
255+
return Err(MessageValidationError::Local(LocalError::new(format!(
256+
"Expected round ID from group 1, got {round_id}"
257+
))));
258+
}
259+
self.0
260+
.as_ref()
261+
.verify_normal_broadcast_is_invalid(&round_id, format, message, &associated_data.protocol1)
262+
}
263+
}
264+
265+
#[derive_where::derive_where(Debug)]
266+
struct RoundInfoWrapper2<Id: 'static, P: ChainedProtocol<Id> + ChainedMarker>(BoxedRoundInfo<Id, P::Protocol2>);
267+
268+
impl<Id, P> RoundInfo<Id> for RoundInfoWrapper2<Id, P>
269+
where
270+
P: ChainedProtocol<Id> + ChainedMarker,
271+
{
272+
type Protocol = P;
273+
274+
fn verify_direct_message_is_invalid(
275+
&self,
221276
round_id: &RoundId,
277+
format: &BoxedFormat,
222278
message: &DirectMessage,
223-
associated_data: &<Self::ProtocolError as ProtocolError<Id>>::AssociatedData,
279+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
224280
) -> Result<(), MessageValidationError> {
225281
let (group, round_id) = round_id.split_group()?;
226-
if group == 1 {
227-
C::Protocol1::verify_direct_message_is_invalid(format, &round_id, message, &associated_data.protocol1)
228-
} else {
229-
C::Protocol2::verify_direct_message_is_invalid(format, &round_id, message, &associated_data.protocol2)
282+
if group != 2 {
283+
return Err(MessageValidationError::Local(LocalError::new(format!(
284+
"Expected round ID from group 2, got {round_id}"
285+
))));
230286
}
287+
self.0
288+
.as_ref()
289+
.verify_direct_message_is_invalid(&round_id, format, message, &associated_data.protocol2)
231290
}
232291

233292
fn verify_echo_broadcast_is_invalid(
234-
format: &BoxedFormat,
293+
&self,
235294
round_id: &RoundId,
295+
format: &BoxedFormat,
236296
message: &EchoBroadcast,
297+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
237298
) -> Result<(), MessageValidationError> {
238299
let (group, round_id) = round_id.split_group()?;
239-
if group == 1 {
240-
C::Protocol1::verify_echo_broadcast_is_invalid(format, &round_id, message)
241-
} else {
242-
C::Protocol2::verify_echo_broadcast_is_invalid(format, &round_id, message)
300+
if group != 2 {
301+
return Err(MessageValidationError::Local(LocalError::new(format!(
302+
"Expected round ID from group 2, got {round_id}"
303+
))));
243304
}
305+
self.0
306+
.as_ref()
307+
.verify_echo_broadcast_is_invalid(&round_id, format, message, &associated_data.protocol2)
244308
}
245309

246310
fn verify_normal_broadcast_is_invalid(
247-
format: &BoxedFormat,
311+
&self,
248312
round_id: &RoundId,
313+
format: &BoxedFormat,
249314
message: &NormalBroadcast,
315+
associated_data: &<<Self::Protocol as Protocol<Id>>::ProtocolError as ProtocolError<Id>>::AssociatedData,
250316
) -> Result<(), MessageValidationError> {
251317
let (group, round_id) = round_id.split_group()?;
318+
if group != 2 {
319+
return Err(MessageValidationError::Local(LocalError::new(format!(
320+
"Expected round ID from group 2, got {round_id}"
321+
))));
322+
}
323+
self.0
324+
.as_ref()
325+
.verify_normal_broadcast_is_invalid(&round_id, format, message, &associated_data.protocol2)
326+
}
327+
}
328+
329+
impl<Id, C> Protocol<Id> for C
330+
where
331+
Id: 'static,
332+
C: ChainedProtocol<Id> + ChainedMarker,
333+
{
334+
type Result = <C::Protocol2 as Protocol<Id>>::Result;
335+
type ProtocolError = ChainedProtocolError<Id, C>;
336+
337+
fn round_info(round_id: &RoundId) -> Option<BoxedRoundInfo<Id, Self>> {
338+
let (group, round_id) = round_id.split_group().ok()?;
252339
if group == 1 {
253-
C::Protocol1::verify_normal_broadcast_is_invalid(format, &round_id, message)
340+
let round_info = C::Protocol1::round_info(&round_id)?;
341+
Some(BoxedRoundInfo::new_obj(Box::new(RoundInfoWrapper1(round_info))))
342+
} else if group == 2 {
343+
let round_info = C::Protocol2::round_info(&round_id)?;
344+
Some(BoxedRoundInfo::new_obj(Box::new(RoundInfoWrapper2(round_info))))
254345
} else {
255-
C::Protocol2::verify_normal_broadcast_is_invalid(format, &round_id, message)
346+
None
256347
}
257348
}
258349
}

manul/src/protocol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ pub use static_round::{NoMessage, StaticProtocolMessage, StaticRound};
3737

3838
pub(crate) use errors::ReceiveErrorType;
3939
pub(crate) use message::ProtocolMessagePartHashable;
40+
pub(crate) use round_info::RoundInfo;

manul/src/protocol/round.rs

Lines changed: 1 addition & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ use serde::{Deserialize, Serialize};
1414
use super::{
1515
boxed_format::BoxedFormat,
1616
boxed_round::BoxedRound,
17-
errors::{LocalError, MessageValidationError, ProtocolValidationError, ReceiveError},
17+
errors::{LocalError, ProtocolValidationError, ReceiveError},
1818
message::{DirectMessage, EchoBroadcast, NormalBroadcast, ProtocolMessage, ProtocolMessagePart},
1919
round_id::{RoundId, TransitionInfo},
2020
round_info::BoxedRoundInfo,
@@ -76,56 +76,6 @@ pub trait Protocol<Id>: 'static + Sized {
7676

7777
/// Returns the wrapped round types for each round mapped to round IDs.
7878
fn round_info(round_id: &RoundId) -> Option<BoxedRoundInfo<Id, Self>>;
79-
80-
// TODO: move out of `Protocol`. To `evidence.rs`, perhaps?
81-
/// Returns `Ok(())` if the given direct message cannot be deserialized
82-
/// assuming it is a direct message from the round `round_id`.
83-
///
84-
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
85-
/// when implementing this.
86-
fn verify_direct_message_is_invalid(
87-
format: &BoxedFormat,
88-
round_id: &RoundId,
89-
message: &DirectMessage,
90-
associated_data: &<Self::ProtocolError as ProtocolError<Id>>::AssociatedData,
91-
) -> Result<(), MessageValidationError> {
92-
let round_info = Self::round_info(round_id).ok_or_else(|| {
93-
MessageValidationError::Local(LocalError::new(format!("{round_id} is not in the protocol")))
94-
})?;
95-
round_info.verify_direct_message_is_invalid(round_id, format, message, associated_data)
96-
}
97-
98-
/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
99-
/// assuming it is an echo broadcast from the round `round_id`.
100-
///
101-
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
102-
/// when implementing this.
103-
fn verify_echo_broadcast_is_invalid(
104-
format: &BoxedFormat,
105-
round_id: &RoundId,
106-
message: &EchoBroadcast,
107-
) -> Result<(), MessageValidationError> {
108-
let round_info = Self::round_info(round_id).ok_or_else(|| {
109-
MessageValidationError::Local(LocalError::new(format!("{round_id} is not in the protocol")))
110-
})?;
111-
round_info.verify_echo_broadcast_is_invalid(format, message)
112-
}
113-
114-
/// Returns `Ok(())` if the given echo broadcast cannot be deserialized
115-
/// assuming it is an echo broadcast from the round `round_id`.
116-
///
117-
/// Normally one would use [`ProtocolMessagePart::verify_is_not`] and [`ProtocolMessagePart::verify_is_some`]
118-
/// when implementing this.
119-
fn verify_normal_broadcast_is_invalid(
120-
format: &BoxedFormat,
121-
round_id: &RoundId,
122-
message: &NormalBroadcast,
123-
) -> Result<(), MessageValidationError> {
124-
let round_info = Self::round_info(round_id).ok_or_else(|| {
125-
MessageValidationError::Local(LocalError::new(format!("{round_id} is not in the protocol")))
126-
})?;
127-
round_info.verify_normal_broadcast_is_invalid(format, message)
128-
}
12979
}
13080

13181
/// Declares which parts of the message from a round have to be stored to serve as the evidence of malicious behavior.

0 commit comments

Comments
 (0)