Skip to content

Commit 5c3f521

Browse files
committed
Implement MultiPartSigner/Verifier for ML-DSA and SLH-DSA
1 parent 2798f79 commit 5c3f521

File tree

8 files changed

+142
-33
lines changed

8 files changed

+142
-33
lines changed

Cargo.lock

Lines changed: 1 addition & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,6 @@ lms-signature = { path = "./lms" }
2525
ml-dsa = { path = "./ml-dsa" }
2626
rfc6979 = { path = "./rfc6979" }
2727
slh-dsa = { path = "./slh-dsa" }
28+
29+
# https://github.com/RustCrypto/traits/pull/1880
30+
signature = { git = "https://github.com/RustCrypto/traits", rev = "ce4a05bc4266dd1d4c7c713fb7f9c211c76375ec" }

ml-dsa/src/lib.rs

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ use core::fmt;
8989

9090
pub use crate::param::{EncodedSignature, EncodedSigningKey, EncodedVerifyingKey, MlDsaParams};
9191
pub use crate::util::B32;
92-
pub use signature::{self, Error};
92+
pub use signature::{self, Error, MultiPartSigner, MultiPartVerifier};
9393

9494
/// An ML-DSA signature
9595
#[derive(Clone, PartialEq, Debug)]
@@ -168,10 +168,10 @@ where
168168
// This method takes a slice of slices so that we can accommodate the varying calculations (direct
169169
// for test vectors, 0... for sign/sign_deterministic, 1... for the pre-hashed version) without
170170
// having to allocate memory for components.
171-
fn message_representative(tr: &[u8], Mp: &[&[u8]]) -> B64 {
171+
fn message_representative(tr: &[u8], Mp: &[&[&[u8]]]) -> B64 {
172172
let mut h = H::default().absorb(tr);
173173

174-
for m in Mp {
174+
for m in Mp.iter().copied().flatten() {
175175
h = h.absorb(m);
176176
}
177177

@@ -245,7 +245,13 @@ where
245245
/// only supports signing with an empty context string.
246246
impl<P: MlDsaParams> signature::Signer<Signature<P>> for KeyPair<P> {
247247
fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
248-
self.signing_key.sign_deterministic(msg, &[])
248+
self.try_multi_part_sign(&[msg])
249+
}
250+
}
251+
252+
impl<P: MlDsaParams> MultiPartSigner<Signature<P>> for KeyPair<P> {
253+
fn try_multi_part_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
254+
self.signing_key.raw_sign_deterministic(msg, &[])
249255
}
250256
}
251257

@@ -350,6 +356,13 @@ impl<P: MlDsaParams> SigningKey<P> {
350356
// Algorithm 7 ML-DSA.Sign_internal
351357
// TODO(RLB) Only expose based on a feature. Tests need access, but normal code shouldn't.
352358
pub fn sign_internal(&self, Mp: &[&[u8]], rnd: &B32) -> Signature<P>
359+
where
360+
P: MlDsaParams,
361+
{
362+
self.raw_sign_internal(&[Mp], rnd)
363+
}
364+
365+
fn raw_sign_internal(&self, Mp: &[&[&[u8]]], rnd: &B32) -> Signature<P>
353366
where
354367
P: MlDsaParams,
355368
{
@@ -440,13 +453,17 @@ impl<P: MlDsaParams> SigningKey<P> {
440453
/// This method will return an opaque error if the context string is more than 255 bytes long.
441454
// Algorithm 2 ML-DSA.Sign (optional deterministic variant)
442455
pub fn sign_deterministic(&self, M: &[u8], ctx: &[u8]) -> Result<Signature<P>, Error> {
456+
self.raw_sign_deterministic(&[M], ctx)
457+
}
458+
459+
fn raw_sign_deterministic(&self, M: &[&[u8]], ctx: &[u8]) -> Result<Signature<P>, Error> {
443460
if ctx.len() > 255 {
444461
return Err(Error::new());
445462
}
446463

447464
let rnd = B32::default();
448-
let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
449-
Ok(self.sign_internal(Mp, &rnd))
465+
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
466+
Ok(self.raw_sign_internal(Mp, &rnd))
450467
}
451468

452469
/// Encode the key in a fixed-size byte array.
@@ -492,7 +509,13 @@ impl<P: MlDsaParams> SigningKey<P> {
492509
/// string, use the [`SigningKey::sign_deterministic`] method.
493510
impl<P: MlDsaParams> signature::Signer<Signature<P>> for SigningKey<P> {
494511
fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
495-
self.sign_deterministic(msg, &[])
512+
self.try_multi_part_sign(&[msg])
513+
}
514+
}
515+
516+
impl<P: MlDsaParams> MultiPartSigner<Signature<P>> for SigningKey<P> {
517+
fn try_multi_part_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
518+
self.raw_sign_deterministic(msg, &[])
496519
}
497520
}
498521

@@ -576,6 +599,13 @@ impl<P: MlDsaParams> VerifyingKey<P> {
576599
/// and it does not separate the context string from the rest of the message.
577600
// Algorithm 8 ML-DSA.Verify_internal
578601
pub fn verify_internal(&self, Mp: &[&[u8]], sigma: &Signature<P>) -> bool
602+
where
603+
P: MlDsaParams,
604+
{
605+
self.raw_verify_internal(&[Mp], sigma)
606+
}
607+
608+
fn raw_verify_internal(&self, Mp: &[&[&[u8]]], sigma: &Signature<P>) -> bool
579609
where
580610
P: MlDsaParams,
581611
{
@@ -605,12 +635,16 @@ impl<P: MlDsaParams> VerifyingKey<P> {
605635
/// This algorithm reflect the ML-DSA.Verify algorithm from FIPS 204.
606636
// Algorithm 3 ML-DSA.Verify
607637
pub fn verify_with_context(&self, M: &[u8], ctx: &[u8], sigma: &Signature<P>) -> bool {
638+
self.raw_verify_with_context(&[M], ctx, sigma)
639+
}
640+
641+
fn raw_verify_with_context(&self, M: &[&[u8]], ctx: &[u8], sigma: &Signature<P>) -> bool {
608642
if ctx.len() > 255 {
609643
return false;
610644
}
611645

612-
let Mp = &[&[0], &[Truncate::truncate(ctx.len())], ctx, M];
613-
self.verify_internal(Mp, sigma)
646+
let Mp = &[&[&[0], &[Truncate::truncate(ctx.len())], ctx], M];
647+
self.raw_verify_internal(Mp, sigma)
614648
}
615649

616650
fn encode_internal(rho: &B32, t1: &Vector<P::K>) -> EncodedVerifyingKey<P> {
@@ -635,7 +669,13 @@ impl<P: MlDsaParams> VerifyingKey<P> {
635669

636670
impl<P: MlDsaParams> signature::Verifier<Signature<P>> for VerifyingKey<P> {
637671
fn verify(&self, msg: &[u8], signature: &Signature<P>) -> Result<(), Error> {
638-
self.verify_with_context(msg, &[], signature)
672+
self.multi_part_verify(&[msg], signature)
673+
}
674+
}
675+
676+
impl<P: MlDsaParams> MultiPartVerifier<Signature<P>> for VerifyingKey<P> {
677+
fn multi_part_verify(&self, msg: &[&[u8]], signature: &Signature<P>) -> Result<(), Error> {
678+
self.raw_verify_with_context(msg, &[], signature)
639679
.then_some(())
640680
.ok_or(Error::new())
641681
}

slh-dsa/src/hashes.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ pub(crate) trait HashSuite: Sized + Clone + Debug + PartialEq + Eq {
2323
fn prf_msg(
2424
sk_prf: &SkPrf<Self::N>,
2525
opt_rand: &Array<u8, Self::N>,
26-
msg: &[impl AsRef<[u8]>],
26+
msg: &[&[impl AsRef<[u8]>]],
2727
) -> Array<u8, Self::N>;
2828

2929
/// Hashes a message using a given randomizer
3030
fn h_msg(
3131
rand: &Array<u8, Self::N>,
3232
pk_seed: &PkSeed<Self::N>,
3333
pk_root: &Array<u8, Self::N>,
34-
msg: &[impl AsRef<[u8]>],
34+
msg: &[&[impl AsRef<[u8]>]],
3535
) -> Array<u8, Self::M>;
3636

3737
/// PRF that is used to generate the secret values in WOTS+ and FORS private keys.
@@ -76,7 +76,7 @@ mod tests {
7676
let opt_rand = Array::<u8, H::N>::from_fn(|_| 1);
7777
let msg = [2u8; 32];
7878

79-
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);
79+
let result = H::prf_msg(&sk_prf, &opt_rand, &[&[msg]]);
8080

8181
assert_eq!(result.as_slice(), expected);
8282
}
@@ -87,7 +87,7 @@ mod tests {
8787
let pk_root = Array::<u8, H::N>::from_fn(|_| 2);
8888
let msg = [3u8; 32];
8989

90-
let result = H::h_msg(&rand, &pk_seed, &pk_root, &[msg]);
90+
let result = H::h_msg(&rand, &pk_seed, &pk_root, &[&[msg]]);
9191

9292
assert_eq!(result.as_slice(), expected);
9393
}

slh-dsa/src/hashes/sha2.rs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ where
6161
fn prf_msg(
6262
sk_prf: &SkPrf<Self::N>,
6363
opt_rand: &Array<u8, Self::N>,
64-
msg: &[impl AsRef<[u8]>],
64+
msg: &[&[impl AsRef<[u8]>]],
6565
) -> Array<u8, Self::N> {
6666
let mut mac = Hmac::<Sha256>::new_from_slice(sk_prf.as_ref()).unwrap();
6767
mac.update(opt_rand.as_slice());
6868
msg.iter()
69+
.copied()
70+
.flatten()
6971
.for_each(|msg_part| mac.update(msg_part.as_ref()));
7072
let result = mac.finalize().into_bytes();
7173
Array::clone_from_slice(&result[..Self::N::USIZE])
@@ -75,13 +77,16 @@ where
7577
rand: &Array<u8, Self::N>,
7678
pk_seed: &PkSeed<Self::N>,
7779
pk_root: &Array<u8, Self::N>,
78-
msg: &[impl AsRef<[u8]>],
80+
msg: &[&[impl AsRef<[u8]>]],
7981
) -> Array<u8, Self::M> {
8082
let mut h = Sha256::new();
8183
h.update(rand);
8284
h.update(pk_seed);
8385
h.update(pk_root);
84-
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
86+
msg.iter()
87+
.copied()
88+
.flatten()
89+
.for_each(|msg_part| h.update(msg_part.as_ref()));
8590
let result = Array(h.finalize().into());
8691
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
8792
mgf1::<Sha256, Self::M>(&seed)
@@ -224,11 +229,13 @@ where
224229
fn prf_msg(
225230
sk_prf: &SkPrf<Self::N>,
226231
opt_rand: &Array<u8, Self::N>,
227-
msg: &[impl AsRef<[u8]>],
232+
msg: &[&[impl AsRef<[u8]>]],
228233
) -> Array<u8, Self::N> {
229234
let mut mac = Hmac::<Sha512>::new_from_slice(sk_prf.as_ref()).unwrap();
230235
mac.update(opt_rand.as_slice());
231236
msg.iter()
237+
.copied()
238+
.flatten()
232239
.for_each(|msg_part| mac.update(msg_part.as_ref()));
233240
let result = mac.finalize().into_bytes();
234241
Array::clone_from_slice(&result[..Self::N::USIZE])
@@ -238,13 +245,16 @@ where
238245
rand: &Array<u8, Self::N>,
239246
pk_seed: &PkSeed<Self::N>,
240247
pk_root: &Array<u8, Self::N>,
241-
msg: &[impl AsRef<[u8]>],
248+
msg: &[&[impl AsRef<[u8]>]],
242249
) -> Array<u8, Self::M> {
243250
let mut h = Sha512::new();
244251
h.update(rand);
245252
h.update(pk_seed);
246253
h.update(pk_root);
247-
msg.iter().for_each(|msg_part| h.update(msg_part.as_ref()));
254+
msg.iter()
255+
.copied()
256+
.flatten()
257+
.for_each(|msg_part| h.update(msg_part.as_ref()));
248258
let result = Array(h.finalize().into());
249259
let seed = rand.clone().concat(pk_seed.0.clone()).concat(result);
250260
mgf1::<Sha512, Self::M>(&seed)

slh-dsa/src/hashes/shake.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ where
3535
fn prf_msg(
3636
sk_prf: &SkPrf<Self::N>,
3737
opt_rand: &Array<u8, Self::N>,
38-
msg: &[impl AsRef<[u8]>],
38+
msg: &[&[impl AsRef<[u8]>]],
3939
) -> Array<u8, Self::N> {
4040
let mut hasher = Shake256::default();
4141
hasher.update(sk_prf.as_ref());
4242
hasher.update(opt_rand.as_slice());
4343
msg.iter()
44+
.copied()
45+
.flatten()
4446
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
4547
let mut output = Array::<u8, Self::N>::default();
4648
hasher.finalize_xof_into(&mut output);
@@ -51,13 +53,15 @@ where
5153
rand: &Array<u8, Self::N>,
5254
pk_seed: &PkSeed<Self::N>,
5355
pk_root: &Array<u8, Self::N>,
54-
msg: &[impl AsRef<[u8]>],
56+
msg: &[&[impl AsRef<[u8]>]],
5557
) -> Array<u8, Self::M> {
5658
let mut hasher = Shake256::default();
5759
hasher.update(rand.as_slice());
5860
hasher.update(pk_seed.as_ref());
5961
hasher.update(pk_root.as_ref());
6062
msg.iter()
63+
.copied()
64+
.flatten()
6165
.for_each(|msg_part| hasher.update(msg_part.as_ref()));
6266
let mut output = Array::<u8, Self::M>::default();
6367
hasher.finalize_xof_into(&mut output);
@@ -276,7 +280,7 @@ mod tests {
276280

277281
let expected = hex!("bc5c062307df0a41aeeae19ad655f7b2");
278282

279-
let result = H::prf_msg(&sk_prf, &opt_rand, &[msg]);
283+
let result = H::prf_msg(&sk_prf, &opt_rand, &[&[msg]]);
280284

281285
assert_eq!(result.as_slice(), expected);
282286
}

slh-dsa/src/signing_key.rs

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use crate::util::split_digest;
44
use crate::verifying_key::VerifyingKey;
55
use crate::{ParameterSet, PkSeed, Sha2L1, Sha2L35, Shake, VerifyingKeyLen};
66
use ::signature::{
7-
Error, KeypairRef, RandomizedSigner, Signer,
7+
Error, KeypairRef, MultiPartSigner, RandomizedMultiPartSigner, RandomizedSigner, Signer,
88
rand_core::{CryptoRng, TryCryptoRng},
99
};
1010
use hybrid_array::{Array, ArraySize};
@@ -133,6 +133,10 @@ impl<P: ParameterSet> SigningKey<P> {
133133
/// Published for KAT validation purposes but not intended for general use.
134134
/// opt_rand must be a P::N length slice, panics otherwise.
135135
pub fn slh_sign_internal(&self, msg: &[&[u8]], opt_rand: Option<&[u8]>) -> Signature<P> {
136+
self.raw_slh_sign_internal(&[msg], opt_rand)
137+
}
138+
139+
fn raw_slh_sign_internal(&self, msg: &[&[&[u8]]], opt_rand: Option<&[u8]>) -> Signature<P> {
136140
let rand = opt_rand
137141
.unwrap_or(&self.verifying_key.pk_seed.0)
138142
.try_into()
@@ -167,12 +171,21 @@ impl<P: ParameterSet> SigningKey<P> {
167171
msg: &[u8],
168172
ctx: &[u8],
169173
opt_rand: Option<&[u8]>,
174+
) -> Result<Signature<P>, Error> {
175+
self.raw_try_sign_with_context(&[msg], ctx, opt_rand)
176+
}
177+
178+
fn raw_try_sign_with_context(
179+
&self,
180+
msg: &[&[u8]],
181+
ctx: &[u8],
182+
opt_rand: Option<&[u8]>,
170183
) -> Result<Signature<P>, Error> {
171184
let ctx_len = u8::try_from(ctx.len()).map_err(|_| Error::new())?;
172185
let ctx_len_bytes = ctx_len.to_be_bytes();
173186

174-
let ctx_msg = [&[0], &ctx_len_bytes, ctx, msg];
175-
Ok(self.slh_sign_internal(&ctx_msg, opt_rand))
187+
let ctx_msg = [&[&[0], &ctx_len_bytes, ctx], msg];
188+
Ok(self.raw_slh_sign_internal(&ctx_msg, opt_rand))
176189
}
177190

178191
/// Serialize the signing key to a new stack-allocated array
@@ -218,7 +231,13 @@ impl<P: ParameterSet> TryFrom<&[u8]> for SigningKey<P> {
218231

219232
impl<P: ParameterSet> Signer<Signature<P>> for SigningKey<P> {
220233
fn try_sign(&self, msg: &[u8]) -> Result<Signature<P>, Error> {
221-
self.try_sign_with_context(msg, &[], None)
234+
self.try_multi_part_sign(&[msg])
235+
}
236+
}
237+
238+
impl<P: ParameterSet> MultiPartSigner<Signature<P>> for SigningKey<P> {
239+
fn try_multi_part_sign(&self, msg: &[&[u8]]) -> Result<Signature<P>, Error> {
240+
self.raw_try_sign_with_context(msg, &[], None)
222241
}
223242
}
224243

@@ -228,10 +247,20 @@ impl<P: ParameterSet> RandomizedSigner<Signature<P>> for SigningKey<P> {
228247
rng: &mut R,
229248
msg: &[u8],
230249
) -> Result<Signature<P>, signature::Error> {
250+
self.try_multi_part_sign_with_rng(rng, &[msg])
251+
}
252+
}
253+
254+
impl<P: ParameterSet> RandomizedMultiPartSigner<Signature<P>> for SigningKey<P> {
255+
fn try_multi_part_sign_with_rng<R: TryCryptoRng + ?Sized>(
256+
&self,
257+
rng: &mut R,
258+
msg: &[&[u8]],
259+
) -> Result<Signature<P>, Error> {
231260
let mut randomizer = Array::<u8, P::N>::default();
232261
rng.try_fill_bytes(randomizer.as_mut_slice())
233262
.map_err(|_| signature::Error::new())?;
234-
self.try_sign_with_context(msg, &[], Some(&randomizer))
263+
self.raw_try_sign_with_context(msg, &[], Some(&randomizer))
235264
}
236265
}
237266

0 commit comments

Comments
 (0)