|
20 | 20 | #include "openvino/op/util/op_types.hpp"
|
21 | 21 | #include "openvino/op/util/read_value_base.hpp"
|
22 | 22 | #include "openvino/op/util/variable.hpp"
|
| 23 | +#include "../../../core/reference/include/openvino/reference/convert.hpp" |
23 | 24 | #include "openvino/runtime/shared_buffer.hpp"
|
24 | 25 | #include "openvino/runtime/string_aligned_buffer.hpp"
|
25 | 26 | #include "openvino/util/xml_parse_utils.hpp"
|
@@ -262,46 +263,27 @@ ov::op::v5::Loop::SpecialBodyPorts ov::XmlDeserializer::parse_purpose_attribute(
|
262 | 263 | }
|
263 | 264 |
|
264 | 265 | namespace {
|
265 |
| -template <typename src_type, typename dst_type> |
266 |
| -inline dst_type convert_value(src_type val) { |
267 |
| - if (val > std::numeric_limits<dst_type>::max()) { |
268 |
| - return std::numeric_limits<dst_type>::max(); |
269 |
| - } else if (val < std::numeric_limits<dst_type>::lowest()) { |
270 |
| - return std::numeric_limits<dst_type>::lowest(); |
271 |
| - } |
272 |
| - return static_cast<dst_type>(val); |
273 |
| -} |
274 |
| - |
275 |
| -template <ov::element::Type_t DT_FROM, ov::element::Type_t DT_TO> |
276 |
| -void convert_dt(char* dst, const char* src, size_t el_num) { |
277 |
| - using src_type = typename ov::element_type_traits<DT_FROM>::value_type; |
278 |
| - using dst_type = typename ov::element_type_traits<DT_TO>::value_type; |
279 |
| - |
280 |
| - auto src_data = reinterpret_cast<const src_type*>(src); |
281 |
| - auto dst_data = reinterpret_cast<dst_type*>(dst); |
282 |
| - |
283 |
| - for (size_t i = 0lu; i < el_num; i++) { |
284 |
| - dst_data[i] = convert_value<src_type, dst_type>(src_data[i]); |
285 |
| - } |
286 |
| -} |
287 |
| - |
288 | 266 | void convert_dt(ov::element::Type to_dt, ov::element::Type from_dt, char* dst, const char* src, size_t el_num) {
|
289 |
| - if (from_dt == ov::element::i64 && to_dt == ov::element::i32) { |
290 |
| - convert_dt<ov::element::Type_t::i64, ov::element::Type_t::i32>(dst, src, el_num); |
291 |
| - } else if (from_dt == ov::element::f32 && to_dt == ov::element::f16) { |
292 |
| - convert_dt<ov::element::Type_t::f32, ov::element::Type_t::f16>(dst, src, el_num); |
293 |
| - } else if (from_dt == ov::element::f16 && to_dt == ov::element::f32) { |
294 |
| - convert_dt<ov::element::Type_t::f16, ov::element::Type_t::f32>(dst, src, el_num); |
| 267 | + if (from_dt == ov::element::f16 && to_dt == ov::element::f32) { |
| 268 | + ov::reference::convert(reinterpret_cast<const ov::float16*>(src), reinterpret_cast<float*>(dst), el_num); |
| 269 | + } else if (from_dt == ov::element::i64 && to_dt == ov::element::i32) { |
| 270 | + ov::reference::convert(reinterpret_cast<const int64_t*>(src), reinterpret_cast<int32_t*>(dst), el_num); |
| 271 | + } else if (from_dt == ov::element::u64 && to_dt == ov::element::i32) { |
| 272 | + ov::reference::convert(reinterpret_cast<const uint64_t*>(src), reinterpret_cast<int32_t*>(dst), el_num); |
295 | 273 | } else if (from_dt == ov::element::bf16 && to_dt == ov::element::f32) {
|
296 |
| - convert_dt<ov::element::Type_t::bf16, ov::element::Type_t::f32>(dst, src, el_num); |
297 |
| - } else if (from_dt == ov::element::f32 && to_dt == ov::element::bf16) { |
298 |
| - convert_dt<ov::element::Type_t::f32, ov::element::Type_t::bf16>(dst, src, el_num); |
| 274 | + ov::reference::convert(reinterpret_cast<const ov::bfloat16*>(src), reinterpret_cast<float*>(dst), el_num); |
299 | 275 | } else if (from_dt == ov::element::u8 && to_dt == ov::element::i32) {
|
300 |
| - convert_dt<ov::element::Type_t::u8, ov::element::Type_t::i32>(dst, src, el_num); |
301 |
| - } else if (from_dt == ov::element::i8 && to_dt == ov::element::f32) { |
302 |
| - convert_dt<ov::element::Type_t::i8, ov::element::Type_t::f32>(dst, src, el_num); |
| 276 | + ov::reference::convert(reinterpret_cast<const uint8_t*>(src), reinterpret_cast<int32_t*>(dst), el_num); |
| 277 | + } else if (from_dt == ov::element::i8 && to_dt == ov::element::i32) { |
| 278 | + ov::reference::convert(reinterpret_cast<const int8_t*>(src), reinterpret_cast<int32_t*>(dst), el_num); |
303 | 279 | } else if (from_dt == ov::element::u8 && to_dt == ov::element::f32) {
|
304 |
| - convert_dt<ov::element::Type_t::u8, ov::element::Type_t::f32>(dst, src, el_num); |
| 280 | + ov::reference::convert(reinterpret_cast<const uint8_t*>(src), reinterpret_cast<float*>(dst), el_num); |
| 281 | + } else if (from_dt == ov::element::i8 && to_dt == ov::element::f32) { |
| 282 | + ov::reference::convert(reinterpret_cast<const int8_t*>(src), reinterpret_cast<float*>(dst), el_num); |
| 283 | + } else if (from_dt == ov::element::f32 && to_dt == ov::element::f16) { |
| 284 | + ov::reference::convert(reinterpret_cast<const float*>(src), reinterpret_cast<ov::float16*>(dst), el_num); |
| 285 | + } else if (from_dt == ov::element::f32 && to_dt == ov::element::bf16) { |
| 286 | + ov::reference::convert(reinterpret_cast<const float*>(src), reinterpret_cast<ov::bfloat16*>(dst), el_num); |
305 | 287 | } else {
|
306 | 288 | OPENVINO_THROW("Unsupported element types conversion from ", from_dt, " to ", to_dt);
|
307 | 289 | }
|
@@ -444,7 +426,7 @@ void ov::XmlDeserializer::on_adapter(const std::string& name, ov::ValueAccessor<
|
444 | 426 | for (auto attr : child.attributes()) {
|
445 | 427 | if (strcmp(attr.name(), "name") == 0 &&
|
446 | 428 | strcmp(attr.value(), ov::WeightlessCacheAttribute::get_type_info_static().name) == 0) {
|
447 |
| - ov::element::Type original_dt(child.attribute("original_dtype").value()); |
| 429 | + const ov::element::Type original_dt(child.attribute("original_dtype").value()); |
448 | 430 | offset = static_cast<size_t>(pugixml::get_uint64_attr(child, "bin_offset"));
|
449 | 431 | origin_size = static_cast<size_t>(pugixml::get_uint64_attr(child, "original_size"));
|
450 | 432 | actual_size = ((el_num * el_type.bitwidth() + 7) >> 3);
|
|
0 commit comments