Skip to content

Commit db44c93

Browse files
authored
Rolling management (#78)
1 parent 0dcb495 commit db44c93

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/lighteval/models/base_model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,7 @@ def loglikelihood_rolling(
560560
requests,
561561
override_bs=override_bs,
562562
return_bool_score=False,
563+
rolling=True,
563564
)
564565
return results
565566

@@ -568,6 +569,7 @@ def _loglikelihood_tokens(
568569
requests: list[LoglikelihoodRequest],
569570
override_bs: int = -1,
570571
return_bool_score: bool = True,
572+
rolling: bool = False,
571573
) -> list[LoglikelihoodReturn]:
572574
dataset = LoglikelihoodDataset(requests=requests, dataset_splits=self.DATASET_SPLITS)
573575
starting_batch_size = STARTING_BATCH_SIZE
@@ -576,9 +578,12 @@ def _loglikelihood_tokens(
576578
for split_start, split_end in tqdm(dataset.splits_start_end_iterator()):
577579
context_enc = dataset[0].tokenized_context
578580
continuation_enc = dataset[0].tokenized_continuation
579-
max_context_continuation_size_allowed = len(
580-
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
581-
)
581+
if rolling: # we take all the sequence in rolling mode
582+
max_context_continuation_size_allowed = len(context_enc + continuation_enc)
583+
else: # in normal mode, we left cut the context if needed
584+
max_context_continuation_size_allowed = len(
585+
(context_enc + continuation_enc)[-(self.max_length + 1) :][:-1]
586+
)
582587

583588
batch_size = self._get_batch_size(
584589
override_bs=override_bs,

0 commit comments

Comments
 (0)