Skip to content

Add OpenVINO backend support for argmin and argmax #21060

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Apr 9, 2025
6 changes: 1 addition & 5 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ NumpyDtypeTest::test_absolute_bool
NumpyDtypeTest::test_add_
NumpyDtypeTest::test_all
NumpyDtypeTest::test_any
NumpyDtypeTest::test_argmax
NumpyDtypeTest::test_argmin
NumpyDtypeTest::test_argpartition
NumpyDtypeTest::test_array
NumpyDtypeTest::test_bitwise
Expand Down Expand Up @@ -77,8 +75,6 @@ NumpyDtypeTest::test_square_bool
HistogramTest
NumpyOneInputOpsCorrectnessTest::test_all
NumpyOneInputOpsCorrectnessTest::test_any
NumpyOneInputOpsCorrectnessTest::test_argmax
NumpyOneInputOpsCorrectnessTest::test_argmin
NumpyOneInputOpsCorrectnessTest::test_argpartition
NumpyOneInputOpsCorrectnessTest::test_array
NumpyOneInputOpsCorrectnessTest::test_bitwise_invert
Expand Down Expand Up @@ -161,4 +157,4 @@ NumpyTwoInputOpsCorrectnessTest::test_quantile
NumpyTwoInputOpsCorrectnessTest::test_take_along_axis
NumpyTwoInputOpsCorrectnessTest::test_tensordot
NumpyTwoInputOpsCorrectnessTest::test_vdot
NumpyTwoInputOpsCorrectnessTest::test_where
NumpyTwoInputOpsCorrectnessTest::test_where
60 changes: 58 additions & 2 deletions keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,11 +328,67 @@ def arctanh(x):


def argmax(x, axis=None, keepdims=False):
raise NotImplementedError("`argmax` is not supported with openvino backend")
x = get_ov_output(x)
x_shape = x.get_partial_shape()
rank = x_shape.rank.get_length()
if rank == 0:
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
if axis is None:
flatten_shape = ov_opset.constant(
[-1] + [1] * (rank - 1), Type.i32
).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
axis = 0
k = ov_opset.constant(1, Type.i32).output(0)
else:
if axis < 0:
axis = rank + axis
k = ov_opset.constant(1, Type.i32).output(0)
topk_outputs = ov_opset.topk(
x,
k=k,
axis=axis,
mode="max",
sort="value",
stable=True,
index_element_type=Type.i32,
)
topk_indices = topk_outputs.output(1)
if not keepdims:
topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)
return OpenVINOKerasTensor(topk_indices)


def argmin(x, axis=None, keepdims=False):
raise NotImplementedError("`argmin` is not supported with openvino backend")
x = get_ov_output(x)
x_shape = x.get_partial_shape()
rank = x_shape.rank.get_length()
if rank == 0:
return OpenVINOKerasTensor(ov_opset.constant([0], Type.i32).output(0))
if axis is None:
flatten_shape = ov_opset.constant(
[-1] + [1] * (rank - 1), Type.i32
).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
axis = 0
k = ov_opset.constant(1, Type.i32).output(0)
else:
if axis < 0:
axis = rank + axis
k = ov_opset.constant(1, Type.i32).output(0)
topk_outputs = ov_opset.topk(
x,
k=k,
axis=axis,
mode="min",
sort="value",
stable=True,
index_element_type=Type.i32,
)
topk_indices = topk_outputs.output(1)
if not keepdims:
topk_indices = ov_opset.squeeze(topk_indices, [axis]).output(0)
return OpenVINOKerasTensor(topk_indices)


def argsort(x, axis=-1):
Expand Down