Skip to content

Commit 5ca1612

Browse files
authored
🚀from_config API: Create a path between API & configuration file (CLI) (#2065)
* Add from_config feature in model and data Signed-off-by: Kang, Harim <[email protected]> * Add Engine.from_config features Signed-off-by: Kang, Harim <[email protected]> * Add Unit-test for from_config Signed-off-by: Kang, Harim <[email protected]> * Add comment in CHANGELOG.md Signed-off-by: Kang, Harim <[email protected]> --------- Signed-off-by: Kang, Harim <[email protected]>
1 parent 64c123b commit 5ca1612

19 files changed

+442
-4
lines changed

‎CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
1111
- 🚀 Update OpenVINO and ONNX export to support fixed input shape by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2006
1212
- Add data_path argument to predict entrypoint and add properties for retrieving model path by @djdameln in https://github.com/openvinotoolkit/anomalib/pull/2018
1313
- 🚀 Add compression and quantization for OpenVINO export by @adrianboguszewski in https://github.com/openvinotoolkit/anomalib/pull/2052
14+
- 🚀from_config API: Create a path between API & configuration file (CLI) by @harimkang in https://github.com/openvinotoolkit/anomalib/pull/2065
1415

1516
### Changed
1617

‎configs/model/reverse_distillation.yaml

-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@ model:
66
- layer1
77
- layer2
88
- layer3
9-
beta1: 0.5
10-
beta2: 0.999
119
anomaly_map_mode: ADD
1210
pre_trained: true
1311

‎src/anomalib/cli/cli.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class AnomalibCLI:
5050
``SaveConfigCallback`` overwrites the config if it already exists.
5151
"""
5252

53-
def __init__(self, args: Sequence[str] | None = None) -> None:
53+
def __init__(self, args: Sequence[str] | None = None, run: bool = True) -> None:
5454
self.parser = self.init_parser()
5555
self.subcommand_parsers: dict[str, ArgumentParser] = {}
5656
self.subcommand_method_arguments: dict[str, list[str]] = {}
@@ -60,7 +60,8 @@ def __init__(self, args: Sequence[str] | None = None) -> None:
6060
if _LIGHTNING_AVAILABLE:
6161
self.before_instantiate_classes()
6262
self.instantiate_classes()
63-
self._run_subcommand()
63+
if run:
64+
self._run_subcommand()
6465

6566
def init_parser(self, **kwargs) -> ArgumentParser:
6667
"""Method that instantiates the argument parser."""

‎src/anomalib/data/base/datamodule.py

+49
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import logging
88
from abc import ABC, abstractmethod
9+
from pathlib import Path
910
from typing import TYPE_CHECKING, Any
1011

1112
from lightning.pytorch import LightningDataModule
@@ -289,3 +290,51 @@ def eval_transform(self) -> Transform:
289290
if self.image_size:
290291
return Resize(self.image_size, antialias=True)
291292
return None
293+
294+
@classmethod
295+
def from_config(
296+
cls: type["AnomalibDataModule"],
297+
config_path: str | Path,
298+
**kwargs,
299+
) -> "AnomalibDataModule":
300+
"""Create a datamodule instance from the configuration.
301+
302+
Args:
303+
config_path (str | Path): Path to the data configuration file.
304+
**kwargs (dict): Additional keyword arguments.
305+
306+
Returns:
307+
AnomalibDataModule: Datamodule instance.
308+
309+
Example:
310+
The following example shows how to get datamodule from mvtec.yaml:
311+
312+
.. code-block:: python
313+
>>> data_config = "configs/data/mvtec.yaml"
314+
>>> datamodule = AnomalibDataModule.from_config(config_path=data_config)
315+
316+
The following example shows overriding the configuration file with additional keyword arguments:
317+
318+
.. code-block:: python
319+
>>> override_kwargs = {"data.train_batch_size": 8}
320+
>>> datamodule = AnomalibDataModule.from_config(config_path=data_config, **override_kwargs)
321+
"""
322+
from jsonargparse import ArgumentParser
323+
324+
if not Path(config_path).exists():
325+
msg = f"Configuration file not found: {config_path}"
326+
raise FileNotFoundError(msg)
327+
328+
data_parser = ArgumentParser()
329+
data_parser.add_subclass_arguments(AnomalibDataModule, "data", required=False, fail_untyped=False)
330+
args = ["--data", str(config_path)]
331+
for key, value in kwargs.items():
332+
args.extend([f"--{key}", str(value)])
333+
config = data_parser.parse_args(args=args)
334+
instantiated_classes = data_parser.instantiate_classes(config)
335+
datamodule = instantiated_classes.get("data")
336+
if isinstance(datamodule, AnomalibDataModule):
337+
return datamodule
338+
339+
msg = f"Datamodule is not an instance of AnomalibDataModule: {datamodule}"
340+
raise ValueError(msg)

‎src/anomalib/engine/engine.py

+49
Original file line numberDiff line numberDiff line change
@@ -962,3 +962,52 @@ def export(
962962
if exported_model_path:
963963
logging.info(f"Exported model to {exported_model_path}")
964964
return exported_model_path
965+
966+
@classmethod
967+
def from_config(
968+
cls: type["Engine"],
969+
config_path: str | Path,
970+
**kwargs,
971+
) -> tuple["Engine", AnomalyModule, AnomalibDataModule]:
972+
"""Create an Engine instance from a configuration file.
973+
974+
Args:
975+
config_path (str | Path): Path to the full configuration file.
976+
**kwargs (dict): Additional keyword arguments.
977+
978+
Returns:
979+
tuple[Engine, AnomalyModule, AnomalibDataModule]: Engine instance.
980+
981+
Example:
982+
The following example shows training with full configuration file:
983+
984+
.. code-block:: python
985+
>>> config_path = "anomalib_full_config.yaml"
986+
>>> engine, model, datamodule = Engine.from_config(config_path=config_path)
987+
>>> engine.fit(datamodule=datamodule, model=model)
988+
989+
The following example shows overriding the configuration file with additional keyword arguments:
990+
991+
.. code-block:: python
992+
>>> override_kwargs = {"data.train_batch_size": 8}
993+
>>> engine, model, datamodule = Engine.from_config(config_path=config_path, **override_kwargs)
994+
>>> engine.fit(datamodule=datamodule, model=model)
995+
"""
996+
from anomalib.cli.cli import AnomalibCLI
997+
998+
if not Path(config_path).exists():
999+
msg = f"Configuration file not found: {config_path}"
1000+
raise FileNotFoundError(msg)
1001+
1002+
args = [
1003+
"fit",
1004+
"--config",
1005+
str(config_path),
1006+
]
1007+
for key, value in kwargs.items():
1008+
args.extend([f"--{key}", str(value)])
1009+
anomalib_cli = AnomalibCLI(
1010+
args=args,
1011+
run=False,
1012+
)
1013+
return anomalib_cli.engine, anomalib_cli.model, anomalib_cli.datamodule

‎src/anomalib/models/components/base/anomaly_module.py

+63
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
from abc import ABC, abstractmethod
99
from collections import OrderedDict
10+
from pathlib import Path
1011
from typing import TYPE_CHECKING, Any
1112

1213
import lightning.pytorch as pl
@@ -275,3 +276,65 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
275276
"""
276277
self._transform = checkpoint["transform"]
277278
self.setup("load_checkpoint")
279+
280+
@classmethod
281+
def from_config(
282+
cls: type["AnomalyModule"],
283+
config_path: str | Path,
284+
**kwargs,
285+
) -> "AnomalyModule":
286+
"""Create a model instance from the configuration.
287+
288+
Args:
289+
config_path (str | Path): Path to the model configuration file.
290+
**kwargs (dict): Additional keyword arguments.
291+
292+
Returns:
293+
AnomalyModule: model instance.
294+
295+
Example:
296+
The following example shows how to get model from patchcore.yaml:
297+
298+
.. code-block:: python
299+
>>> model_config = "configs/model/patchcore.yaml"
300+
>>> model = AnomalyModule.from_config(config_path=model_config)
301+
302+
The following example shows overriding the configuration file with additional keyword arguments:
303+
304+
.. code-block:: python
305+
>>> override_kwargs = {"model.pre_trained": False}
306+
>>> model = AnomalyModule.from_config(config_path=model_config, **override_kwargs)
307+
"""
308+
from jsonargparse import ActionConfigFile, ArgumentParser
309+
from lightning.pytorch import Trainer
310+
311+
from anomalib import TaskType
312+
313+
if not Path(config_path).exists():
314+
msg = f"Configuration file not found: {config_path}"
315+
raise FileNotFoundError(msg)
316+
317+
model_parser = ArgumentParser()
318+
model_parser.add_argument(
319+
"-c",
320+
"--config",
321+
action=ActionConfigFile,
322+
help="Path to a configuration file in json or yaml format.",
323+
)
324+
model_parser.add_subclass_arguments(AnomalyModule, "model", required=False, fail_untyped=False)
325+
model_parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION)
326+
model_parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"])
327+
model_parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False)
328+
model_parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold")
329+
model_parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True)
330+
args = ["--config", str(config_path)]
331+
for key, value in kwargs.items():
332+
args.extend([f"--{key}", str(value)])
333+
config = model_parser.parse_args(args=args)
334+
instantiated_classes = model_parser.instantiate_classes(config)
335+
model = instantiated_classes.get("model")
336+
if isinstance(model, AnomalyModule):
337+
return model
338+
339+
msg = f"Model is not an instance of AnomalyModule: {model}"
340+
raise ValueError(msg)

