Skip to content

Commit 8c52b1e

Browse files
committed
Fix for Serializer
1 parent 2f16435 commit 8c52b1e

File tree

17 files changed

+164
-127
lines changed

17 files changed

+164
-127
lines changed

src/common/transformations/include/transformations/common_optimizations/eliminate_weightless_attributes.hpp

Lines changed: 0 additions & 25 deletions
This file was deleted.

src/common/transformations/src/transformations/common_optimizations/eliminate_weightless_attributes.cpp

Lines changed: 0 additions & 33 deletions
This file was deleted.

src/core/include/openvino/pass/serialize.hpp

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,21 @@ class OPENVINO_API Serialize : public ov::pass::ModelPass {
3232
};
3333
bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
3434

35-
Serialize(std::ostream& xmlFile, std::ostream& binFile, Version version = Version::UNSPECIFIED);
35+
Serialize(std::ostream& xmlFile,
36+
std::ostream& binFile,
37+
Version version = Version::UNSPECIFIED,
38+
bool skip_weightless_constants = false);
3639

37-
Serialize(const std::string& xmlPath, const std::string& binPath, Version version = Version::UNSPECIFIED);
40+
Serialize(const std::string& xmlPath,
41+
const std::string& binPath,
42+
Version version = Version::UNSPECIFIED,
43+
bool skip_weightless_constants = false);
3844

3945
Serialize(const std::filesystem::path& xmlPath,
4046
const std::filesystem::path& binPath,
41-
Version version = Version::UNSPECIFIED)
42-
: Serialize(xmlPath.string(), binPath.string(), version) {}
47+
Version version = Version::UNSPECIFIED,
48+
bool skip_weightless_constants = false)
49+
: Serialize(xmlPath.string(), binPath.string(), version, skip_weightless_constants) {}
4350

4451
private:
4552
std::ostream* m_xmlFile;
@@ -48,6 +55,8 @@ class OPENVINO_API Serialize : public ov::pass::ModelPass {
4855
const std::string m_binPath;
4956
const Version m_version;
5057
const std::map<std::string, ov::OpSet> m_custom_opsets;
58+
// If True, don't serialize weights of Constants nodes with WeightlessCache attribute.
59+
bool m_skip_weightless_constants;
5160
};
5261

