|
7 | 7 | import logging
|
8 | 8 | from abc import ABC, abstractmethod
|
9 | 9 | from collections import OrderedDict
|
| 10 | +from pathlib import Path |
10 | 11 | from typing import TYPE_CHECKING, Any
|
11 | 12 |
|
12 | 13 | import lightning.pytorch as pl
|
@@ -275,3 +276,65 @@ def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
|
275 | 276 | """
|
276 | 277 | self._transform = checkpoint["transform"]
|
277 | 278 | self.setup("load_checkpoint")
|
| 279 | + |
| 280 | + @classmethod |
| 281 | + def from_config( |
| 282 | + cls: type["AnomalyModule"], |
| 283 | + config_path: str | Path, |
| 284 | + **kwargs, |
| 285 | + ) -> "AnomalyModule": |
| 286 | + """Create a model instance from the configuration. |
| 287 | +
|
| 288 | + Args: |
| 289 | + config_path (str | Path): Path to the model configuration file. |
| 290 | + **kwargs (dict): Additional keyword arguments. |
| 291 | +
|
| 292 | + Returns: |
| 293 | + AnomalyModule: model instance. |
| 294 | +
|
| 295 | + Example: |
| 296 | + The following example shows how to get model from patchcore.yaml: |
| 297 | +
|
| 298 | + .. code-block:: python |
| 299 | + >>> model_config = "configs/model/patchcore.yaml" |
| 300 | + >>> model = AnomalyModule.from_config(config_path=model_config) |
| 301 | +
|
| 302 | + The following example shows overriding the configuration file with additional keyword arguments: |
| 303 | +
|
| 304 | + .. code-block:: python |
| 305 | + >>> override_kwargs = {"model.pre_trained": False} |
| 306 | + >>> model = AnomalyModule.from_config(config_path=model_config, **override_kwargs) |
| 307 | + """ |
| 308 | + from jsonargparse import ActionConfigFile, ArgumentParser |
| 309 | + from lightning.pytorch import Trainer |
| 310 | + |
| 311 | + from anomalib import TaskType |
| 312 | + |
| 313 | + if not Path(config_path).exists(): |
| 314 | + msg = f"Configuration file not found: {config_path}" |
| 315 | + raise FileNotFoundError(msg) |
| 316 | + |
| 317 | + model_parser = ArgumentParser() |
| 318 | + model_parser.add_argument( |
| 319 | + "-c", |
| 320 | + "--config", |
| 321 | + action=ActionConfigFile, |
| 322 | + help="Path to a configuration file in json or yaml format.", |
| 323 | + ) |
| 324 | + model_parser.add_subclass_arguments(AnomalyModule, "model", required=False, fail_untyped=False) |
| 325 | + model_parser.add_argument("--task", type=TaskType | str, default=TaskType.SEGMENTATION) |
| 326 | + model_parser.add_argument("--metrics.image", type=list[str] | str | None, default=["F1Score", "AUROC"]) |
| 327 | + model_parser.add_argument("--metrics.pixel", type=list[str] | str | None, default=None, required=False) |
| 328 | + model_parser.add_argument("--metrics.threshold", type=BaseThreshold | str, default="F1AdaptiveThreshold") |
| 329 | + model_parser.add_class_arguments(Trainer, "trainer", fail_untyped=False, instantiate=False, sub_configs=True) |
| 330 | + args = ["--config", str(config_path)] |
| 331 | + for key, value in kwargs.items(): |
| 332 | + args.extend([f"--{key}", str(value)]) |
| 333 | + config = model_parser.parse_args(args=args) |
| 334 | + instantiated_classes = model_parser.instantiate_classes(config) |
| 335 | + model = instantiated_classes.get("model") |
| 336 | + if isinstance(model, AnomalyModule): |
| 337 | + return model |
| 338 | + |
| 339 | + msg = f"Model is not an instance of AnomalyModule: {model}" |
| 340 | + raise ValueError(msg) |
0 commit comments