Skip to content

Commit 8de19fb

Browse files
committed
Add StaticRound to eliminate some boilerplate when writing protocols
Make Round2 static as well Make payloads and artifacts typed Don't use postcard to create NoMessage
1 parent 46fe191 commit 8de19fb

File tree

6 files changed

+355
-81
lines changed

6 files changed

+355
-81
lines changed

examples/src/simple.rs

Lines changed: 54 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use alloc::collections::{BTreeMap, BTreeSet};
22
use core::fmt::Debug;
33

44
use manul::protocol::{
5-
Artifact, BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome,
6-
LocalError, MessageValidationError, NormalBroadcast, PartyId, Payload, Protocol, ProtocolError, ProtocolMessage,
7-
ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, Round, RoundId,
8-
TransitionInfo,
5+
BoxedFormat, BoxedRound, CommunicationInfo, DirectMessage, EchoBroadcast, EntryPoint, FinalizeOutcome, LocalError,
6+
MessageValidationError, NoMessage, NormalBroadcast, PartyId, Protocol, ProtocolError, ProtocolMessage,
7+
ProtocolMessagePart, ProtocolValidationError, ReceiveError, RequiredMessageParts, RequiredMessages, RoundId,
8+
StaticProtocolMessage, StaticRound, TransitionInfo,
99
};
1010
use rand_core::CryptoRngCore;
1111
use serde::{Deserialize, Serialize};
@@ -121,7 +121,7 @@ pub(crate) struct Context<Id> {
121121
}
122122

