Skip to content

Commit 751a449

Browse files
ashwinvaidya17adrianboguszewski
authored andcommitted
🐞 Fix GMM test (open-edge-platform#1696)
* Separate data Signed-off-by: Ashwin Vaidya <[email protected]> * limit torch version Signed-off-by: Ashwin Vaidya <[email protected]> --------- Signed-off-by: Ashwin Vaidya <[email protected]>
1 parent 8f56a37 commit 751a449

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

requirements/base.txt

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ pandas>=1.1.0
1212
lightning
1313
setuptools>=41.0.0
1414
timm>=0.5.4,<=0.6.13
15+
torch>=2,<2.2.0 # rkde export fails even with ONNX 17 (latest) with torch 2.2.0. TODO(ashwinvaidya17): revisit
1516
torchmetrics==0.10.3
1617
rich-argparse
1718
open-clip-torch>=2.23.0

tests/unit/models/components/clustering/test_gmm.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@ def test_fit_and_predict() -> None:
1616
# Create some synthetic data
1717
data = torch.cat(
1818
[
19-
torch.randn(100, 2) + torch.tensor([2.0, 2.0]),
20-
torch.randn(100, 2) + torch.tensor([-2.0, -2.0]),
19+
torch.randn(100, 2) + torch.tensor([10.0, 10.0]),
20+
torch.randn(100, 2) + torch.tensor([-10.0, -10.0]),
2121
],
2222
)
2323

0 commit comments

Comments
 (0)