Skip to content

Commit 1707d4c

Browse files
ggerganoviThalay
authored andcommitted
bench : pass memcpy threads from cli
1 parent 5ec5b9d commit 1707d4c

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

whisper.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6138,7 +6138,7 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
61386138

61396139
// multi-thread
61406140

6141-
for (uint32_t n_threads = 1; n_threads <= std::thread::hardware_concurrency(); n_threads++) {
6141+
for (uint32_t k = 1; k <= n_threads; k++) {
61426142
char * src = (char *) malloc(size);
61436143
char * dst = (char *) malloc(size);
61446144

@@ -6149,8 +6149,8 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
61496149
double tsum = 0.0;
61506150

61516151
auto helper = [&](int th) {
6152-
const int64_t i0 = (th + 0)*size/n_threads;
6153-
const int64_t i1 = (th + 1)*size/n_threads;
6152+
const int64_t i0 = (th + 0)*size/k;
6153+
const int64_t i1 = (th + 1)*size/k;
61546154

61556155
for (size_t i = 0; i < n; i++) {
61566156
memcpy(dst + i0, src + i0, i1 - i0);
@@ -6161,22 +6161,22 @@ WHISPER_API const char * whisper_bench_memcpy_str(int n_threads) {
61616161

61626162
const int64_t t0 = ggml_time_us();
61636163

6164-
std::vector<std::thread> threads(n_threads - 1);
6165-
for (uint32_t th = 0; th < n_threads - 1; ++th) {
6164+
std::vector<std::thread> threads(k - 1);
6165+
for (uint32_t th = 0; th < k - 1; ++th) {
61666166
threads[th] = std::thread(helper, th);
61676167
}
61686168

6169-
helper(n_threads - 1);
6169+
helper(k - 1);
61706170

6171-
for (uint32_t th = 0; th < n_threads - 1; ++th) {
6171+
for (uint32_t th = 0; th < k - 1; ++th) {
61726172
threads[th].join();
61736173
}
61746174

61756175
const int64_t t1 = ggml_time_us();
61766176

61776177
tsum += (t1 - t0)*1e-6;
61786178

6179-
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), n_threads);
6179+
snprintf(strbuf, sizeof(strbuf), "memcpy: %7.2f GB/s (%2d thread)\n", (double) (n*size)/(tsum*1e9), k);
61806180
s += strbuf;
61816181

61826182
// needed to prevent the compiler from optimizing the memcpy away

0 commit comments

Comments
 (0)