1
1
#![ deny( missing_docs) ]
2
2
//! Dummy doc
3
+ #[ cfg( any( feature = "py38" , feature = "py311" ) ) ]
4
+ mod view;
3
5
use memmap2:: { Mmap , MmapOptions } ;
4
6
use pyo3:: exceptions:: { PyException , PyFileNotFoundError } ;
5
7
use pyo3:: prelude:: * ;
@@ -10,94 +12,27 @@ use pyo3::Bound as PyBound;
10
12
use pyo3:: { intern, PyErr } ;
11
13
use safetensors:: slice:: TensorIndexer ;
12
14
use safetensors:: tensor:: { Dtype , Metadata , SafeTensors , TensorInfo , TensorView } ;
13
- use safetensors:: View ;
14
- use std:: borrow:: Cow ;
15
15
use std:: collections:: HashMap ;
16
16
use std:: fs:: File ;
17
17
use std:: iter:: FromIterator ;
18
18
use std:: ops:: Bound ;
19
19
use std:: path:: PathBuf ;
20
20
use std:: sync:: Arc ;
21
+ #[ cfg( any( feature = "py38" , feature = "py311" ) ) ]
22
+ use view:: prepare;
21
23
22
24
static TORCH_MODULE : GILOnceCell < Py < PyModule > > = GILOnceCell :: new ( ) ;
23
25
static NUMPY_MODULE : GILOnceCell < Py < PyModule > > = GILOnceCell :: new ( ) ;
24
26
static TENSORFLOW_MODULE : GILOnceCell < Py < PyModule > > = GILOnceCell :: new ( ) ;
25
27
static FLAX_MODULE : GILOnceCell < Py < PyModule > > = GILOnceCell :: new ( ) ;
26
28
static MLX_MODULE : GILOnceCell < Py < PyModule > > = GILOnceCell :: new ( ) ;
27
29
28
- struct PyView < ' a > {
29
- shape : Vec < usize > ,
30
- dtype : Dtype ,
31
- data : PyBound < ' a , PyBytes > ,
32
- data_len : usize ,
33
- }
34
-
35
- impl View for & PyView < ' _ > {
36
- fn data ( & self ) -> std:: borrow:: Cow < [ u8 ] > {
37
- Cow :: Borrowed ( self . data . as_bytes ( ) )
38
- }
39
- fn shape ( & self ) -> & [ usize ] {
40
- & self . shape
41
- }
42
- fn dtype ( & self ) -> Dtype {
43
- self . dtype
44
- }
45
- fn data_len ( & self ) -> usize {
46
- self . data_len
47
- }
48
- }
49
-
50
- fn prepare ( tensor_dict : HashMap < String , PyBound < PyDict > > ) -> PyResult < HashMap < String , PyView > > {
51
- let mut tensors = HashMap :: with_capacity ( tensor_dict. len ( ) ) ;
52
- for ( tensor_name, tensor_desc) in & tensor_dict {
53
- let shape: Vec < usize > = tensor_desc
54
- . get_item ( "shape" ) ?
55
- . ok_or_else ( || SafetensorError :: new_err ( format ! ( "Missing `shape` in {tensor_desc:?}" ) ) ) ?
56
- . extract ( ) ?;
57
- let pydata: PyBound < PyAny > = tensor_desc. get_item ( "data" ) ?. ok_or_else ( || {
58
- SafetensorError :: new_err ( format ! ( "Missing `data` in {tensor_desc:?}" ) )
59
- } ) ?;
60
- // Make sure it's extractable first.
61
- let data: & [ u8 ] = pydata. extract ( ) ?;
62
- let data_len = data. len ( ) ;
63
- let data: PyBound < PyBytes > = pydata. extract ( ) ?;
64
- let pydtype = tensor_desc. get_item ( "dtype" ) ?. ok_or_else ( || {
65
- SafetensorError :: new_err ( format ! ( "Missing `dtype` in {tensor_desc:?}" ) )
66
- } ) ?;
67
- let dtype: String = pydtype. extract ( ) ?;
68
- let dtype = match dtype. as_ref ( ) {
69
- "bool" => Dtype :: BOOL ,
70
- "int8" => Dtype :: I8 ,
71
- "uint8" => Dtype :: U8 ,
72
- "int16" => Dtype :: I16 ,
73
- "uint16" => Dtype :: U16 ,
74
- "int32" => Dtype :: I32 ,
75
- "uint32" => Dtype :: U32 ,
76
- "int64" => Dtype :: I64 ,
77
- "uint64" => Dtype :: U64 ,
78
- "float16" => Dtype :: F16 ,
79
- "float32" => Dtype :: F32 ,
80
- "float64" => Dtype :: F64 ,
81
- "bfloat16" => Dtype :: BF16 ,
82
- "float8_e4m3fn" => Dtype :: F8_E4M3 ,
83
- "float8_e5m2" => Dtype :: F8_E5M2 ,
84
- dtype_str => {
85
- return Err ( SafetensorError :: new_err ( format ! (
86
- "dtype {dtype_str} is not covered" ,
87
- ) ) ) ;
88
- }
89
- } ;
90
-
91
- let tensor = PyView {
92
- shape,
93
- dtype,
94
- data,
95
- data_len,
96
- } ;
97
- tensors. insert ( tensor_name. to_string ( ) , tensor) ;
98
- }
99
- Ok ( tensors)
100
- }
30
+ #[ cfg( not( any( feature = "py38" , feature = "py311" ) ) ) ]
31
+ compile_error ! (
32
+ "At least one python version must be enabled, use `maturin develop --features py311,pyo3/extension-module`"
33
+ ) ;
34
+ #[ cfg( all( feature = "py38" , feature = "py311" ) ) ]
35
+ compile_error ! ( "Only one python version must be enabled" ) ;
101
36
102
37
/// Serializes raw data.
103
38
///
0 commit comments