Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 21e11cd

Browse files
authored
[fix] match metric names ignoring separators (#310) (#311)
1 parent c577c8c commit 21e11cd

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/sparsezoo/deployment_package/utils/extractors.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import logging
2020
from types import MappingProxyType
21-
from typing import Optional
21+
from typing import List, Optional, Union
2222

2323
from sparsezoo import Model
2424

@@ -75,21 +75,41 @@ def _accuracy(model: Model, metric_name=None) -> float:
7575

7676
if metric_name is not None:
7777
for result in validation_results:
78-
if metric_name in result.recorded_units.lower():
78+
if _metric_name_matches(metric_name, result.recorded_units.lower()):
7979
return result.recorded_value
8080
_LOGGER.info(f"metric name {metric_name} not found for model {model}")
8181

8282
# fallback to if any accuracy metric found
8383
accuracy_metrics = ["accuracy", "f1", "recall", "map", "top1 accuracy"]
8484
for result in validation_results:
85-
if result.recorded_units.lower() in accuracy_metrics:
85+
if _metric_name_matches(result.recorded_units.lower(), accuracy_metrics):
8686
return result.recorded_value
8787

8888
raise ValueError(
8989
f"Could not find any accuracy metric {accuracy_metrics} for model {model}"
9090
)
9191

9292

93+
def _metric_name_matches(
94+
metric_name: str, target_metrics: Union[str, List[str]]
95+
) -> bool:
96+
# returns true if metric name is included in the target metrics
97+
if isinstance(target_metrics, str):
98+
target_metrics = [target_metrics]
99+
return any(
100+
_standardized_str_eq(metric_name, target_metric)
101+
for target_metric in target_metrics
102+
)
103+
104+
105+
def _standardized_str_eq(str_1: str, str_2: str) -> bool:
106+
# strings are equal if lowercase, striped of spaces, -, and _ are equal
107+
def _standardize(string):
108+
return string.lower().replace(" ", "").replace("-", "").replace("_", "")
109+
110+
return _standardize(str_1) == _standardize(str_2)
111+
112+
93113
EXTRACTORS = MappingProxyType(
94114
{
95115
"compression": _size,

0 commit comments

Comments
 (0)