@@ -407,46 +407,55 @@ def greedy_until(
407
407
# stop_tokens and max_tokens genrated) which is not necessarily
408
408
# the case! Because of that we only use batch size of 1
409
409
stop_tokens = batch [0 ].stop_sequence
410
+ max_new_tokens = batch [0 ].generation_size
411
+
412
+ # The main question for this step is the following:
413
+ # Would we rather truncate the prompt to allow generation to go to max_new_tokens, at the risk
414
+ # of loosing some meaning, or have some generations that are exceedingly short?
415
+ # The choice we go for here is to avoid truncating the prompt if we can, since it
416
+ # should have been managed by the prompt creator/few shot manager if requested by the user.
410
417
context = [c .context for c in batch ]
411
- max_context_size_allowed = self .max_length
412
- if batch [0 ].generation_size is None :
413
- # No constraints on max tokens except the model and data
414
- # Max generation possible is the max_length - the smallest context
415
- smallest_context = min ([len (c ) for c in context ])
416
- if smallest_context < self .max_length :
417
- max_generated_tokens = self .max_length - smallest_context
418
- max_context_size_allowed = self .max_length
419
- else :
420
- # The max context size is smaller than the smallest context
421
- max_generated_tokens = 1
422
- max_context_size_allowed = self .max_length - 1
423
- hlog_warn (
424
- f"The smallest context of your batch ({ smallest_context } ) is bigger than the maximum context size allowed by the model ({ self .max_length } ) for a task in { [i .task_name for i in batch ]} . This is likely to lead to some errors."
425
- )
426
- else :
427
- max_generated_tokens = batch [0 ].generation_size
428
- max_context_size_allowed = self .max_length - max_generated_tokens
418
+ smallest_context = min (len (c ) for c in context )
419
+ biggest_context = max (len (c ) for c in context )
420
+ if smallest_context > self .max_length :
421
+ hlog_warn (
422
+ f"The smallest context of your batch ({ smallest_context } ) is bigger than the maximum context size allowed by the model ({ self .max_length } ) for a task in"
423
+ + str ({i .task_name for i in batch })
424
+ + ". This is likely to lead to some errors." # noqa C401
425
+ )
429
426
427
+ if (
428
+ biggest_context > self .max_length
429
+ ): # There will be truncation of at least one sample, maximum generation size will be one
430
+ max_new_tokens = 1
431
+ else : # We can't allow generation of more than max_length
432
+ max_new_tokens = min (self .max_length - biggest_context , max_new_tokens )
433
+
434
+ # See doc https://huggingface.co/docs/transformers/v4.38.2/en/pad_truncation#padding-and-truncation
435
+ # Will do left truncation and padding, as defined when creating the tokenizer
430
436
tokenized = self .tokenizer (
431
437
context ,
432
- padding = True ,
433
- truncation = True ,
438
+ truncation = "longest_first" , # we truncate to the model max length if needed
439
+ padding = "longest" , # we pad to the longest sequence
434
440
return_tensors = "pt" ,
435
- max_length = max_context_size_allowed ,
441
+ max_length = self . max_length - 1 , # we always allow minimum one token of generation
436
442
add_special_tokens = self .add_special_tokens ,
437
443
).to (self .device )
438
444
439
445
prepared_batch = Batch (
440
446
input_ids = tokenized ["input_ids" ],
441
447
input_lengths = [len (item == 1 ) for item in tokenized ["attention_mask" ]],
442
448
input_mask = tokenized ["attention_mask" ],
443
- truncated = [0 ] * len (tokenized ["input_ids" ]),
444
- padded = [0 ] * len (tokenized ["input_ids" ]),
449
+ truncated = [
450
+ len (c ) - tokenized ["input_ids" ].shape [1 ] if len (c ) > tokenized ["input_ids" ].shape [1 ] else 0
451
+ for c in context
452
+ ],
453
+ padded = [sum (mask == 0 ) for mask in tokenized ["attention_mask" ]],
445
454
)
446
455
447
456
cur_reponses = self ._generate (
448
457
batch = prepared_batch ,
449
- max_tokens = max_generated_tokens ,
458
+ max_new_tokens = max_new_tokens ,
450
459
stop_tokens = stop_tokens ,
451
460
returns_logits = returns_logits ,
452
461
)
@@ -457,7 +466,7 @@ def greedy_until(
457
466
def _generate (
458
467
self ,
459
468
batch : Batch ,
460
- max_tokens : int ,
469
+ max_new_tokens : int ,
461
470
stop_tokens : list [str ],
462
471
returns_logits : Optional [bool ] = False ,
463
472
) -> list [GenerateReturn ]:
@@ -470,7 +479,7 @@ def _generate(
470
479
outputs = self .model .generate (
471
480
input_ids = batch .input_ids ,
472
481
attention_mask = batch .input_mask ,
473
- max_new_tokens = max_tokens ,
482
+ max_new_tokens = max_new_tokens ,
474
483
stopping_criteria = stopping_criteria ,
475
484
do_sample = False ,
476
485
pad_token_id = self .tokenizer .pad_token_id if self .tokenizer .pad_token_id else self .tokenizer .eos_token_id ,
@@ -598,7 +607,7 @@ def _loglikelihood_tokens(
598
607
dataloader = self .accelerator .prepare (dataloader )
599
608
600
609
for batch in tqdm (dataloader , disable = self .disable_tqdm ):
601
- prepared_batch = self .prepare_batch (
610
+ prepared_batch = self .prepare_batch_logprob (
602
611
batch ,
603
612
padding_length = max_context_continuation_size_allowed ,
604
613
max_context = max_context_continuation_size_allowed ,
@@ -682,10 +691,13 @@ def _loglikelihood_tokens(
682
691
683
692
return dataset .get_original_order (res )
684
693
685
- def prepare_batch (
694
+ def prepare_batch_logprob (
686
695
self , batch : list [Request ], padding_length : int , max_context : Optional [int ] = None , single_token : bool = False
687
696
):
688
- """Tokenize a batch of inputs and return also the length, truncations and padding"""
697
+ """Tokenize a batch of inputs and return also the length, truncations and padding.
698
+ This step is done manually since we tokenize log probability inputs together with their continuation,
699
+ to manage possible extra spaces added at the start by tokenizers, see tok_encode_pair.
700
+ """
689
701
if single_token :
690
702
inputs = [request .tokenized_context for request in batch ]
691
703
else :
@@ -719,7 +731,7 @@ def prepare_batch(
719
731
raise ValueError ("Negative padding" )
720
732
721
733
padded .append (padding_length - sequence_len )
722
- # Right padding - it likely would be better to do left padding
734
+ # Right padding, since we ignore these logprobs in the end
723
735
tokens = F .pad (tokens , (0 , padding_length - sequence_len ), value = self .tokenizer .pad_token_id )
724
736
725
737
# We create the attention mask to ignore padding
@@ -806,7 +818,7 @@ def _loglikelihood_single_token(
806
818
dataloader = self .accelerator .prepare (dataloader )
807
819
808
820
for batch in tqdm (dataloader , disable = self .disable_tqdm , position = 1 ):
809
- prepared_batch = self .prepare_batch (
821
+ prepared_batch = self .prepare_batch_logprob (
810
822
batch , padding_length = max_context , max_context = max_context , single_token = True
811
823
)
812
824
0 commit comments