Skip to content

Commit 0ef0c98

Browse files
committed
Add StaticRound to eliminate some boilerplate when writing protocols
1 parent 46fe191 commit 0ef0c98

File tree

6 files changed

+355
-81
lines changed

6 files changed

+355
-81
lines changed

examples/src/simple.rs

Lines changed: 54 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use alloc::collections::{BTreeMap, BTreeSet};
22
use core::fmt::Debug;
33

44
use manul::protocol::{
5-
Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome,
6-
LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage,
7-
ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId,
8-
TransitionInfo,
5+
BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError,
6+
MessageValidationError, NoMessage, NormalBroadcast, PartyId, Protocol, ProtocolError, ProtocolMessage,
7+
ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, RoundId,
8+
StaticProtocolMessage, StaticRound, TransitionInfo,
99
};
1010
use rand_core::CryptoRngCore;
1111
use serde::{Deserialize, Serialize};
@@ -121,7 +121,7 @@ pub(crate) struct Context<Id> {
121121
}
122122

123123
#[derive(Debug)]
124-
pub struct Round1<Id> {
124+
pub(crate) struct Round1<Id> {
125125
pub(crate) context: Context<Id>,
126126
}
127127

@@ -132,17 +132,17 @@ pub(crate) struct Round1Message {
132132
}
133133

134134
#[derive(Serialize, Deserialize)]
135-
struct Round1Echo {
135+
pub(crate) struct Round1Echo {
136136
my_position: u8,
137137
}
138138

139139
#[derive(Serialize, Deserialize)]
140-
struct Round1Broadcast {
140+
pub(crate) struct Round1Broadcast {
141141
x: u8,
142142
my_position: u8,
143143
}
144144

145-
struct Round1Payload {
145+
pub(crate) struct Round1Payload {
146146
x: u8,
147147
}
148148

@@ -182,7 +182,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
182182
let mut ids = self.all_ids;
183183
ids.remove(id);
184184

185-
Ok(BoxedRound::new_dynamic(Round1 {
185+
Ok(BoxedRound::new_static(Round1 {
186186
context: Context {
187187
id: id.clone(),
188188
other_ids: ids,
@@ -192,7 +192,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
192192
}
193193
}
194194

195-
impl<Id: PartyId> Round<Id> for Round1<Id> {
195+
impl<Id: PartyId> StaticRound<Id> for Round1<Id> {
196196
type Protocol = SimpleProtocol;
197197

198198
fn transition_info(&self) -> TransitionInfo {
@@ -203,92 +203,71 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {
203203
CommunicationInfo::regular(&self.context.other_ids)
204204
}
205205

206-
fn make_normal_broadcast(
207-
&self,
208-
_rng: &mut dyn CryptoRngCore,
209-
format: &BoxedFormat,
210-
) -> Result<NormalBroadcast, LocalError> {
211-
debug!("{:?}: making normal broadcast", self.context.id);
206+
type NormalBroadcast = Round1Broadcast;
207+
type EchoBroadcast = Round1Echo;
208+
type DirectMessage = Round1Message;
212209

213-
let message = Round1Broadcast {
210+
type Payload = Round1Payload;
211+
type Artifact = ();
212+
213+
fn make_normal_broadcast(&self, _rng: &mut dyn CryptoRngCore) -> Result<Self::NormalBroadcast, LocalError> {
214+
debug!("{:?}: making normal broadcast", self.context.id);
215+
Ok(Round1Broadcast {
214216
x: 0,
215217
my_position: self.context.ids_to_positions[&self.context.id],
216-
};
217-
218-
NormalBroadcast::new(format, message)
218+
})
219219
}
220220

221-
fn make_echo_broadcast(
222-
&self,
223-
_rng: &mut dyn CryptoRngCore,
224-
format: &BoxedFormat,
225-
) -> Result<EchoBroadcast, LocalError> {
221+
fn make_echo_broadcast(&self, _rng: &mut dyn CryptoRngCore) -> Result<Self::EchoBroadcast, LocalError> {
226222
debug!("{:?}: making echo broadcast", self.context.id);
227-
228-
let message = Round1Echo {
223+
Ok(Round1Echo {
229224
my_position: self.context.ids_to_positions[&self.context.id],
230-
};
231-
232-
EchoBroadcast::new(format, message)
225+
})
233226
}
234227

235228
fn make_direct_message(
236229
&self,
237230
_rng: &mut dyn CryptoRngCore,
238-
format: &BoxedFormat,
239231
destination: &Id,
240-
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
232+
) -> Result<(Self::DirectMessage, Option<Self::Artifact>), LocalError> {
241233
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
242-
243234
let message = Round1Message {
244235
my_position: self.context.ids_to_positions[&self.context.id],
245236
your_position: self.context.ids_to_positions[destination],
246237
};
247-
let dm = DirectMessage::new(format, message)?;
248-
Ok((dm, None))
238+
Ok((message, Some(())))
249239
}
250240

251241
fn receive_message(
252242
&self,
253-
format: &BoxedFormat,
254243
from: &Id,
255-
message: ProtocolMessage,
256-
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
244+
message: StaticProtocolMessage<Id, Self>,
245+
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
257246
debug!("{:?}: receiving message from {:?}", self.context.id, from);
258-
259-
let _echo = message.echo_broadcast.deserialize::<Round1Echo>(format)?;
260-
let _normal = message.normal_broadcast.deserialize::<Round1Broadcast>(format)?;
261-
let message = message.direct_message.deserialize::<Round1Message>(format)?;
262-
263-
debug!("{:?}: received message: {:?}", self.context.id, message);
247+
let message = message.direct_message;
264248

265249
if self.context.ids_to_positions[&self.context.id] != message.your_position {
266250
return Err(ReceiveError::protocol(SimpleProtocolError::Round1InvalidPosition));
267251
}
268-
269-
Ok(Payload::new(Round1Payload { x: message.my_position }))
252+
Ok(Round1Payload { x: message.my_position })
270253
}
271254

272255
fn finalize(
273-
self: Box<Self>,
256+
self,
274257
_rng: &mut dyn CryptoRngCore,
275-
payloads: BTreeMap<Id, Payload>,
276-
_artifacts: BTreeMap<Id, Artifact>,
258+
payloads: BTreeMap<Id, Self::Payload>,
259+
_artifacts: BTreeMap<Id, Self::Artifact>,
277260
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
278261
debug!(
279262
"{:?}: finalizing with messages from {:?}",
280263
self.context.id,
281264
payloads.keys().cloned().collect::<Vec<_>>()
282265
);
283266

284-
let typed_payloads = payloads
285-
.into_values()
286-
.map(|payload| payload.downcast::<Round1Payload>())
287-
.collect::<Result<Vec<_>, _>>()?;
288-
let sum = self.context.ids_to_positions[&self.context.id]
289-
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
267+
let sum =
268+
self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::<u8>();
290269

291-
let round2 = BoxedRound::new_dynamic(Round2 {
270+
let round2 = BoxedRound::new_static(Round2 {
292271
round1_sum: sum,
293272
context: self.context,
294273
});
@@ -308,7 +287,7 @@ pub(crate) struct Round2Message {
308287
pub(crate) your_position: u8,
309288
}
310289

311-
impl<Id: PartyId> Round<Id> for Round2<Id> {
290+
impl<Id: PartyId> StaticRound<Id> for Round2<Id> {
312291
type Protocol = SimpleProtocol;
313292

314293
fn transition_info(&self) -> TransitionInfo {
@@ -319,62 +298,59 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
319298
CommunicationInfo::regular(&self.context.other_ids)
320299
}
321300

301+
type DirectMessage = Round2Message;
302+
type EchoBroadcast = NoMessage;
303+
type NormalBroadcast = NoMessage;
304+
305+
type Payload = Round1Payload;
306+
type Artifact = ();
307+
322308
fn make_direct_message(
323309
&self,
324310
_rng: &mut dyn CryptoRngCore,
325-
format: &BoxedFormat,
326311
destination: &Id,
327-
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
312+
) -> Result<(Self::DirectMessage, Option<Self::Artifact>), LocalError> {
328313
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
329314

330315
let message = Round2Message {
331316
my_position: self.context.ids_to_positions[&self.context.id],
332317
your_position: self.context.ids_to_positions[destination],
333318
};
334-
let dm = DirectMessage::new(format, message)?;
335-
Ok((dm, None))
319+
Ok((message, Some(())))
336320
}
337321

338322
fn receive_message(
339323
&self,
340-
format: &BoxedFormat,
341324
from: &Id,
342-
message: ProtocolMessage,
343-
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
325+
message: StaticProtocolMessage<Id, Self>,
326+
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
344327
debug!("{:?}: receiving message from {:?}", self.context.id, from);
345328

346-
message.echo_broadcast.assert_is_none()?;
347-
message.normal_broadcast.assert_is_none()?;
348-
349-
let message = message.direct_message.deserialize::<Round1Message>(format)?;
329+
let message = message.direct_message;
350330

351331
debug!("{:?}: received message: {:?}", self.context.id, message);
352332

353333
if self.context.ids_to_positions[&self.context.id] != message.your_position {
354334
return Err(ReceiveError::protocol(SimpleProtocolError::Round2InvalidPosition));
355335
}
356336

357-
Ok(Payload::new(Round1Payload { x: message.my_position }))
337+
Ok(Round1Payload { x: message.my_position })
358338
}
359339

360340
fn finalize(
361-
self: Box<Self>,
341+
self: Self,
362342
_rng: &mut dyn CryptoRngCore,
363-
payloads: BTreeMap<Id, Payload>,
364-
_artifacts: BTreeMap<Id, Artifact>,
343+
payloads: BTreeMap<Id, Self::Payload>,
344+
_artifacts: BTreeMap<Id, Self::Artifact>,
365345
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
366346
debug!(
367347
"{:?}: finalizing with messages from {:?}",
368348
self.context.id,
369349
payloads.keys().cloned().collect::<Vec<_>>()
370350
);
371351

372-
let typed_payloads = payloads
373-
.into_values()
374-
.map(|payload| payload.downcast::<Round1Payload>())
375-
.collect::<Result<Vec<_>, _>>()?;
376-
let sum = self.context.ids_to_positions[&self.context.id]
377-
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
352+
let sum =
353+
self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::<u8>();
378354

379355
Ok(FinalizeOutcome::Result(sum + self.round1_sum))
380356
}

examples/src/simple_malicious.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
3939
match behavior {
4040
Behavior::SerializedGarbage => DirectMessage::new(format, [99u8])?,
4141
Behavior::AttributableFailure => {
42-
let round1 = round.downcast_ref::<Round1<Id>>()?;
42+
let round1 = round.downcast_static_ref::<Round1<Id>>()?;
4343
let message = Round1Message {
4444
my_position: round1.context.ids_to_positions[&round1.context.id],
4545
your_position: round1.context.ids_to_positions[&round1.context.id],
@@ -51,7 +51,7 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
5151
} else if round.id() == 2 {
5252
match behavior {
5353
Behavior::AttributableFailureRound2 => {
54-
let round2 = round.downcast_ref::<Round2<Id>>()?;
54+
let round2 = round.downcast_static_ref::<Round2<Id>>()?;
5555
let message = Round2Message {
5656
my_position: round2.context.ids_to_positions[&round2.context.id],
5757
your_position: round2.context.ids_to_positions[&round2.context.id],

manul/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod errors;
1717
mod message;
1818
mod round;
1919
mod round_id;
20+
mod static_round;
2021

2122
pub use boxed_format::BoxedFormat;
2223
pub use boxed_round::BoxedRound;
@@ -30,6 +31,7 @@ pub use round::{
3031
Payload, Protocol, ProtocolError, RequiredMessageParts, RequiredMessages, Round,
3132
};
3233
pub use round_id::{RoundId, TransitionInfo};
34+
pub use static_round::{NoMessage, StaticProtocolMessage, StaticRound};
3335

3436
pub(crate) use errors::ReceiveErrorType;
3537
pub(crate) use message::ProtocolMessagePartHashable;

manul/src/protocol/boxed_round.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::{
44
errors::LocalError,
55
round::{PartyId, Protocol, Round},
66
round_id::RoundId,
7+
static_round::{StaticRound, StaticRoundAdapter},
78
};
89

910
/// A wrapped new round that may be returned by [`Round::finalize`]
@@ -17,6 +18,11 @@ impl<Id: PartyId, P: Protocol<Id>> BoxedRound<Id, P> {
1718
Self(Box::new(round))
1819
}
1920

21+
/// Wraps an object implementing the dynamic round trait ([`StaticRound`](`crate::protocol::StaticRound`)).
22+
pub fn new_static<R: StaticRound<Id, Protocol = P>>(round: R) -> Self {
23+
Self(Box::new(StaticRoundAdapter::new(round)))
24+
}
25+
2026
pub(crate) fn as_ref(&self) -> &dyn Round<Id, Protocol = P> {
2127
self.0.as_ref()
2228
}
@@ -64,6 +70,22 @@ impl<Id: PartyId, P: Protocol<Id>> BoxedRound<Id, P> {
6470
}
6571
}
6672

73+
/// Attempts to provide a reference to an object of a concrete type.
74+
///
75+
/// Fails if the wrapped type is not `T`.
76+
pub fn downcast_static_ref<T: StaticRound<Id>>(&self) -> Result<&T, LocalError> {
77+
if self.boxed_type_is::<StaticRoundAdapter<T>>() {
78+
let ptr: *const dyn Round<Id, Protocol = P> = self.0.as_ref();
79+
// Safety: This is safe since we just checked that we are casting to the correct type.
80+
Ok(unsafe { &*(ptr as *const StaticRoundAdapter<T>) }.as_inner())
81+
} else {
82+
Err(LocalError::new(format!(
83+
"Failed to downcast into type {}",
84+
core::any::type_name::<T>()
85+
)))
86+
}
87+
}
88+
6789
/// Returns the round's ID.
6890
pub fn id(&self) -> RoundId {
6991
// This constructs a new `TransitionInfo` object, so calling this method inside `Session`

manul/src/protocol/round.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ mod sealed {
384384
impl<T: 'static> DynTypeId for T {}
385385
}
386386

387-
use sealed::DynTypeId;
387+
pub(crate) use sealed::DynTypeId;
388388

389389
/**
390390
A type representing a single round of a protocol.

0 commit comments

Comments
 (0)