@@ -49,15 +49,16 @@ Usage:
49
49
when verifying evidence from the chained protocol.
50
50
*/
51
51
52
- use alloc:: { boxed:: Box , collections:: BTreeMap } ;
52
+ use alloc:: { boxed:: Box , collections:: BTreeMap , format } ;
53
53
use core:: fmt:: { self , Debug } ;
54
54
55
55
use rand_core:: CryptoRngCore ;
56
56
57
57
use crate :: protocol:: {
58
58
Artifact , BoxedFormat , BoxedRound , BoxedRoundInfo , CommunicationInfo , DirectMessage , EchoBroadcast , EntryPoint ,
59
59
FinalizeOutcome , LocalError , MessageValidationError , NormalBroadcast , PartyId , Payload , Protocol , ProtocolError ,
60
- ProtocolMessage , ProtocolValidationError , ReceiveError , RequiredMessages , Round , RoundId , TransitionInfo ,
60
+ ProtocolMessage , ProtocolValidationError , ReceiveError , RequiredMessages , Round , RoundId , RoundInfo ,
61
+ TransitionInfo ,
61
62
} ;
62
63
63
64
/// A marker trait that is used to disambiguate blanket trait implementations for [`Protocol`] and [`EntryPoint`].
@@ -197,62 +198,152 @@ where
197
198
}
198
199
}
199
200
200
- impl < Id , C > Protocol < Id > for C
201
+ #[ derive_where:: derive_where( Debug ) ]
202
+ struct RoundInfoWrapper1 < Id : ' static , P : ChainedProtocol < Id > + ChainedMarker > ( BoxedRoundInfo < Id , P :: Protocol1 > ) ;
203
+
204
+ impl < Id , P > RoundInfo < Id > for RoundInfoWrapper1 < Id , P >
201
205
where
202
- Id : ' static ,
203
- C : ChainedProtocol < Id > + ChainedMarker ,
206
+ P : ChainedProtocol < Id > + ChainedMarker ,
204
207
{
205
- type Result = <C :: Protocol2 as Protocol < Id > >:: Result ;
206
- type ProtocolError = ChainedProtocolError < Id , C > ;
208
+ type Protocol = P ;
207
209
208
- fn round_info ( _round_id : & RoundId ) -> Option < BoxedRoundInfo < Id , Self > > {
209
- /*let rounds1 = C::Protocol1::rounds()
210
- .into_iter()
211
- .map(|(round_id, round)| (round_id.group_under(1), round));
212
- let rounds2 = C::Protocol2::rounds()
213
- .into_iter()
214
- .map(|(round_id, round)| (round_id.group_under(2), round));
215
- rounds1.chain(rounds2).collect()*/
216
- unimplemented ! ( )
210
+ fn verify_direct_message_is_invalid (
211
+ & self ,
212
+ round_id : & RoundId ,
213
+ format : & BoxedFormat ,
214
+ message : & DirectMessage ,
215
+ associated_data : & <<Self :: Protocol as Protocol < Id > >:: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
216
+ ) -> Result < ( ) , MessageValidationError > {
217
+ let ( group, round_id) = round_id. split_group ( ) ?;
218
+ if group != 1 {
219
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
220
+ "Expected round ID from group 1, got {round_id}"
221
+ ) ) ) ) ;
222
+ }
223
+ self . 0
224
+ . as_ref ( )
225
+ . verify_direct_message_is_invalid ( & round_id, format, message, & associated_data. protocol1 )
217
226
}
218
227
219
- fn verify_direct_message_is_invalid (
228
+ fn verify_echo_broadcast_is_invalid (
229
+ & self ,
230
+ round_id : & RoundId ,
220
231
format : & BoxedFormat ,
232
+ message : & EchoBroadcast ,
233
+ associated_data : & <<Self :: Protocol as Protocol < Id > >:: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
234
+ ) -> Result < ( ) , MessageValidationError > {
235
+ let ( group, round_id) = round_id. split_group ( ) ?;
236
+ if group != 1 {
237
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
238
+ "Expected round ID from group 1, got {round_id}"
239
+ ) ) ) ) ;
240
+ }
241
+ self . 0
242
+ . as_ref ( )
243
+ . verify_echo_broadcast_is_invalid ( & round_id, format, message, & associated_data. protocol1 )
244
+ }
245
+
246
+ fn verify_normal_broadcast_is_invalid (
247
+ & self ,
248
+ round_id : & RoundId ,
249
+ format : & BoxedFormat ,
250
+ message : & NormalBroadcast ,
251
+ associated_data : & <<Self :: Protocol as Protocol < Id > >:: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
252
+ ) -> Result < ( ) , MessageValidationError > {
253
+ let ( group, round_id) = round_id. split_group ( ) ?;
254
+ if group != 1 {
255
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
256
+ "Expected round ID from group 1, got {round_id}"
257
+ ) ) ) ) ;
258
+ }
259
+ self . 0
260
+ . as_ref ( )
261
+ . verify_normal_broadcast_is_invalid ( & round_id, format, message, & associated_data. protocol1 )
262
+ }
263
+ }
264
+
265
+ #[ derive_where:: derive_where( Debug ) ]
266
+ struct RoundInfoWrapper2 < Id : ' static , P : ChainedProtocol < Id > + ChainedMarker > ( BoxedRoundInfo < Id , P :: Protocol2 > ) ;
267
+
268
+ impl < Id , P > RoundInfo < Id > for RoundInfoWrapper2 < Id , P >
269
+ where
270
+ P : ChainedProtocol < Id > + ChainedMarker ,
271
+ {
272
+ type Protocol = P ;
273
+
274
+ fn verify_direct_message_is_invalid (
275
+ & self ,
221
276
round_id : & RoundId ,
277
+ format : & BoxedFormat ,
222
278
message : & DirectMessage ,
223
- associated_data : & <Self :: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
279
+ associated_data : & << Self :: Protocol as Protocol < Id > > :: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
224
280
) -> Result < ( ) , MessageValidationError > {
225
281
let ( group, round_id) = round_id. split_group ( ) ?;
226
- if group == 1 {
227
- C :: Protocol1 :: verify_direct_message_is_invalid ( format, & round_id , message , & associated_data . protocol1 )
228
- } else {
229
- C :: Protocol2 :: verify_direct_message_is_invalid ( format , & round_id , message , & associated_data . protocol2 )
282
+ if group != 2 {
283
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
284
+ "Expected round ID from group 2, got {round_id}"
285
+ ) ) ) ) ;
230
286
}
287
+ self . 0
288
+ . as_ref ( )
289
+ . verify_direct_message_is_invalid ( & round_id, format, message, & associated_data. protocol2 )
231
290
}
232
291
233
292
fn verify_echo_broadcast_is_invalid (
234
- format : & BoxedFormat ,
293
+ & self ,
235
294
round_id : & RoundId ,
295
+ format : & BoxedFormat ,
236
296
message : & EchoBroadcast ,
297
+ associated_data : & <<Self :: Protocol as Protocol < Id > >:: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
237
298
) -> Result < ( ) , MessageValidationError > {
238
299
let ( group, round_id) = round_id. split_group ( ) ?;
239
- if group == 1 {
240
- C :: Protocol1 :: verify_echo_broadcast_is_invalid ( format, & round_id , message )
241
- } else {
242
- C :: Protocol2 :: verify_echo_broadcast_is_invalid ( format , & round_id , message )
300
+ if group != 2 {
301
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
302
+ "Expected round ID from group 2, got {round_id}"
303
+ ) ) ) ) ;
243
304
}
305
+ self . 0
306
+ . as_ref ( )
307
+ . verify_echo_broadcast_is_invalid ( & round_id, format, message, & associated_data. protocol2 )
244
308
}
245
309
246
310
fn verify_normal_broadcast_is_invalid (
247
- format : & BoxedFormat ,
311
+ & self ,
248
312
round_id : & RoundId ,
313
+ format : & BoxedFormat ,
249
314
message : & NormalBroadcast ,
315
+ associated_data : & <<Self :: Protocol as Protocol < Id > >:: ProtocolError as ProtocolError < Id > >:: AssociatedData ,
250
316
) -> Result < ( ) , MessageValidationError > {
251
317
let ( group, round_id) = round_id. split_group ( ) ?;
318
+ if group != 2 {
319
+ return Err ( MessageValidationError :: Local ( LocalError :: new ( format ! (
320
+ "Expected round ID from group 2, got {round_id}"
321
+ ) ) ) ) ;
322
+ }
323
+ self . 0
324
+ . as_ref ( )
325
+ . verify_normal_broadcast_is_invalid ( & round_id, format, message, & associated_data. protocol2 )
326
+ }
327
+ }
328
+
329
+ impl < Id , C > Protocol < Id > for C
330
+ where
331
+ Id : ' static ,
332
+ C : ChainedProtocol < Id > + ChainedMarker ,
333
+ {
334
+ type Result = <C :: Protocol2 as Protocol < Id > >:: Result ;
335
+ type ProtocolError = ChainedProtocolError < Id , C > ;
336
+
337
+ fn round_info ( round_id : & RoundId ) -> Option < BoxedRoundInfo < Id , Self > > {
338
+ let ( group, round_id) = round_id. split_group ( ) . ok ( ) ?;
252
339
if group == 1 {
253
- C :: Protocol1 :: verify_normal_broadcast_is_invalid ( format, & round_id, message)
340
+ let round_info = C :: Protocol1 :: round_info ( & round_id) ?;
341
+ Some ( BoxedRoundInfo :: new_obj ( Box :: new ( RoundInfoWrapper1 ( round_info) ) ) )
342
+ } else if group == 2 {
343
+ let round_info = C :: Protocol2 :: round_info ( & round_id) ?;
344
+ Some ( BoxedRoundInfo :: new_obj ( Box :: new ( RoundInfoWrapper2 ( round_info) ) ) )
254
345
} else {
255
- C :: Protocol2 :: verify_normal_broadcast_is_invalid ( format , & round_id , message )
346
+ None
256
347
}
257
348
}
258
349
}
0 commit comments