Skip to content

Commit e40ac7f

Browse files
committed
Use optimized convert from the Reference lib
1 parent e385a74 commit e40ac7f

File tree

3 files changed

+23
-38
lines changed

3 files changed

+23
-38
lines changed

src/core/src/node.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "openvino/core/descriptor_tensor.hpp"
1717
#include "openvino/core/log_util.hpp"
1818
#include "openvino/core/rt_info.hpp"
19+
#include "openvino/core/rt_info/weightless_caching_attributes.hpp"
1920
#include "openvino/core/shape_util.hpp"
2021
#include "openvino/op/util/op_types.hpp"
2122
#include "openvino/pass/constant_folding.hpp"
@@ -744,6 +745,7 @@ bool ov::Node::constant_fold(OutputVector& output_values, const OutputVector& in
744745
for (size_t i = 0; i < output_tensors.size(); ++i) {
745746
output_values[i] = make_shared<ov::op::v0::Constant>(output_tensors[i]);
746747
ov::copy_runtime_info(nodes, output_values[i].get_node_shared_ptr());
748+
ov::copy_weightless_cache_attr(nodes[0], output_values[i].get_node_shared_ptr());
747749
}
748750
return true;
749751
}

src/frontends/ir/src/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,5 @@
55
ov_add_frontend(NAME ir
66
FILEDESCRIPTION "FrontEnd to load OpenVINO IR file format"
77
LINK_LIBRARIES openvino::pugixml
8-
openvino::core::dev)
8+
openvino::core::dev
9+
openvino_reference)

src/frontends/ir/src/ir_deserializer.cpp

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "openvino/op/util/op_types.hpp"
2121
#include "openvino/op/util/read_value_base.hpp"
2222
#include "openvino/op/util/variable.hpp"
23+
#include "../../../core/reference/include/openvino/reference/convert.hpp"
2324
#include "openvino/runtime/shared_buffer.hpp"
2425
#include "openvino/runtime/string_aligned_buffer.hpp"
2526
#include "openvino/util/xml_parse_utils.hpp"
@@ -262,46 +263,27 @@ ov::op::v5::Loop::SpecialBodyPorts ov::XmlDeserializer::parse_purpose_attribute(
262263
}
263264

