@@ -492,6 +492,8 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
492
492
// quantization
493
493
//
494
494
495
+ #define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
496
+
495
497
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
496
498
// multiply int8_t, add results pairwise twice
497
499
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
@@ -551,7 +553,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
551
553
static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
552
554
{
553
555
const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
554
- const __m256i bytes = _mm256_set_m128i (_mm_srli_epi16(tmp, 4), tmp);
556
+ const __m256i bytes = MM256_SET_M128I (_mm_srli_epi16(tmp, 4), tmp);
555
557
const __m256i lowMask = _mm256_set1_epi8( 0xF );
556
558
return _mm256_and_si256(lowMask, bytes);
557
559
}
@@ -624,7 +626,7 @@ static inline __m256i bytes_from_bits_32(const uint8_t * x) {
624
626
bytesh = _mm_or_si128(bytesh, bit_mask);
625
627
bytesl = _mm_cmpeq_epi8(bytesl, _mm_set1_epi64x(-1));
626
628
bytesh = _mm_cmpeq_epi8(bytesh, _mm_set1_epi64x(-1));
627
- return _mm256_set_m128i (bytesh, bytesl);
629
+ return MM256_SET_M128I (bytesh, bytesl);
628
630
}
629
631
630
632
// Unpack 32 4-bit fields into 32 bytes
@@ -637,15 +639,15 @@ static inline __m256i bytes_from_nibbles_32(const uint8_t * rsi)
637
639
const __m128i lowMask = _mm_set1_epi8(0xF);
638
640
tmpl = _mm_and_si128(lowMask, tmpl);
639
641
tmph = _mm_and_si128(lowMask, tmph);
640
- return _mm256_set_m128i (tmph, tmpl);
642
+ return MM256_SET_M128I (tmph, tmpl);
641
643
}
642
644
643
645
// add int16_t pairwise and return as float vector
644
646
static inline __m256 sum_i16_pairs_float(const __m128i xh, const __m128i xl) {
645
647
const __m128i ones = _mm_set1_epi16(1);
646
648
const __m128i summed_pairsl = _mm_madd_epi16(ones, xl);
647
649
const __m128i summed_pairsh = _mm_madd_epi16(ones, xh);
648
- const __m256i summed_pairs = _mm256_set_m128i (summed_pairsh, summed_pairsl);
650
+ const __m256i summed_pairs = MM256_SET_M128I (summed_pairsh, summed_pairsl);
649
651
return _mm256_cvtepi32_ps(summed_pairs);
650
652
}
651
653
@@ -2350,7 +2352,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void *
2350
2352
const __m128i i32_1 = mul_sum_i8_pairs(bx, by);
2351
2353
2352
2354
// Convert int32_t to float
2353
- __m256 p = _mm256_cvtepi32_ps(_mm256_set_m128i (i32_0, i32_1));
2355
+ __m256 p = _mm256_cvtepi32_ps(MM256_SET_M128I (i32_0, i32_1));
2354
2356
2355
2357
// Apply the scale, and accumulate
2356
2358
acc = _mm256_add_ps(_mm256_mul_ps( d, p ), acc);
@@ -2826,7 +2828,7 @@ static void ggml_vec_dot_q5_0_q8_0(const int n, float * restrict s, const void *
2826
2828
__m128i bxh = _mm256_extractf128_si256(bx, 1);
2827
2829
bxl = _mm_or_si128(bxl, bxhil);
2828
2830
bxh = _mm_or_si128(bxh, bxhih);
2829
- bx = _mm256_set_m128i (bxh, bxl);
2831
+ bx = MM256_SET_M128I (bxh, bxl);
2830
2832
2831
2833
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
2832
2834
@@ -3082,7 +3084,7 @@ static void ggml_vec_dot_q5_1_q8_1(const int n, float * restrict s, const void *
3082
3084
__m128i bxh = _mm256_extractf128_si256(bx, 1);
3083
3085
bxl = _mm_or_si128(bxl, bxhil);
3084
3086
bxh = _mm_or_si128(bxh, bxhih);
3085
- bx = _mm256_set_m128i (bxh, bxl);
3087
+ bx = MM256_SET_M128I (bxh, bxl);
3086
3088
3087
3089
const __m256 dy = _mm256_set1_ps(y[i].d);
3088
3090
const __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
0 commit comments