Skip to content

Commit d95123f

Browse files
committed
Improve speed of ZSTD_compressSequencesAndLiterals() using RVV
1 parent f9938c2 commit d95123f

File tree

2 files changed

+157
-0
lines changed

2 files changed

+157
-0
lines changed

lib/common/compiler.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,9 @@
218218
# if defined(__ARM_NEON) || defined(_M_ARM64)
219219
# define ZSTD_ARCH_ARM_NEON
220220
# endif
221+
# if defined(__riscv) && defined(__riscv_vector)
222+
# define ZSTD_ARCH_RISCV_RVV
223+
# endif
221224
#
222225
# if defined(ZSTD_ARCH_X86_AVX2)
223226
# include <immintrin.h>
@@ -227,6 +230,9 @@
227230
# elif defined(ZSTD_ARCH_ARM_NEON)
228231
# include <arm_neon.h>
229232
# endif
233+
# if defined(ZSTD_ARCH_RISCV_RVV)
234+
# include <riscv_vector.h>
235+
# endif
230236
#endif
231237

232238
/* C-language Attributes are added in C23. */

lib/compress/zstd_compress.c

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7284,6 +7284,93 @@ static size_t convertSequences_noRepcodes(
72847284
return longLen;
72857285
}
72867286

7287+
#elif defined ZSTD_ARCH_RISCV_RVV
7288+
#include <riscv_vector.h>
7289+
/*
7290+
* Convert `vl` sequences per iteration, using AVX2 intrinsics:
7291+
* - offset -> offBase = offset + 2
7292+
* - litLength -> (U16) litLength
7293+
* - matchLength -> (U16)(matchLength - 3)
7294+
* - rep is ignored
7295+
* Store only 8 bytes per SeqDef (offBase[4], litLength[2], mlBase[2]).
7296+
*
7297+
* @returns 0 on succes, with no long length detected
7298+
* @returns > 0 if there is one long length (> 65535),
7299+
* indicating the position, and type.
7300+
*/
7301+
static size_t convertSequences_noRepcodes(SeqDef* dstSeqs, const ZSTD_Sequence* inSeqs, size_t nbSequences) {
7302+
size_t longLen = 0;
7303+
7304+
/* RVV depends on the specific definition of target structures */
7305+
ZSTD_STATIC_ASSERT(sizeof(ZSTD_Sequence) == 16);
7306+
ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, offset) == 0);
7307+
ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, litLength) == 4);
7308+
ZSTD_STATIC_ASSERT(offsetof(ZSTD_Sequence, matchLength) == 8);
7309+
ZSTD_STATIC_ASSERT(sizeof(SeqDef) == 8);
7310+
ZSTD_STATIC_ASSERT(offsetof(SeqDef, offBase) == 0);
7311+
ZSTD_STATIC_ASSERT(offsetof(SeqDef, litLength) == 4);
7312+
ZSTD_STATIC_ASSERT(offsetof(SeqDef, mlBase) == 6);
7313+
size_t vl = 0;
7314+
for (size_t i = 0; i < nbSequences; i += vl) {
7315+
7316+
vl = __riscv_vsetvl_e32m2(nbSequences-i);
7317+
// Loading structure member variables
7318+
vuint32m2x4_t v_tuple = __riscv_vlseg4e32_v_u32m2x4(
7319+
(const int32_t*)&inSeqs[i],
7320+
vl
7321+
);
7322+
vuint32m2_t v_offset = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 0);
7323+
vuint32m2_t v_lit = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 1);
7324+
vuint32m2_t v_match = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 2);
7325+
// offset + ZSTD_REP_NUM
7326+
vuint32m2_t v_offBase = __riscv_vadd_vx_u32m2(v_offset, ZSTD_REP_NUM, vl);
7327+
// Check for integer overflow
7328+
// Cast to a 16-bit variable
7329+
vbool16_t lit_overflow = __riscv_vmsgtu_vx_u32m2_b16(v_lit, 65535, vl);
7330+
vuint16m1_t v_lit_clamped = __riscv_vncvt_x_x_w_u16m1(v_lit, vl);
7331+
7332+
vbool16_t ml_overflow = __riscv_vmsgtu_vx_u32m2_b16(v_match, 65535+MINMATCH, vl);
7333+
vuint16m1_t v_ml_clamped = __riscv_vncvt_x_x_w_u16m1(__riscv_vsub_vx_u32m2(v_match, MINMATCH, vl), vl);
7334+
7335+
// Pack two 16-bit fields into a 32-bit value (little-endian)
7336+
// The lower 16 bits contain litLength, and the upper 16 bits contain mlBase
7337+
vuint32m2_t v_lit_ml_combined = __riscv_vsll_vx_u32m2(
7338+
__riscv_vwcvtu_x_x_v_u32m2(v_ml_clamped, vl), // Convert matchLength to 32-bit
7339+
16,
7340+
vl
7341+
);
7342+
v_lit_ml_combined = __riscv_vor_vv_u32m2(
7343+
v_lit_ml_combined,
7344+
__riscv_vwcvtu_x_x_v_u32m2(v_lit_clamped, vl),
7345+
vl
7346+
);
7347+
// Create a vector of SeqDef structures
7348+
// Store the offBase, litLength, and mlBase in a vector of SeqDef
7349+
vuint32m2x2_t store_data = __riscv_vcreate_v_u32m2x2(
7350+
v_offBase,
7351+
v_lit_ml_combined
7352+
);
7353+
__riscv_vsseg2e32_v_u32m2x2(
7354+
(uint32_t*)&dstSeqs[i],
7355+
store_data,
7356+
vl
7357+
);
7358+
// Find the first index where an overflow occurs
7359+
int first_ml = __riscv_vfirst_m_b16(ml_overflow, vl);
7360+
int first_lit = __riscv_vfirst_m_b16(lit_overflow, vl);
7361+
7362+
if (UNLIKELY(first_ml != -1)) {
7363+
assert(longLen == 0);
7364+
longLen = i + first_ml + 1;
7365+
}
7366+
if (UNLIKELY(first_lit != -1)) {
7367+
assert(longLen == 0);
7368+
longLen = i + first_lit + 1 + nbSequences;
7369+
}
7370+
}
7371+
return longLen;
7372+
}
7373+
72877374
/* the vector implementation could also be ported to SSSE3,
72887375
* but since this implementation is targeting modern systems (>= Sapphire Rapid),
72897376
* it's not useful to develop and maintain code for older pre-AVX2 platforms */
@@ -7451,6 +7538,70 @@ BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs)
74517538
}
74527539
}
74537540

