File tree Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Expand file tree Collapse file tree 1 file changed +8
-3
lines changed Original file line number Diff line number Diff line change @@ -560,6 +560,7 @@ def loglikelihood_rolling(
560
560
requests ,
561
561
override_bs = override_bs ,
562
562
return_bool_score = False ,
563
+ rolling = True ,
563
564
)
564
565
return results
565
566
@@ -568,6 +569,7 @@ def _loglikelihood_tokens(
568
569
requests : list [LoglikelihoodRequest ],
569
570
override_bs : int = - 1 ,
570
571
return_bool_score : bool = True ,
572
+ rolling : bool = False ,
571
573
) -> list [LoglikelihoodReturn ]:
572
574
dataset = LoglikelihoodDataset (requests = requests , dataset_splits = self .DATASET_SPLITS )
573
575
starting_batch_size = STARTING_BATCH_SIZE
@@ -576,9 +578,12 @@ def _loglikelihood_tokens(
576
578
for split_start , split_end in tqdm (dataset .splits_start_end_iterator ()):
577
579
context_enc = dataset [0 ].tokenized_context
578
580
continuation_enc = dataset [0 ].tokenized_continuation
579
- max_context_continuation_size_allowed = len (
580
- (context_enc + continuation_enc )[- (self .max_length + 1 ) :][:- 1 ]
581
- )
581
+ if rolling : # we take all the sequence in rolling mode
582
+ max_context_continuation_size_allowed = len (context_enc + continuation_enc )
583
+ else : # in normal mode, we left cut the context if needed
584
+ max_context_continuation_size_allowed = len (
585
+ (context_enc + continuation_enc )[- (self .max_length + 1 ) :][:- 1 ]
586
+ )
582
587
583
588
batch_size = self ._get_batch_size (
584
589
override_bs = override_bs ,
You can’t perform that action at this time.
0 commit comments