123123
#[derive(Debug)]
124-
pub struct Round1<Id> {
124+
pub(crate) struct Round1<Id> {
125125
pub(crate) context: Context<Id>,
126126
}
127127

@@ -132,17 +132,17 @@ pub(crate) struct Round1Message {
132132
}
133133

134134
#[derive(Serialize, Deserialize)]
135-
struct Round1Echo {
135+
pub(crate) struct Round1Echo {
136136
my_position: u8,
137137
}
138138

139139
#[derive(Serialize, Deserialize)]
140-
struct Round1Broadcast {
140+
pub(crate) struct Round1Broadcast {
141141
x: u8,
142142
my_position: u8,
143143
}
144144

145-
struct Round1Payload {
145+
pub(crate) struct Round1Payload {
146146
x: u8,
147147
}
148148

@@ -182,7 +182,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
182182
let mut ids = self.all_ids;
183183
ids.remove(id);
184184

185-
Ok(BoxedRound::new_dynamic(Round1 {
185+
Ok(BoxedRound::new_static(Round1 {
186186
context: Context {
187187
id: id.clone(),
188188
other_ids: ids,
@@ -192,7 +192,7 @@ impl<Id: PartyId> EntryPoint<Id> for SimpleProtocolEntryPoint<Id> {
192192
}
193193
}
194194

195-
impl<Id: PartyId> Round<Id> for Round1<Id> {
195+
impl<Id: PartyId> StaticRound<Id> for Round1<Id> {
196196
type Protocol = SimpleProtocol;
197197

198198
fn transition_info(&self) -> TransitionInfo {
@@ -203,92 +203,71 @@ impl<Id: PartyId> Round<Id> for Round1<Id> {
203203
CommunicationInfo::regular(&self.context.other_ids)
204204
}
205205

206-
fn make_normal_broadcast(
207-
&self,
208-
_rng: &mut dyn CryptoRngCore,
209-
format: &BoxedFormat,
210-
) -> Result<NormalBroadcast, LocalError> {
211-
debug!("{:?}: making normal broadcast", self.context.id);
206+
type NormalBroadcast = Round1Broadcast;
207+
type EchoBroadcast = Round1Echo;
208+
type DirectMessage = Round1Message;
212209

213-
let message = Round1Broadcast {
210+
type Payload = Round1Payload;
211+
type Artifact = ();
212+
213+
fn make_normal_broadcast(&self, _rng: &mut dyn CryptoRngCore) -> Result<Self::NormalBroadcast, LocalError> {
214+
debug!("{:?}: making normal broadcast", self.context.id);
215+
Ok(Round1Broadcast {
214216
x: 0,
215217
my_position: self.context.ids_to_positions[&self.context.id],
216-
};
217-
218-
NormalBroadcast::new(format, message)
218+
})
219219
}
220220

221-
fn make_echo_broadcast(
222-
&self,
223-
_rng: &mut dyn CryptoRngCore,
224-
format: &BoxedFormat,
225-
) -> Result<EchoBroadcast, LocalError> {
221+
fn make_echo_broadcast(&self, _rng: &mut dyn CryptoRngCore) -> Result<Self::EchoBroadcast, LocalError> {
226222
debug!("{:?}: making echo broadcast", self.context.id);
227-
228-
let message = Round1Echo {
223+
Ok(Round1Echo {
229224
my_position: self.context.ids_to_positions[&self.context.id],
230-
};
231-
232-
EchoBroadcast::new(format, message)
225+
})
233226
}
234227

235228
fn make_direct_message(
236229
&self,
237230
_rng: &mut dyn CryptoRngCore,
238-
format: &BoxedFormat,
239231
destination: &Id,
240-
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
232+
) -> Result<(Self::DirectMessage, Option<Self::Artifact>), LocalError> {
241233
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
242-
243234
let message = Round1Message {
244235
my_position: self.context.ids_to_positions[&self.context.id],
245236
your_position: self.context.ids_to_positions[destination],
246237
};
247-
let dm = DirectMessage::new(format, message)?;
248-
Ok((dm, None))
238+
Ok((message, Some(())))
249239
}
250240

251241
fn receive_message(
252242
&self,
253-
format: &BoxedFormat,
254243
from: &Id,
255-
message: ProtocolMessage,
256-
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
244+
message: StaticProtocolMessage<Id, Self>,
245+
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
257246
debug!("{:?}: receiving message from {:?}", self.context.id, from);
258-
259-
let _echo = message.echo_broadcast.deserialize::<Round1Echo>(format)?;
260-
let _normal = message.normal_broadcast.deserialize::<Round1Broadcast>(format)?;
261-
let message = message.direct_message.deserialize::<Round1Message>(format)?;
262-
263-
debug!("{:?}: received message: {:?}", self.context.id, message);
247+
let message = message.direct_message;
264248

265249
if self.context.ids_to_positions[&self.context.id] != message.your_position {
266250
return Err(ReceiveError::protocol(SimpleProtocolError::Round1InvalidPosition));
267251
}
268-
269-
Ok(Payload::new(Round1Payload { x: message.my_position }))
252+
Ok(Round1Payload { x: message.my_position })
270253
}
271254

272255
fn finalize(
273-
self: Box<Self>,
256+
self,
274257
_rng: &mut dyn CryptoRngCore,
275-
payloads: BTreeMap<Id, Payload>,
276-
_artifacts: BTreeMap<Id, Artifact>,
258+
payloads: BTreeMap<Id, Self::Payload>,
259+
_artifacts: BTreeMap<Id, Self::Artifact>,
277260
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
278261
debug!(
279262
"{:?}: finalizing with messages from {:?}",
280263
self.context.id,
281264
payloads.keys().cloned().collect::<Vec<_>>()
282265
);
283266

284-
let typed_payloads = payloads
285-
.into_values()
286-
.map(|payload| payload.downcast::<Round1Payload>())
287-
.collect::<Result<Vec<_>, _>>()?;
288-
let sum = self.context.ids_to_positions[&self.context.id]
289-
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
267+
let sum =
268+
self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::<u8>();
290269

291-
let round2 = BoxedRound::new_dynamic(Round2 {
270+
let round2 = BoxedRound::new_static(Round2 {
292271
round1_sum: sum,
293272
context: self.context,
294273
});
@@ -308,7 +287,7 @@ pub(crate) struct Round2Message {
308287
pub(crate) your_position: u8,
309288
}
310289

311-
impl<Id: PartyId> Round<Id> for Round2<Id> {
290+
impl<Id: PartyId> StaticRound<Id> for Round2<Id> {
312291
type Protocol = SimpleProtocol;
313292

314293
fn transition_info(&self) -> TransitionInfo {
@@ -319,62 +298,59 @@ impl<Id: PartyId> Round<Id> for Round2<Id> {
319298
CommunicationInfo::regular(&self.context.other_ids)
320299
}
321300

301+
type DirectMessage = Round2Message;
302+
type EchoBroadcast = NoMessage;
303+
type NormalBroadcast = NoMessage;
304+
305+
type Payload = Round1Payload;
306+
type Artifact = ();
307+
322308
fn make_direct_message(
323309
&self,
324310
_rng: &mut dyn CryptoRngCore,
325-
format: &BoxedFormat,
326311
destination: &Id,
327-
) -> Result<(DirectMessage, Option<Artifact>), LocalError> {
312+
) -> Result<(Self::DirectMessage, Option<Self::Artifact>), LocalError> {
328313
debug!("{:?}: making direct message for {:?}", self.context.id, destination);
329314

330315
let message = Round2Message {
331316
my_position: self.context.ids_to_positions[&self.context.id],
332317
your_position: self.context.ids_to_positions[destination],
333318
};
334-
let dm = DirectMessage::new(format, message)?;
335-
Ok((dm, None))
319+
Ok((message, Some(())))
336320
}
337321

338322
fn receive_message(
339323
&self,
340-
format: &BoxedFormat,
341324
from: &Id,
342-
message: ProtocolMessage,
343-
) -> Result<Payload, ReceiveError<Id, Self::Protocol>> {
325+
message: StaticProtocolMessage<Id, Self>,
326+
) -> Result<Self::Payload, ReceiveError<Id, Self::Protocol>> {
344327
debug!("{:?}: receiving message from {:?}", self.context.id, from);
345328

346-
message.echo_broadcast.assert_is_none()?;
347-
message.normal_broadcast.assert_is_none()?;
348-
349-
let message = message.direct_message.deserialize::<Round1Message>(format)?;
329+
let message = message.direct_message;
350330

351331
debug!("{:?}: received message: {:?}", self.context.id, message);
352332

353333
if self.context.ids_to_positions[&self.context.id] != message.your_position {
354334
return Err(ReceiveError::protocol(SimpleProtocolError::Round2InvalidPosition));
355335
}
356336

357-
Ok(Payload::new(Round1Payload { x: message.my_position }))
337+
Ok(Round1Payload { x: message.my_position })
358338
}
359339

360340
fn finalize(
361-
self: Box<Self>,
341+
self: Self,
362342
_rng: &mut dyn CryptoRngCore,
363-
payloads: BTreeMap<Id, Payload>,
364-
_artifacts: BTreeMap<Id, Artifact>,
343+
payloads: BTreeMap<Id, Self::Payload>,
344+
_artifacts: BTreeMap<Id, Self::Artifact>,
365345
) -> Result<FinalizeOutcome<Id, Self::Protocol>, LocalError> {
366346
debug!(
367347
"{:?}: finalizing with messages from {:?}",
368348
self.context.id,
369349
payloads.keys().cloned().collect::<Vec<_>>()
370350
);
371351

372-
let typed_payloads = payloads
373-
.into_values()
374-
.map(|payload| payload.downcast::<Round1Payload>())
375-
.collect::<Result<Vec<_>, _>>()?;
376-
let sum = self.context.ids_to_positions[&self.context.id]
377-
+ typed_payloads.iter().map(|payload| payload.x).sum::<u8>();
352+
let sum =
353+
self.context.ids_to_positions[&self.context.id] + payloads.values().map(|payload| payload.x).sum::<u8>();
378354

379355
Ok(FinalizeOutcome::Result(sum + self.round1_sum))
380356
}

examples/src/simple_malicious.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
3939
match behavior {
4040
Behavior::SerializedGarbage => DirectMessage::new(format, [99u8])?,
4141
Behavior::AttributableFailure => {
42-
let round1 = round.downcast_ref::<Round1<Id>>()?;
42+
let round1 = round.downcast_static_ref::<Round1<Id>>()?;
4343
let message = Round1Message {
4444
my_position: round1.context.ids_to_positions[&round1.context.id],
4545
your_position: round1.context.ids_to_positions[&round1.context.id],
@@ -51,7 +51,7 @@ impl<Id: PartyId> Misbehaving<Id, Behavior> for MaliciousLogic {
5151
} else if round.id() == 2 {
5252
match behavior {
5353
Behavior::AttributableFailureRound2 => {
54-
let round2 = round.downcast_ref::<Round2<Id>>()?;
54+
let round2 = round.downcast_static_ref::<Round2<Id>>()?;
5555
let message = Round2Message {
5656
my_position: round2.context.ids_to_positions[&round2.context.id],
5757
your_position: round2.context.ids_to_positions[&round2.context.id],

manul/src/protocol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod errors;
1717
mod message;
1818
mod round;
1919
mod round_id;
20+
mod static_round;
2021

2122
pub use boxed_format::BoxedFormat;
2223
pub use boxed_round::BoxedRound;
@@ -30,6 +31,7 @@ pub use round::{
3031
Payload, Protocol, ProtocolError, RequiredMessageParts, RequiredMessages, Round,
3132
};
3233
pub use round_id::{RoundId, TransitionInfo};
34+
pub use static_round::{NoMessage, StaticProtocolMessage, StaticRound};
3335

3436
pub(crate) use errors::ReceiveErrorType;
3537
pub(crate) use message::ProtocolMessagePartHashable;

manul/src/protocol/boxed_round.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use super::{
44
errors::LocalError,
55
round::{PartyId, Protocol, Round},
66
round_id::RoundId,
7+
static_round::{StaticRound, StaticRoundAdapter},
78
};
89

910
/// A wrapped new round that may be returned by [`Round::finalize`]
@@ -17,6 +18,11 @@ impl<Id: PartyId, P: Protocol<Id>> BoxedRound<Id, P> {
1718
Self(Box::new(round))
1819
}
1920

21+
/// Wraps an object implementing the dynamic round trait ([`StaticRound`](`crate::protocol::StaticRound`)).
22+
pub fn new_static<R: StaticRound<Id, Protocol = P>>(round: R) -> Self {
23+
Self(Box::new(StaticRoundAdapter::new(round)))
24+
}
25+
2026
pub(crate) fn as_ref(&self) -> &dyn Round<Id, Protocol = P> {
2127
self.0.as_ref()
2228
}
@@ -64,6 +70,22 @@ impl<Id: PartyId, P: Protocol<Id>> BoxedRound<Id, P> {
6470
}
6571
}
6672

73+
/// Attempts to provide a reference to an object of a concrete type.
74+
///
75+
/// Fails if the wrapped type is not `T`.
76+
pub fn downcast_static_ref<T: StaticRound<Id>>(&self) -> Result<&T, LocalError> {
77+
if self.boxed_type_is::<StaticRoundAdapter<T>>() {
78+
let ptr: *const dyn Round<Id, Protocol = P> = self.0.as_ref();
79+
// Safety: This is safe since we just checked that we are casting to the correct type.
80+
Ok(unsafe { &*(ptr as *const StaticRoundAdapter<T>) }.as_inner())
81+
} else {
82+
Err(LocalError::new(format!(
83+
"Failed to downcast into type {}",
84+
core::any::type_name::<T>()
85+
)))
86+
}
87+
}
88+
6789
/// Returns the round's ID.
6890
pub fn id(&self) -> RoundId {
6991
// This constructs a new `TransitionInfo` object, so calling this method inside `Session`

manul/src/protocol/round.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ mod sealed {
384384
impl<T: 'static> DynTypeId for T {}
385385
}
386386

387-
use sealed::DynTypeId;
387+
pub(crate) use sealed::DynTypeId;
388388

389389
/**
390390
A type representing a single round of a protocol.

0 commit comments

Comments
 (0)