Skip to content

Commit 4b3864c

Browse files
authored
Demoing zero-copy save. (#567)
* Demoing zero-copy save. * Fixing clippy issue. * Some cleanup. * Sdist doesn't require feature ? * Incorrect clean. * Clippy ? * Sanity check * Fixing the doc builder? * Using pre-commit for quality. * This should work. * Clippy variant. * pyfeature typo. * Bypassing the necessity for an env ? * Remove maturin. * BigEndian fix. * Only black. * We need to check both features. * ?? * Asking for readonly is not possible. * Before the build. * Fixing byte-endian?
1 parent fa83351 commit 4b3864c

File tree

10 files changed

+191
-93
lines changed

10 files changed

+191
-93
lines changed

.github/workflows/build_documentation.yml

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ on:
1111
jobs:
1212
build:
1313
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
14+
env:
15+
MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module"
1416
with:
1517
commit_sha: ${{ github.sha }}
1618
package: safetensors

.github/workflows/build_pr_documentation.yml

+2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ concurrency:
1414
jobs:
1515
build:
1616
uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
17+
env:
18+
MATURIN_PEP517_ARGS: "--features py311,pyo3/extension-module"
1719
with:
1820
commit_sha: ${{ github.event.pull_request.head.sha }}
1921
pr_number: ${{ github.event.number }}

.github/workflows/python-release.yml

+12-8
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ jobs:
2323
runs-on: ${{ matrix.platform.runner }}
2424
strategy:
2525
matrix:
26+
pyfeature: ["py38", "py311"]
2627
platform:
2728
- runner: ubuntu-latest
2829
target: x86_64
@@ -45,19 +46,20 @@ jobs:
4546
uses: PyO3/maturin-action@v1
4647
with:
4748
target: ${{ matrix.platform.target }}
48-
args: --release --out dist --manifest-path bindings/python/Cargo.toml
49+
args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }}
4950
sccache: 'true'
5051
manylinux: auto
5152
- name: Upload wheels
5253
uses: actions/upload-artifact@v4
5354
with:
54-
name: wheels-linux-${{ matrix.platform.target }}
55+
name: wheels-linux-${{ matrix.platform.target }}-${{ matrix.pyfeature }}
5556
path: dist
5657

5758
musllinux:
5859
runs-on: ${{ matrix.platform.runner }}
5960
strategy:
6061
matrix:
62+
pyfeature: ["py38", "py311"]
6163
platform:
6264
- runner: ubuntu-latest
6365
target: x86_64
@@ -76,19 +78,20 @@ jobs:
7678
uses: PyO3/maturin-action@v1
7779
with:
7880
target: ${{ matrix.platform.target }}
79-
args: --release --out dist --manifest-path bindings/python/Cargo.toml
81+
args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }}
8082
sccache: 'true'
8183
manylinux: musllinux_1_2
8284
- name: Upload wheels
8385
uses: actions/upload-artifact@v4
8486
with:
85-
name: wheels-musllinux-${{ matrix.platform.target }}
87+
name: wheels-musllinux-${{ matrix.platform.target }}-${{ matrix.pyfeature }}
8688
path: dist
8789

8890
windows:
8991
runs-on: ${{ matrix.platform.runner }}
9092
strategy:
9193
matrix:
94+
pyfeature: ["py38", "py311"]
9295
platform:
9396
- runner: windows-latest
9497
target: x64
@@ -104,18 +107,19 @@ jobs:
104107
uses: PyO3/maturin-action@v1
105108
with:
106109
target: ${{ matrix.platform.target }}
107-
args: --release --out dist --manifest-path bindings/python/Cargo.toml
110+
args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }}
108111
sccache: 'true'
109112
- name: Upload wheels
110113
uses: actions/upload-artifact@v4
111114
with:
112-
name: wheels-windows-${{ matrix.platform.target }}
115+
name: wheels-windows-${{ matrix.platform.target }}-${{ matrix.pyfeature }}
113116
path: dist
114117

115118
macos:
116119
runs-on: ${{ matrix.platform.runner }}
117120
strategy:
118121
matrix:
122+
pyfeature: ["py38", "py311"]
119123
platform:
120124
- runner: macos-13
121125
target: x86_64
@@ -130,12 +134,12 @@ jobs:
130134
uses: PyO3/maturin-action@v1
131135
with:
132136
target: ${{ matrix.platform.target }}
133-
args: --release --out dist --manifest-path bindings/python/Cargo.toml
137+
args: --release --out dist --manifest-path bindings/python/Cargo.toml --features pyo3/extension-module,${{ matrix.pyfeature }}
134138
sccache: 'true'
135139
- name: Upload wheels
136140
uses: actions/upload-artifact@v4
137141
with:
138-
name: wheels-macos-${{ matrix.platform.target }}
142+
name: wheels-macos-${{ matrix.platform.target }}-${{ matrix.pyfeature }}
139143
path: dist
140144

141145
sdist:

.github/workflows/python.yml