7541+
#elif defined ZSTD_ARCH_RISCV_RVV
7542+
7543+
BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs)
7544+
{
7545+
size_t totalMatchSize = 0;
7546+
size_t litSize = 0;
7547+
size_t i = 0;
7548+
int found_terminator = 0;
7549+
size_t vl_max = __riscv_vsetvlmax_e32m1();
7550+
vuint32m1_t v_lit_sum = __riscv_vmv_v_x_u32m1(0, vl_max);
7551+
vuint32m1_t v_match_sum = __riscv_vmv_v_x_u32m1(0, vl_max);
7552+
7553+
for (; i < nbSeqs; ) {
7554+
size_t vl = __riscv_vsetvl_e32m2(nbSeqs - i);
7555+
7556+
ptrdiff_t stride = sizeof(ZSTD_Sequence); // 16
7557+
vuint32m2x4_t v_tuple = __riscv_vlseg4e32_v_u32m2x4(
7558+
(const int32_t*)&seqs[i],
7559+
vl
7560+
);
7561+
vuint32m2_t v_offset = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 0);
7562+
vuint32m2_t v_lit = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 1);
7563+
vuint32m2_t v_match = __riscv_vget_v_u32m2x4_u32m2(v_tuple, 2);
7564+
7565+
// Check if any element has a matchLength of 0
7566+
vbool16_t mask = __riscv_vmseq_vx_u32m2_b16(v_match, 0, vl);
7567+
int first_zero = __riscv_vfirst_m_b16(mask, vl);
7568+
7569+
if (first_zero >= 0) {
7570+
// Find the first zero byte and set the effective length to that index + 1 to
7571+
// recompute the cumulative vector length of literals and matches
7572+
vl = first_zero + 1;
7573+
7574+
// recompute the cumulative vector length of literals and matches
7575+
v_lit_sum = __riscv_vredsum_vs_u32m2_u32m1(__riscv_vslidedown_vx_u32m2(v_lit, 0, vl), v_lit_sum, vl);
7576+
v_match_sum = __riscv_vredsum_vs_u32m2_u32m1(__riscv_vslidedown_vx_u32m2(v_match, 0, vl), v_match_sum, vl);
7577+
7578+
i += vl;
7579+
found_terminator = 1;
7580+
assert(seqs[i - 1].offset == 0);
7581+
break;
7582+
} else {
7583+
7584+
v_lit_sum = __riscv_vredsum_vs_u32m2_u32m1(v_lit, v_lit_sum, vl);
7585+
v_match_sum = __riscv_vredsum_vs_u32m2_u32m1(v_match, v_match_sum, vl);
7586+
i += vl;
7587+
}
7588+
}
7589+
litSize = __riscv_vmv_x_s_u32m1_u32(v_lit_sum);
7590+
totalMatchSize = __riscv_vmv_x_s_u32m1_u32(v_match_sum);
7591+
7592+
if (!found_terminator && i==nbSeqs) {
7593+
BlockSummary bs;
7594+
bs.nbSequences = ERROR(externalSequences_invalid);
7595+
return bs;
7596+
}
7597+
{ BlockSummary bs;
7598+
bs.nbSequences = i;
7599+
bs.blockSize = litSize + totalMatchSize;
7600+
bs.litSize = litSize;
7601+
return bs;
7602+
}
7603+
}
7604+
74547605
#else
74557606

74567607
BlockSummary ZSTD_get1BlockSummary(const ZSTD_Sequence* seqs, size_t nbSeqs)

0 commit comments

Comments
 (0)