Skip to content

Support XAI Method Abstract Interface #7

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GETTING_STARTED.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@
"\n",
"`White Box explainer` can be configured with the following parameters:\n",
"- `target_layer` - specifies the layer after which the XAI nodes should be inserted (the last convolutional layer is a good default option). Example: `/backbone/conv/conv.2/Div`. This parameter can be useful if `WhiteBoxExplainer` fails to find a place where to insert XAI branch.\n",
"- `embed_normalization` - **default True** (for speed purposes), but you can disable embedding of normalization into the model.\n",
"- `embed_scale` - **default True** (for speed purposes), but you can disable embedding of normalization into the model.\n",
"- `explain_method` - **default reciprocam**:\n",
"\n",
" For Classification models `White Box` algorithm supports 2 `Method`:\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ model = ov.Core().read_model("path/to/model.xml") # type: ov.Model
# Optional - create insertion parameters
insertion_parameters = ClassificationInsertionParameters(
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
embed_normalization=True, # True by default. If set to True, saliency map normalization is embedded in the model
embed_scale=True, # True by default. If set to True, saliency map normalization is embedded in the model
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
)

Expand Down Expand Up @@ -221,7 +221,7 @@ model = ov.Core().read_model("path/to/model.xml") # type: ov.Model
# Optional - create insertion parameters
insertion_parameters = ClassificationInsertionParameters(
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
embed_normalization=True, # True by default. If set to True, saliency map normalization is embedded in the model
embed_scale=True, # True by default. If set to True, saliency map normalization is embedded in the model
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
)

Expand Down
7 changes: 5 additions & 2 deletions examples/run_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
ExplainMode,
ExplanationParameters,
TargetExplainGroup,
VisualizationParameters,
)
from openvino_xai.inserter.parameters import ClassificationInsertionParameters

Expand Down Expand Up @@ -94,7 +95,7 @@ def explain_white_box(args):
# target_layer="last_conv_node_name", # target_layer - node after which XAI branch will be inserted
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
embed_normalization=True, # True by default. If set to True, saliency map normalization is embedded in the model
embed_scale=True, # True by default. If set to True, saliency map normalization is embedded in the model
explain_method=xai.Method.RECIPROCAM, # ReciproCAM is the default XAI method for CNNs
)

Expand All @@ -115,6 +116,7 @@ def explain_white_box(args):
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
target_explain_labels=[11, 14], # target classes to explain, also ['dog', 'person'] is a valid input
label_names=voc_labels, # optional names
visualization_parameters=VisualizationParameters(overlay=True)
)

# Generate explanation
Expand Down Expand Up @@ -158,6 +160,7 @@ def explain_black_box(args):
target_explain_group=TargetExplainGroup.CUSTOM, # CUSTOM list of classes to explain, also ALL possible
target_explain_labels=['dog', 'person'], # target classes to explain, also [11, 14] possible
label_names=voc_labels, # optional names
visualization_parameters=VisualizationParameters(overlay=True)
)

# Generate explanation
Expand Down Expand Up @@ -306,7 +309,7 @@ def insert_xai_w_params(args):
insertion_parameters = ClassificationInsertionParameters(
target_layer="/backbone/conv/conv.2/Div", # OTX mnet_v3
# target_layer="/backbone/features/final_block/activate/Mul", # OTX effnet
embed_normalization=True,
embed_scale=True,
explain_method=xai.Method.RECIPROCAM,
)

Expand Down
4 changes: 2 additions & 2 deletions notebooks/xai_classification/xai_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@
"\n",
"If automatic search for correct node fails, you can set up a correct node manually with `target_layer` argument. For classification it's the last backbone node with shape [-1, num_channels, feature_map_height, feature_map_width]. For example, for MobileNetV3 it will be `/backbone/conv/conv.2/Div` layer with [-1, 960, 7, 7] input shape.\n",
"\n",
"`embed_normalization` **default True** (for speed purposes), this parameter adds normalization to the XAI branch, which results in being able to visualize saliency maps right away without further postprocessing.\n",
"`embed_scale` **default True** (for speed purposes), this parameter adds normalization to the XAI branch, which results in being able to visualize saliency maps right away without further postprocessing.\n",
"\n",
"`explain_method` can be:\n",
"\n",
Expand Down Expand Up @@ -340,7 +340,7 @@
" target_layer=\"/blocks/blocks.6/blocks.6.0/bn1/act/HardSwish\",\n",
" # target_layer=\"/backbone/conv/conv.2/Div\", # OTX mnet_v3\n",
" # target_layer=\"/backbone/features/final_block/activate/Mul\", # OTX effnet\n",
" embed_normalization=True,\n",
" embed_scale=True,\n",
" explain_method=Method.RECIPROCAM,\n",
")\n",
"\n",
Expand Down
142 changes: 58 additions & 84 deletions openvino_xai/explainer/explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,8 @@
import numpy as np
import openvino.runtime as ov

