8
8
#include < sstream>
9
9
#include < thread>
10
10
#include < mutex>
11
+ #include < atomic>
11
12
#include < vector>
12
13
#include < array>
13
14
#include < fstream>
@@ -444,6 +445,48 @@ static std::vector<float> evaluate_tokens(llama_context * ctx, std::vector<int>
444
445
return result;
445
446
}
446
447
448
+ static void hellaswag_compute_logprobs (const float * batch_logits, int n_vocab, std::vector<std::thread>& workers,
449
+ const std::vector<std::pair<size_t , llama_token>>& eval_pairs, std::vector<float >& eval_results) {
450
+ constexpr int k_token_chunk = 4 ;
451
+ if (eval_results.size () != eval_pairs.size ()) {
452
+ eval_results.resize (eval_pairs.size ());
453
+ }
454
+ if (eval_pairs.empty ()) return ;
455
+
456
+ size_t max_threads = std::min ((eval_pairs.size () + k_token_chunk - 1 )/k_token_chunk, workers.size ());
457
+
458
+ std::atomic<int > counter (0 );
459
+ auto compute = [&counter, &eval_pairs, &eval_results, batch_logits, n_vocab] () {
460
+ float local_logprobs[k_token_chunk];
461
+ while (true ) {
462
+ size_t first = counter.fetch_add (k_token_chunk, std::memory_order_relaxed);
463
+ if (first >= eval_results.size ()) break ;
464
+ size_t last = std::min (first + k_token_chunk, eval_results.size ());
465
+ for (size_t i = first; i < last; ++i) {
466
+ auto logits = batch_logits + eval_pairs[i].first * n_vocab;
467
+ float max_logit = logits[0 ];
468
+ for (int j = 1 ; j < n_vocab; ++j) {
469
+ max_logit = std::max (max_logit, logits[j]);
470
+ }
471
+ float sum_p = 0 .f ;
472
+ for (int j = 0 ; j < n_vocab; ++j) {
473
+ sum_p += expf (logits[j] - max_logit);
474
+ }
475
+ local_logprobs[i - first] = logits[eval_pairs[i].second ] - max_logit - std::log (sum_p);
476
+ }
477
+ std::memcpy (eval_results.data () + first, local_logprobs, (last - first)*sizeof (float ));
478
+ }
479
+ };
480
+
481
+ for (size_t it = 0 ; it < max_threads; ++it) {
482
+ workers[it] = std::thread (compute);
483
+ }
484
+ for (size_t it = 0 ; it < max_threads; ++it) {
485
+ workers[it].join ();
486
+ }
487
+
488
+ }
489
+
447
490
static void hellaswag_score (llama_context * ctx, const gpt_params & params) {
448
491
// Calculates hellaswag score (acc_norm) from prompt
449
492
//
@@ -574,6 +617,10 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
574
617
std::vector<float > tok_logits (n_vocab);
575
618
std::vector<float > batch_logits (n_ctx*n_vocab);
576
619
620
+ std::vector<std::pair<size_t , llama_token>> eval_pairs;
621
+ std::vector<float > eval_results;
622
+ std::vector<std::thread> workers (std::thread::hardware_concurrency ());
623
+
577
624
auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
578
625
for (int32_t i = 0 ; i < (int32_t ) batch.n_tokens ; i += n_batch) {
579
626
const int32_t n_tokens = std::min (n_batch, (int32_t ) (batch.n_tokens - i));
@@ -654,6 +701,24 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
654
701
return ;
655
702
}
656
703
704
+ // Compute log-probs in parallel
705
+ // First we collect all tasks
706
+ eval_pairs.clear ();
707
+ for (size_t i = i0; i < i1; ++i) {
708
+ auto & hs_cur = hs_data[i];
709
+ size_t li = hs_cur.common_prefix ;
710
+ for (int s = 0 ; s < 4 ; ++s) {
711
+ for (size_t j = hs_cur.common_prefix ; j < hs_cur.seq_tokens [s].size () - 1 ; j++) {
712
+ eval_pairs.push_back (std::make_pair (hs_cur.i_batch + li++, hs_cur.seq_tokens [s][j + 1 ]));
713
+ }
714
+ ++li;
715
+ }
716
+ }
717
+ // Then we do the actual calculation
718
+ hellaswag_compute_logprobs (batch_logits.data (), n_vocab, workers, eval_pairs, eval_results);
719
+
720
+ size_t ir = 0 ;
721
+
657
722
// compute the logprobs for each ending of the decoded tasks
658
723
for (size_t i = i0; i < i1; ++i) {
659
724
auto & hs_cur = hs_data[i];
@@ -662,26 +727,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
662
727
663
728
const auto first_probs = softmax (tok_logits);
664
729
665
- size_t li = hs_cur.common_prefix ; // logits index in the batch
666
-
667
730
for (int s = 0 ; s < 4 ; ++s) {
668
731
hs_cur.ending_logprob_count [s] = 1 ;
669
732
hs_cur.ending_logprob [s] = std::log (first_probs[hs_cur.seq_tokens [s][hs_cur.common_prefix ]]);
670
-
671
- // Calculate the logprobs over the ending
672
733
for (size_t j = hs_cur.common_prefix ; j < hs_cur.seq_tokens [s].size () - 1 ; j++) {
673
- std::memcpy (tok_logits.data (), batch_logits.data () + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof (float ));
674
-
675
- const float prob = softmax (tok_logits)[hs_cur.seq_tokens [s][j + 1 ]];
676
-
677
- hs_cur.ending_logprob [s] += std::log (prob);
734
+ hs_cur.ending_logprob [s] += eval_results[ir++];
678
735
hs_cur.ending_logprob_count [s]++;
679
736
}
680
-
681
- // account that we skip the last token in the ending
682
- ++li;
683
-
684
- // Calculate the mean token logprob for acc_norm
685
737
hs_cur.ending_logprob [s] /= hs_cur.ending_logprob_count [s];
686
738
}
687
739
0 commit comments