5
5
#include " cpu_convert.h"
6
6
#include " cpu_memcpy.h"
7
7
#include " utils/bfloat16.hpp"
8
+ #include < mkldnn_selective_build.h>
8
9
#include < type_traits>
10
+ #include < tuple>
9
11
#include < ie_parallel.hpp>
10
12
11
13
using namespace InferenceEngine ;
12
14
15
+ namespace {
16
+
13
17
template <typename srcType, typename dstType>
14
18
void convert (const void *srcPtr, void *dstPtr, const size_t size) {
15
19
if (std::is_same<srcType, dstType>::value) {
@@ -24,45 +28,41 @@ void convert(const void *srcPtr, void *dstPtr, const size_t size) {
24
28
}
25
29
}
26
30
27
- template <typename srcType>
28
- void convertFrom (const void *srcPtr, void *dstPtr, Precision dstPrc, const size_t size) {
29
- switch (dstPrc) {
30
- case Precision::U8:
31
- convert<srcType, PrecisionTrait<Precision::U8>::value_type>(srcPtr, dstPtr, size);
32
- break ;
33
- case Precision::I8:
34
- convert<srcType, PrecisionTrait<Precision::I8>::value_type>(srcPtr, dstPtr, size);
35
- break ;
36
- case Precision::U16:
37
- convert<srcType, PrecisionTrait<Precision::U16>::value_type>(srcPtr, dstPtr, size);
38
- break ;
39
- case Precision::I16:
40
- convert<srcType, PrecisionTrait<Precision::I16>::value_type>(srcPtr, dstPtr, size);
41
- break ;
42
- case Precision::I32:
43
- convert<srcType, PrecisionTrait<Precision::I32>::value_type>(srcPtr, dstPtr, size);
44
- break ;
45
- case Precision::U64:
46
- convert<srcType, PrecisionTrait<Precision::U64>::value_type>(srcPtr, dstPtr, size);
47
- break ;
48
- case Precision::I64:
49
- convert<srcType, PrecisionTrait<Precision::I64>::value_type>(srcPtr, dstPtr, size);
50
- break ;
51
- case Precision::FP32:
52
- convert<srcType, PrecisionTrait<Precision::FP32>::value_type>(srcPtr, dstPtr, size);
53
- break ;
54
- case Precision::BF16:
55
- convert<srcType, MKLDNNPlugin::bfloat16_t >(srcPtr, dstPtr, size);
56
- break ;
57
- case Precision::BOOL:
58
- convert<srcType, PrecisionTrait<Precision::BOOL>::value_type>(srcPtr, dstPtr, size);
59
- break ;
60
- default :
61
- THROW_IE_EXCEPTION << " cpu_convert can't convert to: " << dstPrc << " precision" ;
31
+ template <Precision::ePrecision p>
32
+ struct PrecisionInfo {
33
+ using value_type = typename PrecisionTrait<p>::value_type;
34
+ };
35
+
36
+ template <>
37
+ struct PrecisionInfo <Precision::BF16> {
38
+ using value_type = MKLDNNPlugin::bfloat16_t ;
39
+ };
40
+
41
+ struct ConvertContext {
42
+ const void *srcPtr;
43
+ void *dstPtr;
44
+ size_t size;
45
+ bool converted;
46
+ };
47
+
48
+ template <typename T>
49
+ struct ConvertPrecision {
50
+ using src_t = typename std::tuple_element<0 , T>::type;
51
+ using dst_t = typename std::tuple_element<1 , T>::type;
52
+
53
+ void operator ()(ConvertContext & ctx) {
54
+ convert<src_t , dst_t >(ctx.srcPtr , ctx.dstPtr , ctx.size );
55
+ ctx.converted = true ;
62
56
}
63
- }
57
+ };
58
+
59
+ } // namespace
60
+
61
+ #define MKLDNN_CVT (ST, DT ) OV_CASE2(Precision::ST, Precision::DT, PrecisionInfo<Precision::ST>::value_type, PrecisionInfo<Precision::DT>::value_type)
64
62
65
63
void cpu_convert (const void *srcPtr, void *dstPtr, Precision srcPrc, Precision dstPrc, const size_t size) {
64
+ using namespace MKLDNNPlugin ;
65
+
66
66
if (srcPtr == nullptr || dstPtr == nullptr )
67
67
THROW_IE_EXCEPTION << " cpu_convert has null data pointer" ;
68
68
@@ -71,38 +71,42 @@ void cpu_convert(const void *srcPtr, void *dstPtr, Precision srcPrc, Precision d
71
71
return ;
72
72
}
73
73
74
- switch (srcPrc) {
75
- case Precision::U8:
76
- convertFrom<PrecisionTrait<Precision::U8>::value_type>(srcPtr, dstPtr, dstPrc, size);
77
- break ;
78
- case Precision::I8:
79
- convertFrom<PrecisionTrait<Precision::I8>::value_type>(srcPtr, dstPtr, dstPrc, size);
80
- break ;
81
- case Precision::U16:
82
- convertFrom<PrecisionTrait<Precision::U16>::value_type>(srcPtr, dstPtr, dstPrc, size);
83
- break ;
84
- case Precision::I16:
85
- convertFrom<PrecisionTrait<Precision::I16>::value_type>(srcPtr, dstPtr, dstPrc, size);
86
- break ;
87
- case Precision::I32:
88
- convertFrom<PrecisionTrait<Precision::I32>::value_type>(srcPtr, dstPtr, dstPrc, size);
89
- break ;
90
- case Precision::U64:
91
- convertFrom<PrecisionTrait<Precision::U64>::value_type>(srcPtr, dstPtr, dstPrc, size);
92
- break ;
93
- case Precision::I64:
94
- convertFrom<PrecisionTrait<Precision::I64>::value_type>(srcPtr, dstPtr, dstPrc, size);
95
- break ;
96
- case Precision::FP32:
97
- convertFrom<PrecisionTrait<Precision::FP32>::value_type>(srcPtr, dstPtr, dstPrc, size);
98
- break ;
99
- case Precision::BF16:
100
- convertFrom<MKLDNNPlugin::bfloat16_t >(srcPtr, dstPtr, dstPrc, size);
101
- break ;
102
- case Precision::BOOL:
103
- convertFrom<PrecisionTrait<Precision::BOOL>::value_type>(srcPtr, dstPtr, dstPrc, size);
104
- break ;
105
- default :
106
- THROW_IE_EXCEPTION << " cpu_convert can't convert from: " << srcPrc << " precision" ;
107
- }
74
+ ConvertContext ctx = { srcPtr, dstPtr, size, false };
75
+
76
+ OV_SWITCH (MKLDNNPlugin, ConvertPrecision, ctx, std::tie (srcPrc, dstPrc),
77
+ MKLDNN_CVT (U8, I8), MKLDNN_CVT (U8, U16), MKLDNN_CVT (U8, I16),
78
+ MKLDNN_CVT (U8, I32), MKLDNN_CVT (U8, U64), MKLDNN_CVT (U8, I64),
79
+ MKLDNN_CVT (U8, FP32), MKLDNN_CVT (U8, BF16), MKLDNN_CVT (U8, BOOL),
80
+ MKLDNN_CVT (I8, U8), MKLDNN_CVT (I8, U16), MKLDNN_CVT (I8, I16),
81
+ MKLDNN_CVT (I8, I32), MKLDNN_CVT (I8, U64), MKLDNN_CVT (I8, I64),
82
+ MKLDNN_CVT (I8, FP32), MKLDNN_CVT (I8, BF16), MKLDNN_CVT (I8, BOOL),
83
+ MKLDNN_CVT (U16, U8), MKLDNN_CVT (U16, I8), MKLDNN_CVT (U16, I16),
84
+ MKLDNN_CVT (U16, I32), MKLDNN_CVT (U16, U64), MKLDNN_CVT (U16, I64),
85
+ MKLDNN_CVT (U16, FP32), MKLDNN_CVT (U16, BF16), MKLDNN_CVT (U16, BOOL),
86
+ MKLDNN_CVT (I16, U8), MKLDNN_CVT (I16, I8), MKLDNN_CVT (I16, U16),
87
+ MKLDNN_CVT (I16, I32), MKLDNN_CVT (I16, U64), MKLDNN_CVT (I16, I64),
88
+ MKLDNN_CVT (I16, FP32), MKLDNN_CVT (I16, BF16), MKLDNN_CVT (I16, BOOL),
89
+ MKLDNN_CVT (I32, U8), MKLDNN_CVT (I32, I8), MKLDNN_CVT (I32, U16),
90
+ MKLDNN_CVT (I32, I16), MKLDNN_CVT (I32, U64), MKLDNN_CVT (I32, I64),
91
+ MKLDNN_CVT (I32, FP32), MKLDNN_CVT (I32, BF16), MKLDNN_CVT (I32, BOOL),
92
+ MKLDNN_CVT (U64, U8), MKLDNN_CVT (U64, I8), MKLDNN_CVT (U64, U16),
93
+ MKLDNN_CVT (U64, I16), MKLDNN_CVT (U64, I32), MKLDNN_CVT (U64, I64),
94
+ MKLDNN_CVT (U64, FP32), MKLDNN_CVT (U64, BF16), MKLDNN_CVT (U64, BOOL),
95
+ MKLDNN_CVT (I64, U8), MKLDNN_CVT (I64, I8), MKLDNN_CVT (I64, U16),
96
+ MKLDNN_CVT (I64, I16), MKLDNN_CVT (I64, I32), MKLDNN_CVT (I64, U64),
97
+ MKLDNN_CVT (I64, FP32), MKLDNN_CVT (I64, BF16), MKLDNN_CVT (I64, BOOL),
98
+ MKLDNN_CVT (FP32, U8), MKLDNN_CVT (FP32, I8), MKLDNN_CVT (FP32, U16),
99
+ MKLDNN_CVT (FP32, I16), MKLDNN_CVT (FP32, I32), MKLDNN_CVT (FP32, U64),
100
+ MKLDNN_CVT (FP32, I64), MKLDNN_CVT (FP32, BF16), MKLDNN_CVT (FP32, BOOL),
101
+ MKLDNN_CVT (BF16, U8), MKLDNN_CVT (BF16, I8), MKLDNN_CVT (BF16, U16),
102
+ MKLDNN_CVT (BF16, I16), MKLDNN_CVT (BF16, I32), MKLDNN_CVT (BF16, U64),
103
+ MKLDNN_CVT (BF16, I64), MKLDNN_CVT (BF16, FP32), MKLDNN_CVT (BF16, BOOL),
104
+ MKLDNN_CVT (BOOL, U8), MKLDNN_CVT (BOOL, I8), MKLDNN_CVT (BOOL, U16),
105
+ MKLDNN_CVT (BOOL, I16), MKLDNN_CVT (BOOL, I32), MKLDNN_CVT (BOOL, U64),
106
+ MKLDNN_CVT (BOOL, I64), MKLDNN_CVT (BOOL, FP32), MKLDNN_CVT (BOOL, BF16));
107
+
108
+ if (!ctx.converted )
109
+ THROW_IE_EXCEPTION << " cpu_convert can't convert from: " << srcPrc << " precision to: " << dstPrc;
108
110
}
111
+
112
+ #undef MKLDNN_CVT
0 commit comments