Skip to content

Commit 3e945cc

Browse files
ikawrakowKawrakow
andauthored
HellaSwag: speed up by parallelizing log-prob evaluation (#5020)
For Mistral-7B and fp16, time on my system goes down from 536 seconds to 423 seconds for the full evaluation dataset (10042 tasks). Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent ad19812 commit 3e945cc

File tree

1 file changed

+66
-14
lines changed

1 file changed

+66
-14
lines changed

examples/perplexity/perplexity.cpp

Lines changed: 66 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <sstream>
99
#include <thread>
1010
#include <mutex>
11+
#include <atomic>
1112
#include <vector>
1213
#include <array>
1314
#include <fstream>
@@ -444,6 +445,48 @@ static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int>
444445
return result;
445446
}
446447

448+
static void hellaswag_compute_logprobs(const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
449+
const std::vector<std::pair<size_t, llama_token>>& eval_pairs, std::vector<float>& eval_results) {
450+
constexpr int k_token_chunk = 4;
451+
if (eval_results.size() != eval_pairs.size()) {
452+
eval_results.resize(eval_pairs.size());
453+
}
454+
if (eval_pairs.empty()) return;
455+
456+
size_t max_threads = std::min((eval_pairs.size() + k_token_chunk - 1)/k_token_chunk, workers.size());
457+
458+
std::atomic<int> counter(0);
459+
auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
460+
float local_logprobs[k_token_chunk];
461+
while (true) {
462+
size_t first = counter.fetch_add(k_token_chunk, std::memory_order_relaxed);
463+
if (first >= eval_results.size()) break;
464+
size_t last = std::min(first + k_token_chunk, eval_results.size());
465+
for (size_t i = first; i < last; ++i) {
466+
auto logits = batch_logits + eval_pairs[i].first * n_vocab;
467+
float max_logit = logits[0];
468+
for (int j = 1; j < n_vocab; ++j) {
469+
max_logit = std::max(max_logit, logits[j]);
470+
}
471+
float sum_p = 0.f;
472+
for (int j = 0; j < n_vocab; ++j) {
473+
sum_p += expf(logits[j] - max_logit);
474+
}
475+
local_logprobs[i - first] = logits[eval_pairs[i].second] - max_logit - std::log(sum_p);
476+
}
477+
std::memcpy(eval_results.data() + first, local_logprobs, (last - first)*sizeof(float));
478+
}
479+
};
480+
481+
for (size_t it = 0; it < max_threads; ++it) {
482+
workers[it] = std::thread(compute);
483+
}
484+
for (size_t it = 0; it < max_threads; ++it) {
485+
workers[it].join();
486+
}
487+
488+
}
489+
447490
static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
448491
// Calculates hellaswag score (acc_norm) from prompt
449492
//
@@ -574,6 +617,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
574617
std::vector<float> tok_logits(n_vocab);
575618
std::vector<float> batch_logits(n_ctx*n_vocab);
576619

620+
std::vector<std::pair<size_t, llama_token>> eval_pairs;
621+
std::vector<float> eval_results;
622+
std::vector<std::thread> workers(std::thread::hardware_concurrency());
623+
577624
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
578625
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
579626
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
@@ -654,6 +701,24 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
654701
return;
655702
}
656703

704+
// Compute log-probs in parallel
705+
// First we collect all tasks
706+
eval_pairs.clear();
707+
for (size_t i = i0; i < i1; ++i) {
708+
auto & hs_cur = hs_data[i];
709+
size_t li = hs_cur.common_prefix;
710+
for (int s = 0; s < 4; ++s) {
711+
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
712+
eval_pairs.push_back(std::make_pair(hs_cur.i_batch + li++, hs_cur.seq_tokens[s][j + 1]));
713+
}
714+
++li;
715+
}
716+
}
717+
// Then we do the actual calculation
718+
hellaswag_compute_logprobs(batch_logits.data(), n_vocab, workers, eval_pairs, eval_results);
719+
720+
size_t ir = 0;
721+
657722
// compute the logprobs for each ending of the decoded tasks
658723
for (size_t i = i0; i < i1; ++i) {
659724
auto & hs_cur = hs_data[i];
@@ -662,26 +727,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
662727

663728
const auto first_probs = softmax(tok_logits);
664729

665-
size_t li = hs_cur.common_prefix; // logits index in the batch
666-
667730
for (int s = 0; s < 4; ++s) {
668731
hs_cur.ending_logprob_count[s] = 1;
669732
hs_cur.ending_logprob[s] = std::log(first_probs[hs_cur.seq_tokens[s][hs_cur.common_prefix]]);
670-
671-
// Calculate the logprobs over the ending
672733
for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) {
673-
std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float));
674-
675-
const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];
676-
677-
hs_cur.ending_logprob[s] += std::log(prob);
734+
hs_cur.ending_logprob[s] += eval_results[ir++];
678735
hs_cur.ending_logprob_count[s]++;
679736
}
680-
681-
// account that we skip the last token in the ending
682-
++li;
683-
684-
// Calculate the mean token logprob for acc_norm
685737
hs_cur.ending_logprob[s] /= hs_cur.ending_logprob_count[s];
686738
}
687739

0 commit comments

Comments
 (0)