5362
/**
@@ -74,13 +83,16 @@ class OPENVINO_API StreamSerialize : public ov::pass::ModelPass {
7483
StreamSerialize(std::ostream& stream,
7584
const std::function<void(std::ostream&)>& custom_data_serializer = {},
7685
const std::function<std::string(const std::string&)>& cache_encrypt = {},
77-
Serialize::Version version = Serialize::Version::UNSPECIFIED);
86+
Serialize::Version version = Serialize::Version::UNSPECIFIED,
87+
bool skip_weightless_constants = false);
7888

7989
private:
8090
std::ostream& m_stream;
8191
std::function<void(std::ostream&)> m_custom_data_serializer;
8292
std::function<std::string(const std::string&)> m_cache_encrypt;
8393
const Serialize::Version m_version;
94+
// If True, don't serialize weights of Constants nodes with WeightlessCache attribute.
95+
bool m_skip_weightless_constants;
8496
};
8597
} // namespace pass
8698
} // namespace ov

src/core/src/pass/serialize.cpp

Lines changed: 73 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,8 @@ void ngfunction_2_ir(pugi::xml_node& node,
216216
const ov::Model& model,
217217
ConstantWriter& constant_write_handler,
218218
int64_t version,
219-
bool deterministic);
219+
bool deterministic,
220+
const bool skip_weightless_constants = false);
220221

221222
namespace rt_info {
222223
static const std::vector<std::string> list_of_names{
@@ -330,7 +331,7 @@ class XmlSerializer : public ov::AttributeVisitor {
330331
bool m_compress_to_fp16;
331332
ov::element::Type m_output_element_type;
332333
bool m_data_is_temporary;
333-
bool m_weightless_const;
334+
bool m_skip_weightless_constants;
334335

335336
template <typename T>
336337
std::string create_atribute_list(ov::ValueAccessor<std::vector<T>>& adapter) {
@@ -451,7 +452,7 @@ class XmlSerializer : public ov::AttributeVisitor {
451452
bool compress_to_fp16 = false,
452453
ov::element::Type output_element_type = ov::element::dynamic,
453454
bool data_is_temporary = false,
454-
bool weightless_const = false)
455+
bool skip_weightless_constants = false)
455456
: m_xml_node(data),
456457
m_node_type_name(node_type_name),
457458
m_constant_write_handler(constant_write_handler),
@@ -460,7 +461,7 @@ class XmlSerializer : public ov::AttributeVisitor {
460461
m_compress_to_fp16(compress_to_fp16),
461462
m_output_element_type(output_element_type),
462463
m_data_is_temporary(data_is_temporary),
463-
m_weightless_const(weightless_const) {}
464+
m_skip_weightless_constants(skip_weightless_constants) {}
464465

465466
void on_adapter(const std::string& name, ov::ValueAccessor<void>& adapter) override {
466467
using BodyTargetNames = std::tuple<std::string, std::string, std::vector<std::string>>;
@@ -595,7 +596,7 @@ class XmlSerializer : public ov::AttributeVisitor {
595596
}
596597
} else if (const auto& a = ov::as_type<ov::AttributeAdapter<std::shared_ptr<ov::AlignedBuffer>>>(&adapter)) {
597598
if (name == "value" && translate_type_name(m_node_type_name) == "Const") {
598-
const size_t size = m_weightless_const ? 0lu : a->get()->size();
599+
const size_t size = m_skip_weightless_constants ? 0lu : a->get()->size();
599600
size_t new_size = 0lu;
600601
int64_t offset = m_constant_write_handler.write(static_cast<const char*>(a->get()->get_ptr()),
601602
size,
@@ -690,11 +691,21 @@ class XmlSerializer : public ov::AttributeVisitor {
690691
// to layer above (m_xml_node.parent()) as in ngfunction_2_ir() layer (m_xml_node) with empty attributes
691692
// is removed.
692693
pugi::xml_node xml_body = m_xml_node.parent().append_child(name.c_str());
693-
ngfunction_2_ir(xml_body, *adapter.get(), m_constant_write_handler, m_version, m_deterministic);
694+
ngfunction_2_ir(xml_body,
695+
*adapter.get(),
696+
m_constant_write_handler,
697+
m_version,
698+
m_deterministic,
699+
m_skip_weightless_constants);
694700
xml_body.remove_attribute("name");
695701
xml_body.remove_attribute("version");
696702
} else if (name == "net") {
697-
ngfunction_2_ir(m_xml_node, *adapter.get(), m_constant_write_handler, m_version, m_deterministic);
703+
ngfunction_2_ir(m_xml_node,
704+
*adapter.get(),
705+
m_constant_write_handler,
706+
m_version,
707+
m_deterministic,
708+
m_skip_weightless_constants);
698709
} else {
699710
OPENVINO_THROW("Unsupported Model name.");
700711
}
@@ -1009,7 +1020,8 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
10091020
const ov::Model& model,
10101021
ConstantWriter& constant_node_write_handler,
10111022
int64_t version,
1012-
bool deterministic) {
1023+
bool deterministic,
1024+
const bool skip_weightless_constants) {
10131025
// If determinism is not required, include auto-generated names into xml
10141026
// model name is not critical for hash computing
10151027
if (!deterministic) {
@@ -1075,9 +1087,10 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
10751087

10761088
// <layers/data> general attributes
10771089
pugi::xml_node data = layer.append_child("data");
1078-
bool weightless_const = false;
10791090

1080-
auto append_runtime_info = [&weightless_const](pugi::xml_node& node, ov::RTMap& attributes) {
1091+
auto append_runtime_info = [&skip_weightless_constants](pugi::xml_node& node,
1092+
ov::RTMap& attributes,
1093+
bool& weightless_const) {
10811094
pugi::xml_node rt_node = node.append_child("rt_info");
10821095
bool has_attrs = false;
10831096
for (auto& item : attributes) {
@@ -1092,7 +1105,8 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
10921105
rt_node.remove_child(attribute_node);
10931106
} else {
10941107
has_attrs = true;
1095-
if (strcmp(type_info.name, ov::WeightlessCacheAttribute::get_type_info_static().name) == 0) {
1108+
if (skip_weightless_constants &&
1109+
strcmp(type_info.name, ov::WeightlessCacheAttribute::get_type_info_static().name) == 0) {
10961110
weightless_const = true;
10971111
}
10981112
}
@@ -1103,8 +1117,9 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
11031117
}
11041118
};
11051119

1120+
bool weightless_const = false;
11061121
if (version >= 11) {
1107-
append_runtime_info(layer, node->get_rt_info());
1122+
append_runtime_info(layer, node->get_rt_info(), weightless_const);
11081123
}
11091124

11101125
int port_id = 0;
@@ -1135,8 +1150,10 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
11351150
.set_value(std::to_string(d.get_length()).c_str());
11361151
}
11371152
}
1138-
if (version >= 11)
1139-
append_runtime_info(port, i.get_rt_info());
1153+
if (version >= 11) {
1154+
bool weightless_const_tmp = false;
1155+
append_runtime_info(port, i.get_rt_info(), weightless_const_tmp);
1156+
}
11401157
}
11411158

11421159
if (node_type_name == "TensorIterator" || node_type_name == "Loop") {
@@ -1192,8 +1209,10 @@ void ngfunction_2_ir(pugi::xml_node& netXml,
11921209
.set_value(std::to_string(d.get_length()).c_str());
11931210
}
11941211
}
1195-
if (version >= 11)
1196-
append_runtime_info(port, o.get_rt_info());
1212+
if (version >= 11) {
1213+
bool weightless_const_tmp = false;
1214+
append_runtime_info(port, o.get_rt_info(), weightless_const_tmp);
1215+
}
11971216
}
11981217
if (node_type_name == "TensorIterator" || node_type_name == "Loop") {
11991218
layer.insert_move_after(output, layer.first_child());
@@ -1288,7 +1307,8 @@ void serializeFunc(std::ostream& xml_file,
12881307
std::ostream& bin_file,
12891308
std::shared_ptr<ov::Model> model,
12901309
ov::pass::Serialize::Version ver,
1291-
bool deterministic = false) {
1310+
bool deterministic = false,
1311+
bool skip_weightless_constants = false) {
12921312
auto version = static_cast<int64_t>(ver);
12931313

12941314
auto& rt_info = model->get_rt_info();
@@ -1310,7 +1330,15 @@ void serializeFunc(std::ostream& xml_file,
13101330
pugi::xml_document xml_doc;
13111331
pugi::xml_node net_node = xml_doc.append_child(name.c_str());
13121332
ConstantWriter constant_write_handler(bin_file);
1313-
XmlSerializer visitor(net_node, name, constant_write_handler, version, deterministic);
1333+
XmlSerializer visitor(net_node,
1334+
name,
1335+
constant_write_handler,
1336+
version,
1337+
deterministic,
1338+
false,
1339+
ov::element::dynamic,
1340+
false,
1341+
skip_weightless_constants);
13141342
visitor.on_attribute(name, model);
13151343

13161344
xml_doc.save(xml_file);
@@ -1334,7 +1362,7 @@ bool pass::Serialize::run_on_model(const std::shared_ptr<ov::Model>& model) {
13341362
disable_fp16_compression(node);
13351363

13361364
if (m_xmlFile && m_binFile) {
1337-
serializeFunc(*m_xmlFile, *m_binFile, model, m_version);
1365+
serializeFunc(*m_xmlFile, *m_binFile, model, m_version, false, m_skip_weightless_constants);
13381366
} else {
13391367
#if defined(OPENVINO_ENABLE_UNICODE_PATH_SUPPORT) && defined(_WIN32)
13401368
const auto& xmlPath_ref = ov::util::string_to_wstring(m_xmlPath);
@@ -1359,7 +1387,7 @@ bool pass::Serialize::run_on_model(const std::shared_ptr<ov::Model>& model) {
13591387
OPENVINO_ASSERT(xml_file, message_xml);
13601388

13611389
try {
1362-
serializeFunc(xml_file, bin_file, model, m_version);
1390+
serializeFunc(xml_file, bin_file, model, m_version, false, m_skip_weightless_constants);
13631391
} catch (const ov::AssertFailure&) {
13641392
// optimization decision was made to create .bin file upfront and
13651393
// write to it directly instead of buffering its content in memory,
@@ -1376,28 +1404,38 @@ bool pass::Serialize::run_on_model(const std::shared_ptr<ov::Model>& model) {
13761404
return false;
13771405
}
13781406

1379-
pass::Serialize::Serialize(std::ostream& xmlFile, std::ostream& binFile, pass::Serialize::Version version)
1407+
pass::Serialize::Serialize(std::ostream& xmlFile,
1408+
std::ostream& binFile,
1409+
pass::Serialize::Version version,
1410+
bool weightless_cache)
13801411
: m_xmlFile{&xmlFile},
13811412
m_binFile{&binFile},
13821413
m_xmlPath{},
13831414
m_binPath{},
1384-
m_version{version} {}
1415+
m_version{version},
1416+
m_skip_weightless_constants{weightless_cache} {}
13851417

1386-
pass::Serialize::Serialize(const std::string& xmlPath, const std::string& binPath, pass::Serialize::Version version)
1418+
pass::Serialize::Serialize(const std::string& xmlPath,
1419+
const std::string& binPath,
1420+
pass::Serialize::Version version,
1421+
bool weightless_cache)
13871422
: m_xmlFile{nullptr},
13881423
m_binFile{nullptr},
13891424
m_xmlPath{valid_xml_path(xmlPath)},
13901425
m_binPath{provide_bin_path(xmlPath, binPath)},
1391-
m_version{version} {}
1426+
m_version{version},
1427+
m_skip_weightless_constants{weightless_cache} {}
13921428

13931429
pass::StreamSerialize::StreamSerialize(std::ostream& stream,
13941430
const std::function<void(std::ostream&)>& custom_data_serializer,
13951431
const std::function<std::string(const std::string&)>& cache_encrypt,
1396-
Serialize::Version version)
1432+
Serialize::Version version,
1433+
bool skip_weightless_constants)
13971434
: m_stream(stream),
13981435
m_custom_data_serializer(custom_data_serializer),
13991436
m_cache_encrypt(cache_encrypt),
1400-
m_version(version) {
1437+
m_version(version),
1438+
m_skip_weightless_constants(skip_weightless_constants) {
14011439
if (version != Serialize::Version::UNSPECIFIED && version != Serialize::Version::IR_V10 &&
14021440
version != Serialize::Version::IR_V11) {
14031441
OPENVINO_THROW("Unsupported version");
@@ -1447,7 +1485,15 @@ bool pass::StreamSerialize::run_on_model(const std::shared_ptr<ov::Model>& model
14471485
pugi::xml_document xml_doc;
14481486
pugi::xml_node net_node = xml_doc.append_child(name.c_str());
14491487
ConstantWriter constant_write_handler(m_stream);
1450-
XmlSerializer visitor(net_node, name, constant_write_handler, version);
1488+
XmlSerializer visitor(net_node,
1489+
name,
1490+
constant_write_handler,
1491+
version,
1492+
false,
1493+
false,
1494+
ov::element::dynamic,
1495+
false,
1496+
m_skip_weightless_constants);
14511497
std::shared_ptr<ov::Model> fun = model;
14521498
visitor.on_attribute(name, fun);
14531499

0 commit comments

Comments
 (0)