-
Notifications
You must be signed in to change notification settings - Fork 152
[Export][Transformers] Implementation of correctness validation #1935
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
5fe442c
dff8af1
8623042
eb4c1c7
1640d5f
ed60acf
f94fa53
4fbeb27
d79db1d
d5ccfc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
import os | ||
import shutil | ||
import tarfile | ||
from collections import OrderedDict | ||
from enum import Enum | ||
from pathlib import Path | ||
from typing import Any, Dict, List, Optional, Tuple, Union | ||
|
@@ -166,6 +167,17 @@ def create_data_samples( | |
else labels_ | ||
) | ||
|
||
# turn all the returned lists into a list of dicts | ||
# to facilitate the sample export | ||
if inputs and not isinstance(inputs[0], dict): | ||
inputs = [dict(input=input) for input in inputs] | ||
|
||
if labels and not isinstance(labels[0], dict): | ||
labels = [dict(label=label) for label in labels] | ||
|
||
if outputs and not isinstance(outputs[0], dict): | ||
outputs = [dict(output=output) for output in outputs] | ||
|
||
return inputs, outputs, labels | ||
|
||
|
||
|
@@ -176,23 +188,25 @@ def run_inference_with_dict_data( | |
Run inference on a model by inferring the appropriate | ||
inputs from the dictionary input data. | ||
|
||
|
||
:param data: The data to run inference on | ||
:param model: The model to run inference on (optional) | ||
:return: The inputs, labels and outputs | ||
""" | ||
labels = None | ||
if model is None: | ||
output = None | ||
|
||
else: | ||
inputs = {key: value.to(model.device) for key, value in data.items()} | ||
# move the inputs to the model device and | ||
# grab only the first sample from the batch | ||
inputs = { | ||
key: value[0].to(model.device).reshape(1, -1) for key, value in data.items() | ||
} | ||
output_vals = model(**inputs) | ||
output = { | ||
name: torch.squeeze(val).detach().to("cpu") | ||
for name, val in output_vals.items() | ||
} | ||
inputs = {key: value.to("cpu") for key, value in data.items()} | ||
inputs = {key: value.to("cpu")[0] for key, value in data.items()} | ||
return inputs, labels, output | ||
|
||
|
||
|
@@ -203,14 +217,17 @@ def run_inference_with_tuple_or_list_data( | |
Run inference on a model by inferring the appropriate | ||
inputs from the tuple input data. | ||
|
||
:param inputs: The data to run inference on | ||
:param data: The data to run inference on | ||
:param model: The model to run inference on (optional) | ||
:return: The inputs, labels and outputs | ||
""" | ||
# assume that | ||
inputs, labels = data | ||
|
||
outputs = model(inputs) if model else None | ||
if isinstance(outputs, tuple): | ||
# outputs_ contains (logits, softmax) | ||
outputs = outputs[0] | ||
# outputs_ contains (logits, scores) | ||
outputs = OrderedDict(logits=outputs[0], scores=outputs[1]) | ||
if len(inputs.size()) == 4: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's add a comment that this is IC specific |
||
# if the input is a batch, remove the batch dimension | ||
inputs = torch.squeeze(inputs, 0) | ||
return inputs, labels, outputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,15 +12,19 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import glob | ||
import logging | ||
import os.path | ||
from collections import OrderedDict | ||
from pathlib import Path | ||
from typing import List, Optional, Union | ||
from typing import Callable, List, Optional, Union | ||
|
||
import numpy | ||
import onnxruntime as ort | ||
|
||
from sparseml.export.export_data import InputsNames, LabelNames, OutputsNames | ||
from sparseml.export.helpers import ONNX_MODEL_NAME | ||
from sparsezoo.inference import InferenceRunner | ||
from sparsezoo.objects import File, NumpyDirectory | ||
from sparsezoo.utils.numpy import load_numpy | ||
|
||
|
||
__all__ = ["validate_correctness", "validate_structure"] | ||
|
@@ -98,47 +102,82 @@ def check_file_presence(file_paths: List[str]) -> List[str]: | |
return missing_files | ||
|
||
|
||
# TODO: Need to add few changes to sparsezoo to support this function | ||
def top_k_match( | ||
ground_truth: numpy.ndarray, prediction: numpy.ndarray, k: int = 2 | ||
) -> bool: | ||
""" | ||
Checks if the top k predictions match the ground truth. | ||
|
||
:param ground_truth: The ground truth array. | ||
:param prediction: The prediction array. | ||
:param k: The number of top predictions to consider. | ||
""" | ||
top_k_prediction = numpy.argsort(prediction.flatten())[-k:] | ||
top_k_ground_truth = numpy.argsort(ground_truth.flatten())[-k:] | ||
return numpy.all(top_k_prediction == top_k_ground_truth) | ||
|
||
|
||
def validate_correctness( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @bfineran this could be in the future moved to |
||
target_path: Union[str, Path], directory: Union[str, Path], onnx_model_name: str | ||
): | ||
target_path: Union[str, Path], | ||
directory: Union[str, Path], | ||
onnx_model_name: str, | ||
validation_function: Callable[..., bool] = top_k_match, | ||
) -> bool: | ||
""" | ||
Validates the correctness of the exported ONNX model by | ||
running it on a set of sample inputs and comparing the | ||
resulting outputs with precomputed ground truth values. | ||
resulting outputs using a validation function. | ||
|
||
:param target_path: The directory where the sample inputs and outputs are stored. | ||
:param directory: The directory where the ONNX model is stored. | ||
:param onnx_model_name: The name of the ONNX model. | ||
:param validation_function: The function that will be used to validate the outputs. | ||
:return: True if the validation passes, False otherwise. | ||
""" | ||
# TODO: During testing add a support for tar.gz scenario (potentially) | ||
|
||
sample_inputs_path = os.path.join(target_path, InputsNames.basename.value) | ||
sample_outputs_path = os.path.join(target_path, OutputsNames.basename.value) | ||
|
||
sample_inputs = NumpyDirectory( | ||
name=InputsNames.basename.value, | ||
files=[ | ||
File(name=file_name, path=os.path.join(sample_inputs_path, file_name)) | ||
for file_name in os.listdir(sample_inputs_path) | ||
], | ||
path=sample_inputs_path, | ||
) | ||
sample_outputs = NumpyDirectory( | ||
name=OutputsNames.basename.value, | ||
files=[ | ||
File(name=file_name, path=os.path.join(sample_outputs_path, file_name)) | ||
for file_name in os.listdir(sample_outputs_path) | ||
], | ||
path=sample_outputs_path, | ||
) | ||
onnx_model = File( | ||
name=onnx_model_name, path=os.path.join(directory, onnx_model_name) | ||
sample_inputs_files = sorted(glob.glob(os.path.join(sample_inputs_path, "*"))) | ||
sample_outputs_files = sorted(glob.glob(os.path.join(sample_outputs_path, "*"))) | ||
|
||
session = ort.InferenceSession(os.path.join(directory, onnx_model_name)) | ||
|
||
validations = ( | ||
[] | ||
) # stores boolean per sample pair (True if validation passes, False otherwise) | ||
|
||
for sample_input_file, sample_output_file in zip( | ||
sample_inputs_files, sample_outputs_files | ||
): | ||
sample_input = load_numpy(sample_input_file) | ||
sample_output = load_numpy(sample_output_file) | ||
|
||
sample_input_with_batch_dim = OrderedDict( | ||
(key, numpy.expand_dims(value, 0)) for key, value in sample_input.items() | ||
) | ||
outputs = session.run(None, sample_input_with_batch_dim) | ||
if isinstance(outputs, list): | ||
validations_sample = [] | ||
for o1, o2 in zip(outputs, sample_output.values()): | ||
validations_sample.append(validation_function(o1, o2)) | ||
validations.append(all(validations_sample)) | ||
else: | ||
validations.append(validation_function(outputs, sample_output)) | ||
|
||
if not all(validations): | ||
_LOGGER.error( | ||
f"Correctness validation failed for exported model: {onnx_model_name}. " | ||
"The model outputs match the expected outputs " | ||
f"only for {sum(validations)}/{len(validations)} samples " | ||
f"(according to the validation function: {validation_function.__name__}. " | ||
f"Some failures are expected in the case of quantized models, but not in " | ||
f"the case of non-quantized models. If in doubt, validate the performance " | ||
f"of the exported ONNX model using the NeuralMagic evaluation module." | ||
) | ||
return False | ||
|
||
_LOGGER.info( | ||
f"Successfully validated the exported model on all {len(validations)} samples." | ||
) | ||
|
||
runner = InferenceRunner( | ||
sample_inputs=sample_inputs, | ||
sample_outputs=sample_outputs, | ||
onnx_file=onnx_model, | ||
) | ||
|
||
runner.validate_with_onnx_runtime() | ||
return True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing batch_size argument from the export.
It does not matter for the model export.
It also does not matter for the sample export (by convention, all our sample inputs/outputs/labeled are stored in the "batchless" arrays, e.g. inp-0000.npz has shape
(3, 244, 244)
)