Skip to content

Commit e87f63b

Browse files
committed
Implement ScatterND operation in MO and transform for SparseToDense
SparseToDense used in Wide and Deep model is expressed through ScatterND operation. ScatterND is more functional than SparseToDense. Hence, it was decided to replace SparseToDense with ScatterND. ScatterND is more useful for other models. Remove SparseToDense from the previous opset Signed-off-by: Roman Kazantsev <[email protected]>
1 parent 1f8a8ab commit e87f63b

File tree

10 files changed

+384
-201
lines changed

10 files changed

+384
-201
lines changed

docs/ops/opset3.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ declared in `namespace opset3`.
114114
* [ROIAlign](detection/ROIAlign_3.md)
115115
* [ROIPooling](detection/ROIPooling_1.md)
116116
* [ScatterElementsUpdate](movement/ScatterElementsUpdate_3.md)
117-
* [ScatterNDUpdate](movement/ScatterNDUpdate_3.md)
118117
* [ScatterUpdate](movement/ScatterUpdate_3.md)
119118
* [Select](condition/Select_1.md)
120119
* [Selu](arithmetic/Selu_1.md)

model-optimizer/automation/package_BOM.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ extensions/front/tf/sparse_fill_empty_rows_ext.py
428428
extensions/front/tf/sparse_segment_mean_ext.py
429429
extensions/front/tf/sparse_segment_sqrtn_ext.py
430430
extensions/front/tf/sparse_segment_sum_ext.py
431-
extensions/front/tf/sparse_to_dense_ext.py
431+
extensions/front/tf/sparse_to_dense.py
432432
extensions/front/tf/split_ext.py
433433
extensions/front/tf/ssd_support.json
434434
extensions/front/tf/ssd_support_api_v1.14.json
@@ -654,6 +654,7 @@ extensions/ops/RNNCell.py
654654
extensions/ops/roialign.py
655655
extensions/ops/roifeatureextractor_onnx.py
656656
extensions/ops/scatter.py
657+
extensions/ops/scatternd.py
657658
extensions/ops/select.py
658659
extensions/ops/shufflechannel.py
659660
extensions/ops/simplernms.py
@@ -665,7 +666,6 @@ extensions/ops/sparse_reshape.py
665666
extensions/ops/sparse_segment_mean.py
666667
extensions/ops/sparse_segment_sqrtn.py
667668
extensions/ops/sparse_segment_sum.py
668-
extensions/ops/sparse_to_dense.py
669669
extensions/ops/spatial_transformer.py
670670
extensions/ops/splice.py
671671
extensions/ops/split.py
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
"""
2+
Copyright (C) 2020 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import numpy as np
18+
19+
from extensions.ops.Cast import Cast
20+
from extensions.ops.scatternd import ScatterNDUpdate
21+
from mo.front.common.replacement import FrontReplacementOp
22+
from mo.graph.graph import Node, Graph, rename_nodes
23+
from mo.ops.broadcast import Broadcast
24+
from mo.ops.const import Const
25+
26+
27+
class SparseToDense(FrontReplacementOp):
28+
"""
29+
This replacer substitutes TensorFlow SparseToDense operation with Broadcast -> ScatterND chain.
30+
The Broadcast operation creates a tensor filled with default value and of required shape.
31+
The ScatterND operation updates the created tensor with required values at required locations.
32+
"""
33+
op = "SparseToDense"
34+
enabled = True
35+
36+
def run_after(self):
37+
from extensions.front.tf.CTCGreedyDecoder import CTCGreedyDecoderReplacement
38+
return [CTCGreedyDecoderReplacement]
39+
40+
def replace_op(self, graph: Graph, node: Node):
41+
node_name = node.soft_get('name', node.id)
42+
43+
# broadcast default value to required shape
44+
broadcast_node = Broadcast(graph, {'name': node_name + '/Broadcast_'}).create_node()
45+
node.in_port(1).get_connection().set_destination(broadcast_node.in_port(1))
46+
if not node.in_port(3).disconnected():
47+
# cast default value
48+
cast_default_value = Cast(graph, {'name': node_name + '/CastDefaultValue', 'dst_type': np.int32}).create_node()
49+
node.in_port(3).get_connection().set_destination(cast_default_value.in_port(0))
50+
broadcast_node.in_port(0).connect(cast_default_value.out_port(0))
51+
#node.in_port(3).get_connection().set_destination(broadcast_node.in_port(0))
52+
else:
53+
default_value = np.float32(0)
54+
broadcast_node.in_port(0).connect(Const(graph, {'name': broadcast_node.name + '/FillValue_',
55+
'value': default_value}
56+
).create_node().out_port(0))
57+
58+
# update broadcasted tensor with required values at required locations
59+
scatternd_node = ScatterNDUpdate(graph, {'name': node_name + '/ScatterNDUpdate_'}).create_node()
60+
scatternd_node.in_port(0).connect(broadcast_node.out_port(0))
61+
node.in_port(0).get_connection().set_destination(scatternd_node.in_port(1))
62+
node.in_port(2).get_connection().set_destination(scatternd_node.in_port(2))
63+
64+
rename_nodes([(node, node_name + "/AbandonedName"), (scatternd_node, node_name)])
65+
66+
# The "explicit" version of the return value is: [(out_node.id, 0)])
67+
return [scatternd_node.id]

model-optimizer/extensions/front/tf/sparse_to_dense_ext.py

Lines changed: 0 additions & 28 deletions
This file was deleted.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
"""
2+
Copyright (C) 2020 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import unittest
18+
19+
from extensions.front.tf.sparse_to_dense import SparseToDense
20+
from mo.front.common.partial_infer.utils import int64_array
21+
from mo.utils.ir_engine.compare_graphs import compare_graphs
22+
from mo.utils.unittest.graph import build_graph
23+
from mo.utils.unittest.graph import build_graph, const
24+
25+
26+
class SparseToDenseFrontReplacersTest(unittest.TestCase):
27+
def test1(self):
28+
nodes_attributes = {
29+
'input_indices': {'shape': int64_array([5, 2]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
30+
'input_values' : {'shape': int64_array([5]), 'type': 'Parameter', 'kind': 'op', 'op': 'Parameter'},
31+
32+
'sparse_to_dense' : {'kind': 'op', 'op': 'SparseToDense'},
33+
'broadcast' : {'kind': 'op', 'op': 'Broadcast'},
34+
'scatternd' : {'kind': 'op', 'op': 'ScatterNDUpdate'},
35+
'cast_default_value': {'kind': 'op', 'op': 'Cast'},
36+
37+
'last': {'type': None, 'value': None, 'kind': 'op', 'op': 'Result'},
38+
39+
**const('input_dense_shape', int64_array([50, 40])),
40+
**const('input_default_value', int64_array(0))}
41+
42+
graph = build_graph(nodes_attributes,
43+
[('input_indices', 'sparse_to_dense', {'out': 0, 'in': 0}),
44+
('input_dense_shape', 'sparse_to_dense', {'out': 0, 'in': 1}),
45+
('input_values', 'sparse_to_dense', {'out': 0, 'in': 2}),
46+
('input_default_value', 'sparse_to_dense', {'out': 0, 'in': 3}),
47+
('sparse_to_dense', 'last', {'out': 0, 'in': 0})],
48+
nodes_with_edges_only=True)
49+
graph.stage = 'front'
50+
SparseToDense().find_and_replace_pattern(graph)
51+
52+
graph_ref = build_graph(nodes_attributes,
53+
[('input_default_value', 'cast_default_value', {'in': 0}),
54+
('cast_default_value', 'broadcast', {'in': 0}),
55+
('input_dense_shape', 'broadcast', {'in': 1}),
56+
('broadcast', 'scatternd', {'in': 0}),
57+
('input_indices', 'scatternd', {'in': 1}),
58+
('input_values', 'scatternd', {'in': 2}),
59+
('scatternd', 'last', {'in': 0})],
60+
nodes_with_edges_only=True)
61+
62+
(flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True)
63+
self.assertTrue(flag, resp)
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""
2+
Copyright (C) 2020 Intel Corporation
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import numpy as np
18+
19+
from mo.front.common.partial_infer.utils import int64_array
20+
from mo.graph.graph import Node, Graph
21+
from mo.ops.op import Op
22+
23+
24+
class ScatterNDBase(Op):
25+
enabled = False
26+
27+
op = op_type = None
28+
version = None
29+
30+
def __init__(self, graph: Graph, attrs: dict):
31+
assert self.op is not None and self.op_type is not None and self.version is not None, \
32+
'Please use specialized ScatterNDBase operation class, ScatterNDBase is base class'
33+
34+
mandatory_props = {
35+
'op': self.op,
36+
'type': self.op_type,
37+
'version': self.version,
38+
39+
'infer': self.infer,
40+
41+
'in_ports_count': 3,
42+
'out_ports_count': 1,
43+
}
44+
super().__init__(graph, mandatory_props, attrs)
45+
46+
@staticmethod
47+
def infer(node: Node):
48+
node_name = node.soft_get('name', node.id)
49+
50+
input_shape = node.in_port(0).data.get_shape()
51+
indices_shape = node.in_port(1).data.get_shape()
52+
updates_shape = node.in_port(2).data.get_shape()
53+
assert input_shape is not None and updates_shape is not None and indices_shape is not None, \
54+
'The node "{}" input shape is None'.format(node_name)
55+
56+
# check that shapes are correct
57+
# 1. ranks of both input and indices must be at least 1
58+
assert len(input_shape) >= 1 and len(indices_shape) >= 1, \
59+
'The node "{}" input and indices ranks must be at least 1'.format(node_name)
60+
61+
# 2. the last dimension of indices shape must be at most a rank of input
62+
assert indices_shape[-1] <= len(input_shape), \
63+
'The last dimension of indices shape must be at most a rank of input for the node "{}"'.format(node_name)
64+
65+
# 3. updates is a tensor of shape indices_shape[:-1] + input_shape[indices_shape[-1]:]
66+
expected_updates_shape = np.concatenate((indices_shape[:-1], input_shape[indices_shape[-1]:]), axis=0)
67+
if np.array_equal(expected_updates_shape, int64_array([])):
68+
expected_updates_shape = 0 # updates is a scalar
69+
assert np.array_equal(updates_shape, expected_updates_shape), \
70+
'The updates shape must be equal to indices_shape[:-1] + input_shape[indices_shape[-1]:] for the node "{}"'.format(node_name)
71+
72+
node.out_port(0).data.set_shape(input_shape)
73+
74+
@staticmethod
75+
def type_infer(node: Node):
76+
assert node.in_port(0).get_source().get_data_type() == node.in_port(2).get_source().get_data_type(), \
77+
'The data type of the first and the third inputs must be equal for the node {}'.format(node.name)
78+
node.out_port(0).set_data_type(node.in_port(0).get_source().get_data_type())
79+
80+
81+
class ScatterNDUpdate(ScatterNDBase):
82+
op = op_type = 'ScatterNDUpdate'
83+
version = 'opset4'
84+
85+
@staticmethod
86+
def infer(node: Node):
87+
ScatterNDBase.infer(node)
88+
89+
input_value = node.in_port(0).data.get_value()
90+
indices_shape = node.in_port(1).data.get_shape()
91+
indices_value = node.in_port(1).data.get_value()
92+
updates_value = node.in_port(2).data.get_value()
93+
94+
# compute output value if all input is constant
95+
if input_value is not None and indices_value is not None and updates_value is not None:
96+
output_value = input_value.copy()
97+
indx_range = int64_array(indices_shape[:-1])
98+
for indx in np.ndindex(tuple(indx_range)):
99+
output_value[indices_value[indx]] = updates_value[indx]
100+
101+
node.out_port(0).data.set_value(output_value)

0 commit comments

Comments
 (0)