264265
namespace {
265-
template <typename src_type, typename dst_type>
266-
inline dst_type convert_value(src_type val) {
267-
if (val > std::numeric_limits<dst_type>::max()) {
268-
return std::numeric_limits<dst_type>::max();
269-
} else if (val < std::numeric_limits<dst_type>::lowest()) {
270-
return std::numeric_limits<dst_type>::lowest();
271-
}
272-
return static_cast<dst_type>(val);
273-
}
274-
275-
template <ov::element::Type_t DT_FROM, ov::element::Type_t DT_TO>
276-
void convert_dt(char* dst, const char* src, size_t el_num) {
277-
using src_type = typename ov::element_type_traits<DT_FROM>::value_type;
278-
using dst_type = typename ov::element_type_traits<DT_TO>::value_type;
279-
280-
auto src_data = reinterpret_cast<const src_type*>(src);
281-
auto dst_data = reinterpret_cast<dst_type*>(dst);
282-
283-
for (size_t i = 0lu; i < el_num; i++) {
284-
dst_data[i] = convert_value<src_type, dst_type>(src_data[i]);
285-
}
286-
}
287-
288266
void convert_dt(ov::element::Type to_dt, ov::element::Type from_dt, char* dst, const char* src, size_t el_num) {
289-
if (from_dt == ov::element::i64 && to_dt == ov::element::i32) {
290-
convert_dt<ov::element::Type_t::i64, ov::element::Type_t::i32>(dst, src, el_num);
291-
} else if (from_dt == ov::element::f32 && to_dt == ov::element::f16) {
292-
convert_dt<ov::element::Type_t::f32, ov::element::Type_t::f16>(dst, src, el_num);
293-
} else if (from_dt == ov::element::f16 && to_dt == ov::element::f32) {
294-
convert_dt<ov::element::Type_t::f16, ov::element::Type_t::f32>(dst, src, el_num);
267+
if (from_dt == ov::element::f16 && to_dt == ov::element::f32) {
268+
ov::reference::convert(reinterpret_cast<const ov::float16*>(src), reinterpret_cast<float*>(dst), el_num);
269+
} else if (from_dt == ov::element::i64 && to_dt == ov::element::i32) {
270+
ov::reference::convert(reinterpret_cast<const int64_t*>(src), reinterpret_cast<int32_t*>(dst), el_num);
271+
} else if (from_dt == ov::element::u64 && to_dt == ov::element::i32) {
272+
ov::reference::convert(reinterpret_cast<const uint64_t*>(src), reinterpret_cast<int32_t*>(dst), el_num);
295273
} else if (from_dt == ov::element::bf16 && to_dt == ov::element::f32) {
296-
convert_dt<ov::element::Type_t::bf16, ov::element::Type_t::f32>(dst, src, el_num);
297-
} else if (from_dt == ov::element::f32 && to_dt == ov::element::bf16) {
298-
convert_dt<ov::element::Type_t::f32, ov::element::Type_t::bf16>(dst, src, el_num);
274+
ov::reference::convert(reinterpret_cast<const ov::bfloat16*>(src), reinterpret_cast<float*>(dst), el_num);
299275
} else if (from_dt == ov::element::u8 && to_dt == ov::element::i32) {
300-
convert_dt<ov::element::Type_t::u8, ov::element::Type_t::i32>(dst, src, el_num);
301-
} else if (from_dt == ov::element::i8 && to_dt == ov::element::f32) {
302-
convert_dt<ov::element::Type_t::i8, ov::element::Type_t::f32>(dst, src, el_num);
276+
ov::reference::convert(reinterpret_cast<const uint8_t*>(src), reinterpret_cast<int32_t*>(dst), el_num);
277+
} else if (from_dt == ov::element::i8 && to_dt == ov::element::i32) {
278+
ov::reference::convert(reinterpret_cast<const int8_t*>(src), reinterpret_cast<int32_t*>(dst), el_num);
303279
} else if (from_dt == ov::element::u8 && to_dt == ov::element::f32) {
304-
convert_dt<ov::element::Type_t::u8, ov::element::Type_t::f32>(dst, src, el_num);
280+
ov::reference::convert(reinterpret_cast<const uint8_t*>(src), reinterpret_cast<float*>(dst), el_num);
281+
} else if (from_dt == ov::element::i8 && to_dt == ov::element::f32) {
282+
ov::reference::convert(reinterpret_cast<const int8_t*>(src), reinterpret_cast<float*>(dst), el_num);
283+
} else if (from_dt == ov::element::f32 && to_dt == ov::element::f16) {
284+
ov::reference::convert(reinterpret_cast<const float*>(src), reinterpret_cast<ov::float16*>(dst), el_num);
285+
} else if (from_dt == ov::element::f32 && to_dt == ov::element::bf16) {
286+
ov::reference::convert(reinterpret_cast<const float*>(src), reinterpret_cast<ov::bfloat16*>(dst), el_num);
305287
} else {
306288
OPENVINO_THROW("Unsupported element types conversion from ", from_dt, " to ", to_dt);
307289
}
@@ -444,7 +426,7 @@ void ov::XmlDeserializer::on_adapter(const std::string& name, ov::ValueAccessor<
444426
for (auto attr : child.attributes()) {
445427
if (strcmp(attr.name(), "name") == 0 &&
446428
strcmp(attr.value(), ov::WeightlessCacheAttribute::get_type_info_static().name) == 0) {
447-
ov::element::Type original_dt(child.attribute("original_dtype").value());
429+
const ov::element::Type original_dt(child.attribute("original_dtype").value());
448430
offset = static_cast<size_t>(pugixml::get_uint64_attr(child, "bin_offset"));
449431
origin_size = static_cast<size_t>(pugixml::get_uint64_attr(child, "original_size"));
450432
actual_size = ((el_num * el_type.bitwidth() + 7) >> 3);

0 commit comments

Comments
 (0)