|
18 | 18 |
|
19 | 19 | import logging
|
20 | 20 | from types import MappingProxyType
|
21 |
| -from typing import Optional |
| 21 | +from typing import List, Optional, Union |
22 | 22 |
|
23 | 23 | from sparsezoo import Model
|
24 | 24 |
|
@@ -75,21 +75,41 @@ def _accuracy(model: Model, metric_name=None) -> float:
|
75 | 75 |
|
76 | 76 | if metric_name is not None:
|
77 | 77 | for result in validation_results:
|
78 |
| - if metric_name in result.recorded_units.lower(): |
| 78 | + if _metric_name_matches(metric_name, result.recorded_units.lower()): |
79 | 79 | return result.recorded_value
|
80 | 80 | _LOGGER.info(f"metric name {metric_name} not found for model {model}")
|
81 | 81 |
|
82 | 82 | # fallback to if any accuracy metric found
|
83 | 83 | accuracy_metrics = ["accuracy", "f1", "recall", "map", "top1 accuracy"]
|
84 | 84 | for result in validation_results:
|
85 |
| - if result.recorded_units.lower() in accuracy_metrics: |
| 85 | + if _metric_name_matches(result.recorded_units.lower(), accuracy_metrics): |
86 | 86 | return result.recorded_value
|
87 | 87 |
|
88 | 88 | raise ValueError(
|
89 | 89 | f"Could not find any accuracy metric {accuracy_metrics} for model {model}"
|
90 | 90 | )
|
91 | 91 |
|
92 | 92 |
|
| 93 | +def _metric_name_matches( |
| 94 | + metric_name: str, target_metrics: Union[str, List[str]] |
| 95 | +) -> bool: |
| 96 | + # returns true if metric name is included in the target metrics |
| 97 | + if isinstance(target_metrics, str): |
| 98 | + target_metrics = [target_metrics] |
| 99 | + return any( |
| 100 | + _standardized_str_eq(metric_name, target_metric) |
| 101 | + for target_metric in target_metrics |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +def _standardized_str_eq(str_1: str, str_2: str) -> bool: |
| 106 | + # strings are equal if lowercase, striped of spaces, -, and _ are equal |
| 107 | + def _standardize(string): |
| 108 | + return string.lower().replace(" ", "").replace("-", "").replace("_", "") |
| 109 | + |
| 110 | + return _standardize(str_1) == _standardize(str_2) |
| 111 | + |
| 112 | + |
93 | 113 | EXTRACTORS = MappingProxyType(
|
94 | 114 | {
|
95 | 115 | "compression": _size,
|
|
0 commit comments