‎tests/unit/data/base/base.py

+16
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,19 @@ def test_datamodule_has_dataloader_attributes(self, datamodule: AnomalibDataModu
2323
dataloader = f"{subset}_dataloader"
2424
assert hasattr(datamodule, dataloader)
2525
assert isinstance(getattr(datamodule, dataloader)(), DataLoader)
26+
27+
def test_datamodule_from_config(self, fxt_data_config_path: str) -> None:
28+
# 1. Wrong file path:
29+
with pytest.raises(FileNotFoundError):
30+
AnomalibDataModule.from_config(config_path="wrong_configs.yaml")
31+
32+
# 2. Correct file path:
33+
datamodule = AnomalibDataModule.from_config(config_path=fxt_data_config_path)
34+
assert datamodule is not None
35+
assert isinstance(datamodule, AnomalibDataModule)
36+
37+
# 3. Override batch_size & num_workers
38+
override_kwargs = {"data.train_batch_size": 1, "data.num_workers": 1}
39+
datamodule = AnomalibDataModule.from_config(config_path=fxt_data_config_path, **override_kwargs)
40+
assert datamodule.train_batch_size == 1
41+
assert datamodule.num_workers == 1

‎tests/unit/data/image/test_btech.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> BTech:
3131
_datamodule.setup()
3232

3333
return _datamodule
34+
35+
@pytest.fixture()
36+
def fxt_data_config_path(self) -> str:
37+
"""Return the path to the test data config."""
38+
return "configs/data/btech.yaml"

‎tests/unit/data/image/test_folder.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> Folder:
4141
_datamodule.setup()
4242

4343
return _datamodule
44+
45+
@pytest.fixture()
46+
def fxt_data_config_path(self) -> str:
47+
"""Return the path to the test data config."""
48+
return "configs/data/folder.yaml"

‎tests/unit/data/image/test_folder_3d.py

+17
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,20 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> Folder3D:
3838
_datamodule.setup()
3939

4040
return _datamodule
41+
42+
@pytest.fixture()
43+
def fxt_data_config_path(self) -> str:
44+
"""Return the path to the test data config."""
45+
return "configs/data/folder_3d.yaml"
46+
47+
def test_datamodule_from_config(self, fxt_data_config_path: str) -> None:
48+
"""Test method to create a datamodule from a configuration file.
49+
50+
Args:
51+
fxt_data_config_path (str): The path to the configuration file.
52+
53+
Returns:
54+
None
55+
"""
56+
pytest.skip("The configuration file does not exist.")
57+
_ = fxt_data_config_path

‎tests/unit/data/image/test_kolektor.py

+5
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> Kolektor:
3030
_datamodule.setup()
3131

3232
return _datamodule
33+
34+
@pytest.fixture()
35+
def fxt_data_config_path(self) -> str:
36+
"""Return the path to the test data config."""
37+
return "configs/data/kolektor.yaml"

‎tests/unit/data/image/test_mvtec.py

+5
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTec:
2929
_datamodule.setup()
3030

3131
return _datamodule
32+
33+
@pytest.fixture()
34+
def fxt_data_config_path(self) -> str:
35+
"""Return the path to the test data config."""
36+
return "configs/data/mvtec.yaml"

‎tests/unit/data/image/test_mvtec_3d.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> MVTec3D:
3131
_datamodule.setup()
3232

3333
return _datamodule
34+
35+
@pytest.fixture()
36+
def fxt_data_config_path(self) -> str:
37+
"""Return the path to the test data config."""
38+
return "configs/data/mvtec_3d.yaml"

‎tests/unit/data/image/test_visa.py

+5
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType) -> Visa:
3131
_datamodule.setup()
3232

