Skip to content

Commit 328cfc0

Browse files
committed
bindings/rust/src/lib.rs: enforce expected slice sizes when deserializing.
Fixes #14
1 parent bfd21a0 commit 328cfc0

File tree

1 file changed

+53
-36
lines changed

1 file changed

+53
-36
lines changed

bindings/rust/src/lib.rs

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -419,30 +419,38 @@ macro_rules! sig_variant_impl {
419419
}
420420

421421
pub fn uncompress(pk_comp: &[u8]) -> Result<Self, BLST_ERROR> {
422-
let mut pk = std::mem::MaybeUninit::<$pk_aff>::uninit();
423-
424-
unsafe {
425-
let err = $pk_uncomp(pk.as_mut_ptr(), pk_comp.as_ptr());
426-
if err != BLST_ERROR::BLST_SUCCESS {
427-
return Err(err);
422+
if pk_comp.len() == $pk_comp_size {
423+
unsafe {
424+
let mut pk = MaybeUninit::<$pk_aff>::uninit();
425+
let err = $pk_uncomp(pk.as_mut_ptr(), pk_comp.as_ptr());
426+
if err != BLST_ERROR::BLST_SUCCESS {
427+
return Err(err);
428+
}
429+
Ok(Self {
430+
point: pk.assume_init(),
431+
})
428432
}
429-
Ok(Self {
430-
point: pk.assume_init(),
431-
})
433+
} else {
434+
Err(BLST_ERROR::BLST_BAD_ENCODING)
432435
}
433436
}
434437

435438
pub fn deserialize(pk_in: &[u8]) -> Result<Self, BLST_ERROR> {
436-
let mut pk = std::mem::MaybeUninit::<$pk_aff>::uninit();
437-
438-
unsafe {
439-
let err = $pk_deser(pk.as_mut_ptr(), pk_in.as_ptr());
440-
if err != BLST_ERROR::BLST_SUCCESS {
441-
return Err(err);
439+
if pk_in.len() == $pk_ser_size
440+
|| pk_in.len() == $pk_comp_size && (pk_in[0] & 0x80) != 0
441+
{
442+
unsafe {
443+
let mut pk = MaybeUninit::<$pk_aff>::uninit();
444+
let err = $pk_deser(pk.as_mut_ptr(), pk_in.as_ptr());
445+
if err != BLST_ERROR::BLST_SUCCESS {
446+
return Err(err);
447+
}
448+
Ok(Self {
449+
point: pk.assume_init(),
450+
})
442451
}
443-
Ok(Self {
444-
point: pk.assume_init(),
445-
})
452+
} else {
453+
Err(BLST_ERROR::BLST_BAD_ENCODING)
446454
}
447455
}
448456

@@ -839,30 +847,39 @@ macro_rules! sig_variant_impl {
839847
}
840848

841849
pub fn uncompress(sig_comp: &[u8]) -> Result<Self, BLST_ERROR> {
842-
let mut sig = std::mem::MaybeUninit::<$sig_aff>::uninit();
843-
844-
unsafe {
845-
let err = $sig_uncomp(sig.as_mut_ptr(), sig_comp.as_ptr());
846-
if err != BLST_ERROR::BLST_SUCCESS {
847-
return Err(err);
850+
if sig_comp.len() == $sig_comp_size {
851+
unsafe {
852+
let mut sig = MaybeUninit::<$sig_aff>::uninit();
853+
let err =
854+
$sig_uncomp(sig.as_mut_ptr(), sig_comp.as_ptr());
855+
if err != BLST_ERROR::BLST_SUCCESS {
856+
return Err(err);
857+
}
858+
Ok(Self {
859+
point: sig.assume_init(),
860+
})
848861
}
849-
Ok(Self {
850-
point: sig.assume_init(),
851-
})
862+
} else {
863+
Err(BLST_ERROR::BLST_BAD_ENCODING)
852864
}
853865
}
854866

855867
pub fn deserialize(sig_in: &[u8]) -> Result<Self, BLST_ERROR> {
856-
let mut sig = std::mem::MaybeUninit::<$sig_aff>::uninit();
857-
858-
unsafe {
859-
let err = $sig_deser(sig.as_mut_ptr(), sig_in.as_ptr());
860-
if err != BLST_ERROR::BLST_SUCCESS {
861-
return Err(err);
868+
if sig_in.len() == $sig_ser_size
869+
|| sig_in.len() == $sig_comp_size && (sig_in[0] & 0x80) != 0
870+
{
871+
unsafe {
872+
let mut sig = MaybeUninit::<$sig_aff>::uninit();
873+
let err = $sig_deser(sig.as_mut_ptr(), sig_in.as_ptr());
874+
if err != BLST_ERROR::BLST_SUCCESS {
875+
return Err(err);
876+
}
877+
Ok(Self {
878+
point: sig.assume_init(),
879+
})
862880
}
863-
Ok(Self {
864-
point: sig.assume_init(),
865-
})
881+
} else {
882+
Err(BLST_ERROR::BLST_BAD_ENCODING)
866883
}
867884
}
868885

0 commit comments

Comments
 (0)