import openvino_xai
from openvino_xai import Task
from openvino_xai.common.utils import (
SALIENCY_MAP_OUTPUT_NAME,
IdentityPreprocessFN,
has_xai,
logger,
)
from openvino_xai.common.utils import IdentityPreprocessFN, logger
from openvino_xai.explainer.explanation import Explanation
from openvino_xai.explainer.parameters import (
ExplainMode,
Expand All @@ -23,12 +17,16 @@
from openvino_xai.explainer.utils import get_explain_target_indices
from openvino_xai.explainer.visualizer import Visualizer
from openvino_xai.inserter.parameters import InsertionParameters
from openvino_xai.methods.black_box.black_box_methods import RISE
from openvino_xai.methods.base import BlackBoxXAIMethodBase, MethodBase
from openvino_xai.methods.create_method import (
BlackBoxMethodFactory,
WhiteBoxMethodFactory,
)


class Explainer:
"""
Explainer sets explain mode, prepares the model, and generates explanations.
Explainer creates methods uses them to generate explanations.

Usage:
explanation = explainer_object(data, explanation_parameters)
Expand All @@ -39,7 +37,7 @@ class Explainer:
:type task: Task
:param preprocess_fn: Preprocessing function, identity function by default
(assume input images are already preprocessed by user).
:type preprocess_fn: Callable[[np.ndarray], np.ndarray] | IdentityPreprocessFN
:type preprocess_fn: Callable[[np.ndarray], np.ndarray]
:param postprocess_fn: Postprocessing functions, required for black-box.
:type postprocess_fn: Callable[[ov.utils.data_helpers.wrappers.OVDict], np.ndarray]
:param explain_mode: Explain mode.
Expand Down Expand Up @@ -74,60 +72,61 @@ def __init__(

self.explain_mode = explain_mode

self._set_explain_mode()

self._load_model()

def _set_explain_mode(self) -> None:
if self.explain_mode == ExplainMode.WHITEBOX:
if has_xai(self.model):
logger.info("Model already has XAI - using white-box mode.")
else:
self._insert_xai()
logger.info("Explaining the model in the white-box mode.")
self.method: MethodBase = None
self.create_method(self.explain_mode, self.task)

def create_method(self, explain_mode: ExplainMode, task: Task) -> None:
if explain_mode == ExplainMode.WHITEBOX:
try:
self.method = WhiteBoxMethodFactory.create_method(
task, self.model, self.preprocess_fn, self.insertion_parameters
)
logger.info("Explaining the model in white-box mode.")
except Exception as e:
print(e)
raise RuntimeError("Failed to insert XAI into the model. Try to use black-box.")
elif self.explain_mode == ExplainMode.BLACKBOX:
if self.postprocess_fn is None:
raise ValueError("Postprocess function has to be provided for the black-box mode.")
logger.info("Explaining the model in the black-box mode.")
self._check_postprocess_fn()
self.method = BlackBoxMethodFactory.create_method(task, self.model, self.preprocess_fn, self.postprocess_fn)
elif self.explain_mode == ExplainMode.AUTO:
if has_xai(self.model):
logger.info("Model already has XAI - using white-box mode.")
self.explain_mode = ExplainMode.WHITEBOX
else:
try:
self._insert_xai()
self.explain_mode = ExplainMode.WHITEBOX
logger.info("Explaining the model in the white-box mode.")
except Exception as e:
print(e)
logger.info("Failed to insert XAI into the model - use black-box mode.")
if self.postprocess_fn is None:
raise ValueError("Postprocess function has to be provided for the black-box mode.")
self.explain_mode = ExplainMode.BLACKBOX
logger.info("Explaining the model in the black-box mode.")
try:
self.method = WhiteBoxMethodFactory.create_method(
task, self.model, self.preprocess_fn, self.insertion_parameters
)
logger.info("Explaining the model in the white-box mode.")
except Exception as e:
print(e)
logger.info("Failed to insert XAI into the model - using black-box mode.")
self._check_postprocess_fn()
self.method = BlackBoxMethodFactory.create_method(
task, self.model, self.preprocess_fn, self.postprocess_fn
)
logger.info("Explaining the model in the black-box mode.")
else:
raise ValueError(f"Not supported explain mode {self.explain_mode}.")

def _insert_xai(self) -> None:
logger.info("Model does not have XAI - trying to insert XAI to use white-box mode.")
# Do we need to keep the original model?
self.model = openvino_xai.insert_xai(self.model, self.task, self.insertion_parameters)

def _load_model(self) -> None:
self.compiled_model = ov.Core().compile_model(self.model, "CPU")

def __call__(
self,
data: np.ndarray,
explanation_parameters: ExplanationParameters,
**kwargs,
) -> Explanation:
"""Explainer call that generates processed explanation result."""
# TODO (negvet): support output_shape as argument among other post process parameters
if self.explain_mode == ExplainMode.WHITEBOX:
saliency_map = self._generate_saliency_map_white_box(data)
else:
saliency_map = self._generate_saliency_map_black_box(data, explanation_parameters, **kwargs)
explain_target_indices = None
if (
isinstance(self.method, BlackBoxXAIMethodBase)
and explanation_parameters.target_explain_group == TargetExplainGroup.CUSTOM
):
explain_target_indices = get_explain_target_indices(
explanation_parameters.target_explain_labels,
explanation_parameters.label_names,
)

saliency_map = self.method.generate_saliency_map(
data,
explain_target_indices=explain_target_indices, # type: ignore
**kwargs,
)

explanation = Explanation(
saliency_map=saliency_map,
Expand All @@ -137,38 +136,9 @@ def __call__(
)
return self._visualize(explanation, data, explanation_parameters)

def model_forward(self, x: np.ndarray) -> ov.utils.data_helpers.wrappers.OVDict:
"""Forward pass of the compiled model. Applies preprocess_fn."""
x = self.preprocess_fn(x)
return self.compiled_model(x)

def _generate_saliency_map_white_box(self, data: np.ndarray) -> np.ndarray:
model_output = self.model_forward(data)
return model_output[SALIENCY_MAP_OUTPUT_NAME]

def _generate_saliency_map_black_box(
self,
data: np.ndarray,
explanation_parameters: ExplanationParameters,
**kwargs,
) -> np.ndarray:
explain_target_indices = None
if explanation_parameters.target_explain_group == TargetExplainGroup.CUSTOM:
explain_target_indices = get_explain_target_indices(
explanation_parameters.target_explain_labels,
explanation_parameters.label_names,
)
if self.task == Task.CLASSIFICATION:
saliency_map = RISE.run(
self.compiled_model,
self.preprocess_fn,
self.postprocess_fn,
data,
explain_target_indices,
**kwargs,
)
return saliency_map
raise ValueError(f"Task type {self.task} is not supported in the black-box mode.")
def model_forward(self, x: np.ndarray, preprocess: bool = True) -> ov.utils.data_helpers.wrappers.OVDict:
"""Forward pass of the compiled model."""
return self.method.model_forward(x, preprocess)

def _visualize(
self, explanation: Explanation, data: np.ndarray, explanation_parameters: ExplanationParameters
Expand All @@ -188,3 +158,7 @@ def _visualize(
visualization_parameters=explanation_parameters.visualization_parameters,
).run()
return explanation

def _check_postprocess_fn(self) -> None:
if self.postprocess_fn is None:
raise ValueError("Postprocess function has to be provided for the black-box mode.")
38 changes: 21 additions & 17 deletions openvino_xai/inserter/inserter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from openvino.preprocess import PrePostProcessor

from openvino_xai import Task
from openvino_xai.common.utils import SALIENCY_MAP_OUTPUT_NAME, has_xai, logger
from openvino_xai.inserter.parameters import InsertionParameters
from openvino_xai.methods.white_box.create_method import (
create_white_box_classification_explain_method,
create_white_box_detection_explain_method,
from openvino_xai.common.utils import (
SALIENCY_MAP_OUTPUT_NAME,
IdentityPreprocessFN,
has_xai,
logger,
)
from openvino_xai.inserter.parameters import InsertionParameters


def insert_xai(
Expand All @@ -33,12 +34,21 @@ def insert_xai(
:type insertion_parameters: InsertionParameters
:return: IR with XAI branch.
"""
from openvino_xai.methods.create_method import WhiteBoxMethodFactory

if has_xai(model):
logger.info("Provided IR model already contains XAI branch, return it as-is.")
return model

model_xai = _insert_xai_branch_into_model(model, task, insertion_parameters)
method = WhiteBoxMethodFactory.create_method(
task=task,
model=model,
preprocess_fn=IdentityPreprocessFN(),
insertion_parameters=insertion_parameters,
prepare_model=False,
)

model_xai = method.prepare_model(load_model=False)

if not has_xai(model_xai):
raise RuntimeError("Insertion of the XAI branch into the model was not successful.")
Expand All @@ -47,23 +57,17 @@ def insert_xai(
return model_xai


def _insert_xai_branch_into_model(
model: ov.Model, task: Task, insertion_parameters: InsertionParameters | None
def insert_xai_branch_into_model(
model: ov.Model,
xai_output_node,
set_uint8,
) -> ov.Model:
if task == Task.CLASSIFICATION:
explain_method = create_white_box_classification_explain_method(model, insertion_parameters) # type: ignore
elif task == Task.DETECTION:
explain_method = create_white_box_detection_explain_method(model, insertion_parameters) # type: ignore
else:
raise ValueError(f"Model type {task} is not supported")

xai_output_node = explain_method.generate_xai_branch()
"""Creates new model with XAI branch."""
model_ori_outputs = model.outputs
model_ori_params = model.get_parameters()
model_xai = ov.Model([*model_ori_outputs, xai_output_node.output(0)], model_ori_params)

xai_output_index = len(model_ori_outputs)
set_uint8 = explain_method.embed_normalization # TODO: make a property
model_xai = _set_xai_output_name_and_precision(model_xai, xai_output_index, set_uint8)
return model_xai

Expand Down
Loading
Loading