3333
return _datamodule
34+
35+
@pytest.fixture()
36+
def fxt_data_config_path(self) -> str:
37+
"""Return the path to the test data config."""
38+
return "configs/data/visa.yaml"

‎tests/unit/data/video/test_avenue.py

+5
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType, clip_length_in_fra
3838
_datamodule.setup()
3939

4040
return _datamodule
41+
42+
@pytest.fixture()
43+
def fxt_data_config_path(self) -> str:
44+
"""Return the path to the test data config."""
45+
return "configs/data/avenue.yaml"

‎tests/unit/data/video/test_shanghaitech.py

+5
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType, clip_length_in_fra
3939
_datamodule.setup()
4040

4141
return _datamodule
42+
43+
@pytest.fixture()
44+
def fxt_data_config_path(self) -> str:
45+
"""Return the path to the test data config."""
46+
return "configs/data/shanghaitec.yaml"

‎tests/unit/data/video/test_ucsdped.py

+5
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,8 @@ def datamodule(self, dataset_path: Path, task_type: TaskType, clip_length_in_fra
3737
_datamodule.setup()
3838

3939
return _datamodule
40+
41+
@pytest.fixture()
42+
def fxt_data_config_path(self) -> str:
43+
"""Return the path to the test data config."""
44+
return "configs/data/ucsd_ped.yaml"

0 commit comments

Comments
 (0)