1
+ // Copyright (C) 2020 Intel Corporation
2
+ // SPDX-License-Identifier: Apache-2.0
3
+ //
4
+
5
+ #include < memory>
6
+
7
+ // ! [ngraph:include]
8
+ #include < ngraph/opsets/opset3.hpp>
9
+ // ! [ngraph:include]
10
+
11
+ #include < ngraph/function.hpp>
12
+ #include < ngraph/pattern/op/label.hpp>
13
+ #include < ngraph/rt_info.hpp>
14
+ #include < ngraph/pass/graph_rewrite.hpp>
15
+ #include < ngraph/pass/visualize_tree.hpp>
16
+
17
+ using namespace ngraph ;
18
+
19
+ // ! [ngraph_utils:simple_function]
20
+ std::shared_ptr<ngraph::Function> create_simple_function () {
21
+ // This example shows how to create ngraph::Function
22
+ //
23
+ // Parameter--->Multiply--->Add--->Result
24
+ // Constant---' /
25
+ // Constant---'
26
+
27
+ // Create opset3::Parameter operation with static shape
28
+ auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32 , ngraph::Shape{3 , 1 , 2 });
29
+
30
+ auto mul_constant = ngraph::opset3::Constant::create (ngraph::element::f32 , ngraph::Shape{1 }, {1.5 });
31
+ auto mul = std::make_shared<ngraph::opset3::Multiply>(data, mul_constant);
32
+
33
+ auto add_constant = ngraph::opset3::Constant::create (ngraph::element::f32 , ngraph::Shape{1 }, {0.5 });
34
+ auto add = std::make_shared<ngraph::opset3::Add>(mul, add_constant);
35
+
36
+ // Create opset3::Result operation
37
+ auto res = std::make_shared<ngraph::opset3::Result>(mul);
38
+
39
+ // Create nGraph function
40
+ return std::make_shared<ngraph::Function>(ngraph::ResultVector{res}, ngraph::ParameterVector{data});
41
+ }
42
+ // ! [ngraph_utils:simple_function]
43
+
44
+ // ! [ngraph_utils:advanced_function]
45
+ std::shared_ptr<ngraph::Function> create_advanced_function () {
46
+ // Advanced example with multi output operation
47
+ //
48
+ // Parameter->Split---0-->Result
49
+ // | `--1-->Relu-->Result
50
+ // `----2-->Result
51
+
52
+ auto data = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32 , ngraph::Shape{1 , 3 , 64 , 64 });
53
+
54
+ // Create Constant for axis value
55
+ auto axis_const = ngraph::opset3::Constant::create (ngraph::element::i64 , ngraph::Shape{}/* scalar shape*/ , {1 });
56
+
57
+ // Create opset3::Split operation that splits input to three slices across 1st dimension
58
+ auto split = std::make_shared<ngraph::opset3::Split>(data, axis_const, 3 );
59
+
60
+ // Create opset3::Relu operation that takes 1st Split output as input
61
+ auto relu = std::make_shared<ngraph::opset3::Relu>(split->output (1 )/* specify explicit output*/ );
62
+
63
+ // Results operations will be created automatically based on provided OutputVector
64
+ return std::make_shared<ngraph::Function>(ngraph::OutputVector{split->output (0 ), relu, split->output (2 )}, ngraph::ParameterVector{data});
65
+ }
66
+ // ! [ngraph_utils:advanced_function]
67
+
68
+ void pattern_matcher_examples () {
69
+ {
70
+ // ! [pattern:simple_example]
71
+ // Pattern example
72
+ auto input = std::make_shared<ngraph::opset3::Parameter>(element::i64 , Shape{1 });
73
+ auto shapeof = std::make_shared<ngraph::opset3::ShapeOf>(input);
74
+
75
+ // Create Matcher with Parameter->ShapeOf pattern
76
+ auto m = std::make_shared<ngraph::pattern::Matcher>(shapeof, " MyPatternBasedTransformation" );
77
+ // ! [pattern:simple_example]
78
+
79
+ // ! [pattern:callback_example]
80
+ ngraph::graph_rewrite_callback callback = [](pattern::Matcher& m) {
81
+ // Get root node
82
+ std::shared_ptr<Node> root_node = m.get_match_root ();
83
+
84
+ // Get all nodes matched by pattern
85
+ NodeVector nodes = m.get_matched_nodes ();
86
+
87
+ // Transformation code
88
+ return false ;
89
+ };
90
+ // ! [pattern:callback_example]
91
+ }
92
+
93
+ {
94
+ // ! [pattern:label_example]
95
+ // Detect Multiply with arbitrary first input and second as Constant
96
+ // ngraph::pattern::op::Label - represent arbitrary input
97
+ auto input = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32 , ngraph::Shape{1 });
98
+ auto value = ngraph::opset3::Constant::create (ngraph::element::f32 , ngraph::Shape{1 }, {0.5 });
99
+ auto mul = std::make_shared<ngraph::opset3::Multiply>(input, value);
100
+ auto m = std::make_shared<ngraph::pattern::Matcher>(mul, " MultiplyMatcher" );
101
+ // ! [pattern:label_example]
102
+ }
103
+
104
+ {
105
+ // ! [pattern:concat_example]
106
+ // Detect Concat operation with arbitrary number of inputs
107
+ auto concat = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32 , ngraph::Shape{}, ngraph::pattern::has_class<ngraph::opset3::Concat>());
108
+ auto m = std::make_shared<ngraph::pattern::Matcher>(concat, " ConcatMatcher" );
109
+ // ! [pattern:concat_example]
110
+ }
111
+
112
+ {
113
+ // ! [pattern:predicate_example]
114
+ // Detect Multiply or Add operation
115
+ auto lin_op = std::make_shared<ngraph::pattern::op::Label>(ngraph::element::f32 , ngraph::Shape{},
116
+ [](const std::shared_ptr<ngraph::Node> & node) -> bool {
117
+ return std::dynamic_pointer_cast<ngraph::opset3::Multiply>(node) ||
118
+ std::dynamic_pointer_cast<ngraph::opset3::Add>(node);
119
+ });
120
+ auto m = std::make_shared<ngraph::pattern::Matcher>(lin_op, " MultiplyOrAddMatcher" );
121
+ // ! [pattern:predicate_example]
122
+ }
123
+
124
+ }
125
+
126
+ bool ngraph_api_examples (std::shared_ptr<Node> node) {
127
+ {
128
+ // ! [ngraph:ports_example]
129
+ // Let's suppose that node is opset3::Convolution operation
130
+ // as we know opset3::Convolution has two input ports (data, weights) and one output port
131
+ Input <Node> data = node->input (0 );
132
+ Input <Node> weights = node->input (1 );
133
+ Output <Node> output = node->output (0 );
134
+
135
+ // Getting shape and type
136
+ auto pshape = data.get_partial_shape ();
137
+ auto el_type = data.get_element_type ();
138
+
139
+ // Ggetting parent for input port
140
+ Output <Node> parent_output;
141
+ parent_output = data.get_source_output ();
142
+
143
+ // Another short way to get partent for output port
144
+ parent_output = node->input_value (0 );
145
+
146
+ // Getting all consumers for output port
147
+ auto consumers = output.get_target_inputs ();
148
+ // ! [ngraph:ports_example]
149
+ }
150
+
151
+ {
152
+ // ! [ngraph:shape]
153
+ auto partial_shape = node->input (0 ).get_partial_shape (); // get zero input partial shape
154
+ if (partial_shape.is_dynamic () /* or !partial_shape.is_staic() */ ) {
155
+ return false ;
156
+ }
157
+ auto static_shape = partial_shape.get_shape ();
158
+ // ! [ngraph:shape]
159
+ }
160
+
161
+ {
162
+ // ! [ngraph:shape_check]
163
+ auto partial_shape = node->input (0 ).get_partial_shape (); // get zero input partial shape
164
+
165
+ // Check that input shape rank is static
166
+ if (!partial_shape.rank ().is_static ()) {
167
+ return false ;
168
+ }
169
+ auto rank_size = partial_shape.rank ().get_length ();
170
+
171
+ // Check that second dimension is not dynamic
172
+ if (rank_size < 2 || partial_shape[1 ].is_dynamic ()) {
173
+ return false ;
174
+ }
175
+ auto dim = partial_shape[1 ].get_length ();
176
+ // ! [ngraph:shape_check]
177
+ }
178
+
179
+ return true ;
180
+ }
181
+
182
+ // ! [ngraph:replace_node]
183
+ bool ngraph_replace_node (std::shared_ptr<Node> node) {
184
+ // Step 1. Verify that node has opset3::Negative type
185
+ auto neg = std::dynamic_pointer_cast<ngraph::opset3::Negative>(node);
186
+ if (!neg) {
187
+ return false ;
188
+ }
189
+
190
+ // Step 2. Create opset3::Multiply operation where the first input is negative operation input and second as Constant with -1 value
191
+ auto mul = std::make_shared<ngraph::opset3::Multiply>(neg->input_value (0 ),
192
+ opset3::Constant::create (neg->get_element_type (), Shape{1 }, {-1 }));
193
+
194
+ mul->set_friendly_name (neg->get_friendly_name ());
195
+ ngraph::copy_runtime_info (neg, mul);
196
+
197
+ // Step 3. Replace Negative operation with Multiply operation
198
+ ngraph::replace_node (neg, mul);
199
+ return true ;
200
+
201
+ // Step 4. Negative operation will be removed automatically because all consumers was moved to Multiply operation
202
+ }
203
+ // ! [ngraph:replace_node]
204
+
205
+ // ! [ngraph:insert_node]
206
+ // Step 1. Lets suppose that we have a node with single output port and we want to insert additional operation new_node after it
207
+ void insert_example (std::shared_ptr<ngraph::Node> node) {
208
+ // Get all consumers for node
209
+ auto consumers = node->output (0 ).get_target_inputs ();
210
+
211
+ // Step 2. Create new node. Let it be opset1::Relu.
212
+ auto new_node = std::make_shared<ngraph::opset3::Relu>(node);
213
+
214
+ // Step 3. Reconnect all consumers to new_node
215
+ for (auto input : consumers) {
216
+ input.replace_source_output (new_node);
217
+ }
218
+ }
219
+ // ! [ngraph:insert_node]
220
+
221
+ // ! [ngraph:insert_node_with_copy]
222
+ void insert_example_with_copy (std::shared_ptr<ngraph::Node> node) {
223
+ // Make a node copy
224
+ auto node_copy = node->clone_with_new_inputs (node->input_values ());
225
+ // Create new node
226
+ auto new_node = std::make_shared<ngraph::opset3::Relu>(node_copy);
227
+ ngraph::replace_node (node, new_node);
228
+ }
229
+ // ! [ngraph:insert_node_with_copy]
230
+
231
+ void eliminate_example (std::shared_ptr<ngraph::Node> node) {
232
+ // ! [ngraph:eliminate_node]
233
+ // Suppose we have a node that we want to remove
234
+ bool success = replace_output_update_name (node->output (0 ), node->input_value (0 ));
235
+ // ! [ngraph:eliminate_node]
236
+ }
237
+
238
+ // ! [ngraph:visualize]
239
+ void visualization_example (std::shared_ptr<ngraph::Function> f) {
240
+ std::vector<std::shared_ptr<ngraph::Function> > g{f};
241
+
242
+ // Serialize ngraph::Function to before.svg file before transformation
243
+ ngraph::pass::VisualizeTree (" /path/to/file/before.svg" ).run_on_module (g);
244
+
245
+ // Run your transformation
246
+ // ngraph::pass::MyTransformation().run_on_function();
247
+
248
+ // Serialize ngraph::Function to after.svg file after transformation
249
+ ngraph::pass::VisualizeTree (" /path/to/file/after.svg" ).run_on_module (g);
250
+ }
251
+ // ! [ngraph:visualize]
0 commit comments