+6-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@ jobs:
77
build_and_test:
88
name: Check everything builds & tests
99
runs-on: ${{ matrix.os }}
10+
env:
11+
MATURIN_PEP517_ARGS: "--features ${{ matrix.version.pyfeature }},pyo3/extension-module"
1012
strategy:
1113
matrix:
1214
os: [ubuntu-latest, macos-13, windows-latest]
1315
# Lowest and highest, no version specified so that
1416
# new releases get automatically tested against
15-
version: [{torch: torch==1.10, python: "3.8"}, {torch: torch, python: "3.12"}]
17+
version: [{torch: torch==1.10, python: "3.8", pyfeature: "py38"}, {torch: torch, python: "3.12", pyfeature: "py311"}]
1618
# TODO this would include macos ARM target.
1719
# however jax has an illegal instruction issue
1820
# that exists only in CI (probably difference in instruction support).
@@ -52,14 +54,14 @@ jobs:
5254
run: cargo fmt -- --check
5355

5456
- name: Lint with Clippy
55-
run: cargo clippy --all-targets --all-features -- -D warnings
57+
run: |
58+
cargo clippy --features ${{ matrix.version.pyfeature }} -- -D warnings
5659
5760
- name: Run Audit
5861
run: cargo audit -D warnings
5962

6063
- name: Install
6164
run: |
62-
pip install -U pip
6365
pip install .[numpy,tensorflow]
6466
pip install ${{ matrix.version.torch }}
6567
@@ -82,7 +84,7 @@ jobs:
8284
8385
- name: Run tests
8486
run: |
85-
cargo test
87+
cargo test --features ${{ matrix.version.pyfeature }}
8688
pip install .[testing]
8789
pytest -sv tests/
8890

.pre-commit-config.yaml

