@@ -29,6 +29,7 @@ use crate::{
29
29
ApplicationId , BlobId , BlobType , BytecodeId , Destination , GenericApplicationId , MessageId ,
30
30
UserApplicationId ,
31
31
} ,
32
+ limited_writer:: { LimitedWriter , LimitedWriterError } ,
32
33
time:: { Duration , SystemTime } ,
33
34
} ;
34
35
@@ -842,14 +843,8 @@ impl fmt::Debug for Bytecode {
842
843
#[ derive( Error , Debug ) ]
843
844
pub enum DecompressionError {
844
845
/// Compressed bytecode is invalid, and could not be decompressed.
845
- #[ cfg( not( target_arch = "wasm32" ) ) ]
846
- #[ error( "Bytecode could not be decompressed" ) ]
847
- InvalidCompressedBytecode ( #[ source] io:: Error ) ,
848
-
849
- /// Compressed bytecode is invalid, and could not be decompressed.
850
- #[ cfg( target_arch = "wasm32" ) ]
851
- #[ error( "Bytecode could not be decompressed" ) ]
852
- InvalidCompressedBytecode ( #[ from] ruzstd:: frame_decoder:: FrameDecoderError ) ,
846
+ #[ error( "Bytecode could not be decompressed: {0}" ) ]
847
+ InvalidCompressedBytecode ( #[ from] io:: Error ) ,
853
848
}
854
849
855
850
/// A compressed WebAssembly module's bytecode.
@@ -878,30 +873,57 @@ impl From<Bytecode> for CompressedBytecode {
878
873
}
879
874
880
875
#[ cfg( not( target_arch = "wasm32" ) ) ]
881
- impl TryFrom < & CompressedBytecode > for Bytecode {
882
- type Error = DecompressionError ;
883
-
884
- fn try_from ( compressed_bytecode : & CompressedBytecode ) -> Result < Self , Self :: Error > {
885
- let bytes = zstd:: stream:: decode_all ( & * compressed_bytecode. compressed_bytes )
886
- . map_err ( DecompressionError :: InvalidCompressedBytecode ) ?;
876
+ impl CompressedBytecode {
877
+ /// Returns `true` if the decompressed size does not exceed the limit.
878
+ pub fn decompressed_size_at_most ( & self , limit : u64 ) -> Result < bool , DecompressionError > {
879
+ let mut decoder = zstd:: stream:: Decoder :: new ( & * self . compressed_bytes ) ?;
880
+ let limit = usize:: try_from ( limit) . unwrap_or ( usize:: MAX ) ;
881
+ let mut writer = LimitedWriter :: new ( io:: sink ( ) , limit) ;
882
+ match io:: copy ( & mut decoder, & mut writer) {
883
+ Ok ( _) => Ok ( true ) ,
884
+ Err ( error) => {
885
+ error. downcast :: < LimitedWriterError > ( ) ?;
886
+ Ok ( false )
887
+ }
888
+ }
889
+ }
887
890
891
+ /// Decompresses a [`CompressedBytecode`] into a [`Bytecode`].
892
+ pub fn decompress ( & self ) -> Result < Bytecode , DecompressionError > {
893
+ let bytes = zstd:: stream:: decode_all ( & * self . compressed_bytes ) ?;
888
894
Ok ( Bytecode { bytes } )
889
895
}
890
896
}
891
897
892
898
#[ cfg( target_arch = "wasm32" ) ]
893
- impl TryFrom < & CompressedBytecode > for Bytecode {
894
- type Error = DecompressionError ;
899
+ impl CompressedBytecode {
900
+ /// Returns `true` if the decompressed size does not exceed the limit.
901
+ pub fn decompressed_size_at_most ( & self , limit : u64 ) -> Result < bool , DecompressionError > {
902
+ let compressed_bytes = & * self . compressed_bytes ;
903
+ let limit = usize:: try_from ( limit) . unwrap_or ( usize:: MAX ) ;
904
+ let mut writer = LimitedWriter :: new ( io:: sink ( ) , limit) ;
905
+ let mut decoder = ruzstd:: streaming_decoder:: StreamingDecoder :: new ( compressed_bytes)
906
+ . map_err ( io:: Error :: other) ?;
907
+
908
+ // TODO(#2710): Decode multiple frames, if present
909
+ match io:: copy ( & mut decoder, & mut writer) {
910
+ Ok ( _) => Ok ( true ) ,
911
+ Err ( error) => {
912
+ error. downcast :: < LimitedWriterError > ( ) ?;
913
+ Ok ( false )
914
+ }
915
+ }
916
+ }
895
917
896
- fn try_from ( compressed_bytecode : & CompressedBytecode ) -> Result < Self , Self :: Error > {
918
+ /// Decompresses a [`CompressedBytecode`] into a [`Bytecode`].
919
+ pub fn decompress ( & self ) -> Result < Bytecode , DecompressionError > {
897
920
use ruzstd:: { io:: Read , streaming_decoder:: StreamingDecoder } ;
898
921
899
- let compressed_bytes = & * compressed_bytecode . compressed_bytes ;
922
+ let compressed_bytes = & * self . compressed_bytes ;
900
923
let mut bytes = Vec :: new ( ) ;
901
- let mut decoder = StreamingDecoder :: new ( compressed_bytes) ?;
924
+ let mut decoder = StreamingDecoder :: new ( compressed_bytes) . map_err ( io :: Error :: other ) ?;
902
925
903
- // Decode multiple frames, if present
904
- // (https://github.com/KillingSpark/zstd-rs/issues/57)
926
+ // TODO(#2710): Decode multiple frames, if present
905
927
while !decoder. get_ref ( ) . is_empty ( ) {
906
928
decoder
907
929
. read_to_end ( & mut bytes)
@@ -912,14 +934,6 @@ impl TryFrom<&CompressedBytecode> for Bytecode {
912
934
}
913
935
}
914
936
915
- impl TryFrom < CompressedBytecode > for Bytecode {
916
- type Error = DecompressionError ;
917
-
918
- fn try_from ( compressed_bytecode : CompressedBytecode ) -> Result < Self , Self :: Error > {
919
- Bytecode :: try_from ( & compressed_bytecode)
920
- }
921
- }
922
-
923
937
impl fmt:: Debug for CompressedBytecode {
924
938
fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
925
939
f. debug_struct ( "CompressedBytecode" ) . finish_non_exhaustive ( )
0 commit comments