Skip to content

Commit 165ebc9

Browse files
Fixes input length management for generative evals (#103)
--------- Co-authored-by: Nathan Habib <[email protected]>
1 parent 286912e commit 165ebc9

File tree

1 file changed

+43
-31
lines changed

1 file changed

+43
-31
lines changed

src/lighteval/models/base_model.py

Lines changed: 43 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -407,46 +407,55 @@ def greedy_until(
407407
# stop_tokens and max_tokens genrated) which is not necessarily
408408
# the case! Because of that we only use batch size of 1
409409
stop_tokens = batch[0].stop_sequence
410+
max_new_tokens = batch[0].generation_size
411+
412+
# The main question for this step is the following:
413+
# Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
414+
# of loosing some meaning, or have some generations that are exceedingly short?
415+
# The choice we go for here is to avoid truncating the prompt if we can, since it
416+
# should have been managed by the prompt creator/few shot manager if requested by the user.
410417
context = [c.context for c in batch]
411-
max_context_size_allowed = self.max_length
412-
if batch[0].generation_size is None:
413-
# No constraints on max tokens except the model and data
414-
# Max generation possible is the max_length - the smallest context
415-
smallest_context = min([len(c) for c in context])
416-
if smallest_context < self.max_length:
417-
max_generated_tokens = self.max_length - smallest_context
418-
max_context_size_allowed = self.max_length
419-
else:
420-
# The max context size is smaller than the smallest context
421-
max_generated_tokens = 1
422-
max_context_size_allowed = self.max_length - 1
423-
hlog_warn(
424-
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in {[i.task_name for i in batch]}. This is likely to lead to some errors."
425-
)
426-
else:
427-
max_generated_tokens = batch[0].generation_size
428-
max_context_size_allowed = self.max_length - max_generated_tokens
418+
smallest_context = min(len(c) for c in context)
419+
biggest_context = max(len(c) for c in context)
420+
if smallest_context > self.max_length:
421+
hlog_warn(
422+
f"The smallest context of your batch ({smallest_context}) is bigger than the maximum context size allowed by the model ({self.max_length}) for a task in"
423+
+ str({i.task_name for i in batch})
424+
+ ". This is likely to lead to some errors." # noqa C401
425+
)
429426

427+
if (
428+
biggest_context > self.max_length
429+
): # There will be truncation of at least one sample, maximum generation size will be one
430+
max_new_tokens = 1
431+
else: # We can't allow generation of more than max_length
432+
max_new_tokens = min(self.max_length - biggest_context, max_new_tokens)
433+
434+
# See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
435+
# Will do left truncation and padding, as defined when creating the tokenizer
430436
tokenized = self.tokenizer(
431437
context,
432-
padding=True,
433-
truncation=True,
438+
truncation="longest_first", # we truncate to the model max length if needed
439+
padding="longest", # we pad to the longest sequence
434440
return_tensors="pt",
435-
max_length=max_context_size_allowed,
441+
max_length=self.max_length - 1, # we always allow minimum one token of generation
436442
add_special_tokens=self.add_special_tokens,
437443
).to(self.device)
438444

439445
prepared_batch = Batch(
440446
input_ids=tokenized["input_ids"],
441447
input_lengths=[len(item == 1) for item in tokenized["attention_mask"]],
442448
input_mask=tokenized["attention_mask"],
443-
truncated=[0] * len(tokenized["input_ids"]),
444-
padded=[0] * len(tokenized["input_ids"]),
449+
truncated=[
450+
len(c) - tokenized["input_ids"].shape[1] if len(c) > tokenized["input_ids"].shape[1] else 0
451+
for c in context
452+
],
453+
padded=[sum(mask == 0) for mask in tokenized["attention_mask"]],
445454
)
446455

447456
cur_reponses = self._generate(
448457
batch=prepared_batch,
449-
max_tokens=max_generated_tokens,
458+
max_new_tokens=max_new_tokens,
450459
stop_tokens=stop_tokens,
451460
returns_logits=returns_logits,
452461
)
@@ -457,7 +466,7 @@ def greedy_until(
457466
def _generate(
458467
self,
459468
batch: Batch,
460-
max_tokens: int,
469+
max_new_tokens: int,
461470
stop_tokens: list[str],
462471
returns_logits: Optional[bool] = False,
463472
) -> list[GenerateReturn]:
@@ -470,7 +479,7 @@ def _generate(
470479
outputs = self.model.generate(
471480
input_ids=batch.input_ids,
472481
attention_mask=batch.input_mask,
473-
max_new_tokens=max_tokens,
482+
max_new_tokens=max_new_tokens,
474483
stopping_criteria=stopping_criteria,
475484
do_sample=False,
476485
pad_token_id=self.tokenizer.pad_token_id if self.tokenizer.pad_token_id else self.tokenizer.eos_token_id,
@@ -598,7 +607,7 @@ def _loglikelihood_tokens(
598607
dataloader = self.accelerator.prepare(dataloader)
599608

600609
for batch in tqdm(dataloader, disable=self.disable_tqdm):
601-
prepared_batch = self.prepare_batch(
610+
prepared_batch = self.prepare_batch_logprob(
602611
batch,
603612
padding_length=max_context_continuation_size_allowed,
604613
max_context=max_context_continuation_size_allowed,
@@ -682,10 +691,13 @@ def _loglikelihood_tokens(
682691

683692
return dataset.get_original_order(res)
684693

685-
def prepare_batch(
694+
def prepare_batch_logprob(
686695
self, batch: list[Request], padding_length: int, max_context: Optional[int] = None, single_token: bool = False
687696
):
688-
"""Tokenize a batch of inputs and return also the length, truncations and padding"""
697+
"""Tokenize a batch of inputs and return also the length, truncations and padding.
698+
This step is done manually since we tokenize log probability inputs together with their continuation,
699+
to manage possible extra spaces added at the start by tokenizers, see tok_encode_pair.
700+
"""
689701
if single_token:
690702
inputs = [request.tokenized_context for request in batch]
691703
else:
@@ -719,7 +731,7 @@ def prepare_batch(
719731
raise ValueError("Negative padding")
720732

721733
padded.append(padding_length - sequence_len)
722-
# Right padding - it likely would be better to do left padding
734+
# Right padding, since we ignore these logprobs in the end
723735
tokens = F.pad(tokens, (0, padding_length - sequence_len), value=self.tokenizer.pad_token_id)
724736

725737
# We create the attention mask to ignore padding
@@ -806,7 +818,7 @@ def _loglikelihood_single_token(
806818
dataloader = self.accelerator.prepare(dataloader)
807819

808820
for batch in tqdm(dataloader, disable=self.disable_tqdm, position=1):
809-
prepared_batch = self.prepare_batch(
821+
prepared_batch = self.prepare_batch_logprob(
810822
batch, padding_length=max_context, max_context=max_context, single_token=True
811823
)
812824

0 commit comments

Comments
 (0)