+2
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ repos:
2727
[
2828
"--manifest-path",
2929
"bindings/python/Cargo.toml",
30+
"--features",
31+
"py311",
3032
"--all-targets",
3133
"--",
3234
"-Dwarnings",

Dockerfile.s390x.test

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ RUN /root/miniconda3/bin/pip install -U pip pytest
1111
COPY . .
1212
SHELL ["/bin/bash", "-c"]
1313
WORKDIR /safetensors/bindings/python/
14+
ENV MATURIN_PEP517_ARGS="--features py311,pyo3/extension-module"
1415
RUN source /root/.cargo/env && /root/miniconda3/bin/pip install -e .
1516
RUN /root/miniconda3/bin/pytest -sv tests/test_pt_* tests/test_simple.py
1617
# RUN /root/miniconda3/bin/python -c 'from huggingface_hub import hf_hub_download; filename = hf_hub_download("roberta-base", "model.safetensors"); from safetensors.torch import load_file; weights = load_file(filename); assert weights["roberta.embeddings.position_embeddings.weight"][0][0].abs().item() > 1e-10'

bindings/python/Cargo.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,12 @@ rust-version = "1.74"
99
name = "safetensors_rust"
1010
crate-type = ["cdylib"]
1111

12+
[features]
13+
py38 = ["pyo3/abi3-py38"]
14+
py311 = ["pyo3/abi3-py311"]
15+
1216
[dependencies]
13-
pyo3 = { version = "0.23", features = ["abi3", "abi3-py38"] }
17+
pyo3 = { version = "0.23", features = ["abi3"] }
1418
memmap2 = "0.9"
1519
serde_json = "1.0"
1620

bindings/python/py_src/safetensors/torch.py

+14-5
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def _remove_duplicate_names(
128128

129129

130130
def save_model(
131-
model: torch.nn.Module, filename: str, metadata: Optional[Dict[str, str]] = None, force_contiguous: bool = True
131+
model: torch.nn.Module,
132+
filename: str,
133+
metadata: Optional[Dict[str, str]] = None,
134+
force_contiguous: bool = True,
132135
):
133136
"""
134137
Saves a given torch model to specified filename.
@@ -174,7 +177,10 @@ def save_model(
174177

175178

176179
def load_model(
177-
model: torch.nn.Module, filename: Union[str, os.PathLike], strict: bool = True, device: Union[str, int] = "cpu"
180+
model: torch.nn.Module,
181+
filename: Union[str, os.PathLike],
182+
strict: bool = True,
183+
device: Union[str, int] = "cpu",
178184
) -> Tuple[List[str], List[str]]:
179185
"""
180186
Loads a given filename onto a torch model.
@@ -402,7 +408,7 @@ def _view2torch(safeview) -> Dict[str, torch.Tensor]:
402408
return result
403409

404410

405-
def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
411+
def _tobytes(tensor: torch.Tensor, name: str) -> Union[memoryview, bytes]:
406412
if tensor.layout != torch.strided:
407413
raise ValueError(
408414
f"You are trying to save a sparse tensor: `{name}` which this library does not support."
@@ -456,8 +462,11 @@ def _tobytes(tensor: torch.Tensor, name: str) -> bytes:
456462
}
457463
npdtype = NPDTYPES[tensor.dtype]
458464
# Not in place as that would potentially modify a live running model
459-
data = data.view(npdtype).byteswap(inplace=False)
460-
return data.tobytes()
465+
data = data.view(npdtype).byteswap(inplace=False).view(np.uint8)
466+
if sys.version_info >= (3, 11):
467+
return data.data
468+
else:
469+
return data.tobytes()
461470

462471

463472
def _flatten(tensors: Dict[str, torch.Tensor]) -> Dict[str, Dict[str, Any]]:

bindings/python/src/lib.rs

+10-75
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
#![deny(missing_docs)]
22
//! Dummy doc
3+
#[cfg(any(feature = "py38", feature = "py311"))]
4+
mod view;
35
use memmap2::{Mmap, MmapOptions};
46
use pyo3::exceptions::{PyException, PyFileNotFoundError};
57
use pyo3::prelude::*;
@@ -10,94 +12,27 @@ use pyo3::Bound as PyBound;
1012
use pyo3::{intern, PyErr};
1113
use safetensors::slice::TensorIndexer;
1214
use safetensors::tensor::{Dtype, Metadata, SafeTensors, TensorInfo, TensorView};
13-
use safetensors::View;
14-
use std::borrow::Cow;
1515
use std::collections::HashMap;
1616
use std::fs::File;
1717
use std::iter::FromIterator;
1818
use std::ops::Bound;
1919
use std::path::PathBuf;
2020
use std::sync::Arc;
21+
#[cfg(any(feature = "py38", feature = "py311"))]
22+
use view::prepare;
2123

2224
static TORCH_MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
2325
static NUMPY_MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
2426
static TENSORFLOW_MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
2527
static FLAX_MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
2628
static MLX_MODULE: GILOnceCell<Py<PyModule>> = GILOnceCell::new();
2729

28-
struct PyView<'a> {
29-
shape: Vec<usize>,
30-
dtype: Dtype,
31-
data: PyBound<'a, PyBytes>,
32-
data_len: usize,
33-
}
34-
35-
impl View for &PyView<'_> {
36-
fn data(&self) -> std::borrow::Cow<[u8]> {
37-
Cow::Borrowed(self.data.as_bytes())
38-
}
39-
fn shape(&self) -> &[usize] {
40-
&self.shape
41-
}
42-
fn dtype(&self) -> Dtype {
43-
self.dtype
44-
}
45-
fn data_len(&self) -> usize {
46-
self.data_len
47-
}
48-
}
49-
50-
fn prepare(tensor_dict: HashMap<String, PyBound<PyDict>>) -> PyResult<HashMap<String, PyView>> {
51-
let mut tensors = HashMap::with_capacity(tensor_dict.len());
52-
for (tensor_name, tensor_desc) in &tensor_dict {
53-
let shape: Vec<usize> = tensor_desc
54-
.get_item("shape")?
55-
.ok_or_else(|| SafetensorError::new_err(format!("Missing `shape` in {tensor_desc:?}")))?
56-
.extract()?;
57-
let pydata: PyBound<PyAny> = tensor_desc.get_item("data")?.ok_or_else(|| {
58-
SafetensorError::new_err(format!("Missing `data` in {tensor_desc:?}"))
59-
})?;
60-
// Make sure it's extractable first.
61-
let data: &[u8] = pydata.extract()?;
62-
let data_len = data.len();
63-
let data: PyBound<PyBytes> = pydata.extract()?;
64-
let pydtype = tensor_desc.get_item("dtype")?.ok_or_else(|| {
65-
SafetensorError::new_err(format!("Missing `dtype` in {tensor_desc:?}"))
66-
})?;
67-
let dtype: String = pydtype.extract()?;
68-
let dtype = match dtype.as_ref() {
69-
"bool" => Dtype::BOOL,
70-
"int8" => Dtype::I8,
71-
"uint8" => Dtype::U8,
72-
"int16" => Dtype::I16,
73-
"uint16" => Dtype::U16,
74-
"int32" => Dtype::I32,
75-
"uint32" => Dtype::U32,
76-
"int64" => Dtype::I64,
77-
"uint64" => Dtype::U64,
78-
"float16" => Dtype::F16,
79-
"float32" => Dtype::F32,
80-
"float64" => Dtype::F64,
81-
"bfloat16" => Dtype::BF16,
82-
"float8_e4m3fn" => Dtype::F8_E4M3,
83-
"float8_e5m2" => Dtype::F8_E5M2,
84-
dtype_str => {
85-
return Err(SafetensorError::new_err(format!(
86-
"dtype {dtype_str} is not covered",
87-
)));
88-
}
89-
};
90-
91-
let tensor = PyView {
92-
shape,
93-
dtype,
94-
data,
95-
data_len,
96-
};
97-
tensors.insert(tensor_name.to_string(), tensor);
98-
}
99-
Ok(tensors)
100-
}
30+
#[cfg(not(any(feature = "py38", feature = "py311")))]
31+
compile_error!(
32+
"At least one python version must be enabled, use `maturin develop --features py311,pyo3/extension-module`"
33+
);
34+
#[cfg(all(feature = "py38", feature = "py311"))]
35+
compile_error!("Only one python version must be enabled");
10136

10237
/// Serializes raw data.
10338
///

0 commit comments

Comments
 (0)