Skip to content

Add sanity tests for example scripts #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/run_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def explain_white_box(args):
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
# target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
embed_scaling=True, # True by default. If set to True, saliency map scale (0 ~ 255) operation is embedded in the model
)
Expand All @@ -103,7 +103,7 @@ def explain_white_box(args):

# Generate explanation
explanation = explainer(
image,
image,
targets=[11, 14], # target classes to explain, also ['dog', 'person'] is a valid input
label_names=voc_labels, # optional names
overlay=True,
Expand Down Expand Up @@ -230,7 +230,7 @@ def explain_white_box_vit(args):

# Generate explanation
explanation = explainer(
image,
image,
targets=[0, 1, 2, 3, 4], # target classes to explain
)

Expand Down
28 changes: 14 additions & 14 deletions examples/run_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def get_argument_parser():

def preprocess_fn(x: np.ndarray) -> np.ndarray:
# TODO: make sure it is correct
x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
# x = cv2.resize(src=x, dsize=(992, 736)) # OTX ATSS
# x = cv2.resize(src=x, dsize=(416, 416)) # OTX YOLOX
x = cv2.resize(src=x, dsize=(992, 736)) # OTX ATSS
x = x.transpose((2, 0, 1))
x = np.expand_dims(x, 0)
return x
Expand All @@ -45,18 +45,18 @@ def main(argv):
model = ov.Core().read_model(args.model_path)

# OTX YOLOX
cls_head_output_node_names = [
"/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases",
"/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases",
"/bbox_head/multi_level_conv_cls.2/Conv/WithoutBiases",
]
# # OTX ATSS
# cls_head_output_node_names = [
# "/bbox_head/atss_cls_1/Conv/WithoutBiases",
# "/bbox_head/atss_cls_2/Conv/WithoutBiases",
# "/bbox_head/atss_cls_3/Conv/WithoutBiases",
# "/bbox_head/atss_cls_4/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.0/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.1/Conv/WithoutBiases",
# "/bbox_head/multi_level_conv_cls.2/Conv/WithoutBiases",
# ]
# OTX ATSS
cls_head_output_node_names = [
"/bbox_head/atss_cls_1/Conv/WithoutBiases",
"/bbox_head/atss_cls_2/Conv/WithoutBiases",
"/bbox_head/atss_cls_3/Conv/WithoutBiases",
"/bbox_head/atss_cls_4/Conv/WithoutBiases",
]

# Create explainer object
explainer = xai.Explainer(
Expand All @@ -73,8 +73,8 @@ def main(argv):

# Generate explanation
explanation = explainer(
image,
targets=[0, 1, 2, 3, 4], # target classes to explain
image,
targets=[0, 1, 2], # target classes to explain
)

logger.info(
Expand Down
20 changes: 20 additions & 0 deletions tests/intg/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import subprocess
from pathlib import Path

import cv2
Expand Down Expand Up @@ -412,3 +413,22 @@ def test_classification_black_box_xai_model_as_input(self):
actual_sal_vals = explanation.saliency_map[0][0, :10].astype(np.int16)
ref_sal_vals = self._ref_sal_maps[DEFAULT_CLS_MODEL].astype(np.uint8)
assert np.all(np.abs(actual_sal_vals - ref_sal_vals) <= 1)


class TestExample:
"""Test sanity of examples/run_classification.py."""

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
self.data_dir = fxt_data_root

def test_default_model(self):
retrieve_otx_model(self.data_dir, DEFAULT_CLS_MODEL)
model_path = self.data_dir / "otx_models" / (DEFAULT_CLS_MODEL + ".xml")
cmd = [
"python",
"examples/run_classification.py",
model_path,
"tests/assets/cheetah_person.jpg",
]
subprocess.run(cmd, check=True) # noqa: S603, PLW1510
64 changes: 64 additions & 0 deletions tests/intg/test_classification_timm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import cv2
import numpy as np
import openvino
import openvino.runtime as ov
import pytest

Expand Down Expand Up @@ -322,6 +323,69 @@ def test_classification_black_box(self, model_id, dump_maps=False):
self.update_report("report_bb.csv", model_id, "True", "True", "True", shape_str, str(map_saved))
self.clear_cache()

@pytest.mark.parametrize(
"model_id",
[
"resnet18.a1_in1k",
"vit_tiny_patch16_224.augreg_in21k", # Downloads last month 15,345
],
)
# @pytest.mark.parametrize("model_id", TEST_MODELS)
def test_ovc_ir_insertion(self, model_id):
if model_id in NON_SUPPORTED_BY_WB_MODELS:
pytest.skip(reason="Not supported yet")

if "convit_tiny.fb_in1k" in model_id:
pytest.skip(
reason="RuntimeError: Couldn't get TorchScript module by tracing."
) # Torch -> OV conversion error

timm_model, model_cfg = self.get_timm_model(model_id)
input_size = list(timm_model.default_cfg["input_size"])
dummy_tensor = torch.rand([1] + input_size)
model = openvino.convert_model(
timm_model, example_input=dummy_tensor, input=(ov.PartialShape([-1] + input_size),)
)

if model_id in LIMITED_DIVERSE_SET_OF_CNN_MODELS:
explain_method = Method.RECIPROCAM
elif model_id in LIMITED_DIVERSE_SET_OF_VISION_TRANSFORMER_MODELS:
explain_method = Method.VITRECIPROCAM
else:
raise ValueError

mean_values = [(item * 255) for item in model_cfg["mean"]]
scale_values = [(item * 255) for item in model_cfg["std"]]
preprocess_fn = get_preprocess_fn(
change_channel_order=True,
input_size=model_cfg["input_size"][1:],
mean=mean_values,
std=scale_values,
hwc_to_chw=True,
)

explainer = Explainer(
model=model,
task=Task.CLASSIFICATION,
preprocess_fn=preprocess_fn,
explain_mode=ExplainMode.WHITEBOX, # defaults to AUTO
explain_method=explain_method,
embed_scaling=False,
)

target_class = self.supported_num_classes[model_cfg["num_classes"]]
image = cv2.imread("tests/assets/cheetah_person.jpg")
explanation = explainer(
image,
targets=[target_class],
resize=False,
colormap=False,
)

assert explanation is not None
assert explanation.shape[-1] > 1 and explanation.shape[-2] > 1
print(f"{model_id}: Generated classification saliency maps with shape {explanation.shape}.")

def check_for_saved_map(self, model_id, directory):
for target in self.supported_num_classes.values():
map_name = model_id + "_target_" + str(target) + ".jpg"
Expand Down
20 changes: 20 additions & 0 deletions tests/intg/test_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import subprocess
from pathlib import Path

import addict
Expand Down Expand Up @@ -231,3 +232,22 @@ def get_default_model(self):
model_path = self.data_dir / "otx_models" / (DEFAULT_DET_MODEL + ".xml")
model = ov.Core().read_model(model_path)
return model


class TestExample:
"""Test sanity of examples/run_detection.py."""

@pytest.fixture(autouse=True)
def setup(self, fxt_data_root):
self.data_dir = fxt_data_root

def test_default_model(self):
retrieve_otx_model(self.data_dir, DEFAULT_DET_MODEL)
model_path = self.data_dir / "otx_models" / (DEFAULT_DET_MODEL + ".xml")
cmd = [
"python",
"examples/run_detection.py",
model_path,
"tests/assets/blood.jpg",
]
subprocess.run(cmd, check=True) # noqa: S603, PLW1510
Loading