@@ -511,7 +511,6 @@ def generate(
511
511
self ._set_language_and_task (
512
512
language = language , task = task , is_multilingual = is_multilingual , generation_config = generation_config
513
513
)
514
- self ._set_token_ids (generation_config = generation_config , config = self .config , kwargs = kwargs )
515
514
self ._set_num_frames (
516
515
return_token_timestamps = return_token_timestamps , generation_config = generation_config , kwargs = kwargs
517
516
)
@@ -546,13 +545,13 @@ def generate(
546
545
logits_processor = logits_processor ,
547
546
begin_index = begin_index , # begin index is index of first generated decoder token
548
547
is_shortform = is_shortform ,
549
- num_beams = kwargs . get ( " num_beams" , 1 ) ,
548
+ num_beams = generation_config . num_beams ,
550
549
)
551
550
552
551
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
553
552
if is_shortform :
554
553
if temperature is not None :
555
- kwargs [ " temperature" ] = temperature
554
+ generation_config . temperature = temperature
556
555
557
556
decoder_input_ids = kwargs .pop ("decoder_input_ids" , None )
558
557
if decoder_input_ids is None :
@@ -564,8 +563,8 @@ def generate(
564
563
[prompt_ids [None ].repeat (decoder_input_ids .shape [0 ], 1 ), decoder_input_ids ], dim = - 1
565
564
)
566
565
567
- if kwargs . get ( " max_new_tokens" , 0 ) + decoder_input_ids . shape [ - 1 ] > self . config . max_target_positions :
568
- max_new_tokens = kwargs . get ( "max_new_tokens" , 0 )
566
+ max_new_tokens = generation_config . max_new_tokens if generation_config . max_new_tokens is not None else 0
567
+ if max_new_tokens + decoder_input_ids . shape [ - 1 ] > self . config . max_target_positions :
569
568
raise ValueError (
570
569
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is { decoder_input_ids .shape [- 1 ]} , and the `max_new_tokens` "
571
570
f"is { max_new_tokens } . Thus, the combined length of "
@@ -666,11 +665,10 @@ def generate(
666
665
)
667
666
668
667
# 6.6 set max new tokens or max length
669
- kwargs = self ._set_max_new_tokens_and_length (
668
+ self ._set_max_new_tokens_and_length (
670
669
config = self .config ,
671
670
decoder_input_ids = decoder_input_ids ,
672
671
generation_config = generation_config ,
673
- kwargs = kwargs ,
674
672
)
675
673
676
674
# 6.7 Set current `begin_index` for all logit processors
@@ -770,9 +768,9 @@ def generate_with_fallback(
770
768
771
769
for fallback_idx , temperature in enumerate (temperatures ):
772
770
generation_config .do_sample = temperature is not None and temperature > 0.0
773
-
774
771
generation_config .temperature = temperature if generation_config .do_sample else 1.0
775
- generation_config .num_beams = kwargs .get ("num_beams" , 1 ) if not generation_config .do_sample else 1
772
+ if generation_config .do_sample :
773
+ generation_config .num_beams = 1
776
774
777
775
generate_kwargs = copy .copy (kwargs )
778
776
for key in ["do_sample" , "temperature" , "num_beams" ]:
@@ -1095,20 +1093,15 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
1095
1093
task = getattr (generation_config , "task" , None )
1096
1094
language = getattr (generation_config , "language" , None )
1097
1095
1098
- if kwargs .get ("forced_decoder_ids" , None ) is not None :
1099
- forced_decoder_ids = kwargs ["forced_decoder_ids" ]
1100
- elif hasattr (generation_config , "forced_decoder_ids" ) and generation_config .forced_decoder_ids is not None :
1101
- forced_decoder_ids = generation_config .forced_decoder_ids
1102
-
1096
+ forced_decoder_ids = generation_config .forced_decoder_ids
1097
+ if forced_decoder_ids is not None :
1103
1098
if language is None and task is None and forced_decoder_ids [0 ][1 ] is None :
1104
1099
logger .warning_once (
1105
1100
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
1106
1101
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
1107
1102
)
1108
1103
elif hasattr (config , "forced_decoder_ids" ) and config .forced_decoder_ids is not None :
1109
1104
forced_decoder_ids = config .forced_decoder_ids
1110
- else :
1111
- forced_decoder_ids = None
1112
1105
1113
1106
if forced_decoder_ids is not None and task is not None :
1114
1107
logger .info (
@@ -1288,21 +1281,6 @@ def _check_decoder_input_ids(kwargs):
1288
1281
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead." ,
1289
1282
)
1290
1283
1291
- @staticmethod
1292
- def _set_token_ids (generation_config , config , kwargs ):
1293
- eos_token_id = kwargs .pop ("eos_token_id" , None )
1294
- decoder_start_token_id = kwargs .pop ("decoder_start_token_id" , None )
1295
-
1296
- eos_token_id = eos_token_id if eos_token_id is not None else generation_config .eos_token_id
1297
- decoder_start_token_id = (
1298
- decoder_start_token_id if decoder_start_token_id is not None else generation_config .decoder_start_token_id
1299
- )
1300
-
1301
- generation_config .eos_token_id = eos_token_id if eos_token_id is not None else config .eos_token_id
1302
- generation_config .decoder_start_token_id = (
1303
- decoder_start_token_id if decoder_start_token_id is not None else config .decoder_start_token_id
1304
- )
1305
-
1306
1284
@staticmethod
1307
1285
def _set_num_frames (return_token_timestamps , generation_config , kwargs ):
1308
1286
if return_token_timestamps :
@@ -1313,7 +1291,6 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
1313
1291
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
1314
1292
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
1315
1293
)
1316
-
1317
1294
generation_config .num_frames = kwargs .pop ("num_frames" , None )
1318
1295
1319
1296
@staticmethod
@@ -1517,47 +1494,21 @@ def _prepare_decoder_input_ids(
1517
1494
return decoder_input_ids , kwargs
1518
1495
1519
1496
@staticmethod
1520
- def _set_max_new_tokens_and_length (config , decoder_input_ids , generation_config , kwargs ):
1497
+ def _set_max_new_tokens_and_length (config , decoder_input_ids , generation_config ):
1521
1498
num_initial_tokens = min (config .max_target_positions // 2 - 1 , decoder_input_ids .shape [- 1 ] - 1 )
1522
1499
1523
- passed_max_length = kwargs .pop ("max_length" , None )
1524
- passed_max_new_tokens = kwargs .pop ("max_new_tokens" , None )
1525
- max_length_config = getattr (generation_config , "max_length" , None )
1526
- max_new_tokens_config = getattr (generation_config , "max_new_tokens" , None )
1527
-
1528
- max_new_tokens = None
1529
- max_length = None
1530
-
1531
1500
# Make sure we don't get larger than `max_length`
1532
- if passed_max_length is not None and passed_max_new_tokens is None :
1533
- max_length = min (passed_max_length + num_initial_tokens , config .max_target_positions )
1534
- logger .info (
1535
- f"Increase max_length from { passed_max_length } to { max_length } since input is conditioned on previous segment."
1536
- )
1537
- elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None :
1501
+ if generation_config .max_length is not None and generation_config .max_new_tokens is None :
1538
1502
max_length = min (generation_config .max_length + num_initial_tokens , config .max_target_positions )
1539
1503
logger .info (
1540
- f"Increase max_length from { max_length_config } to { max_length } since input is conditioned on previous segment."
1504
+ f"Increase max_length from { generation_config . max_length } to { max_length } since input is conditioned on previous segment."
1541
1505
)
1542
1506
elif (
1543
- passed_max_new_tokens is not None
1544
- and passed_max_new_tokens + decoder_input_ids .shape [- 1 ] > config .max_target_positions
1507
+ generation_config . max_new_tokens is not None
1508
+ and generation_config . max_new_tokens + decoder_input_ids .shape [- 1 ] > config .max_target_positions
1545
1509
):
1546
1510
max_new_tokens = config .max_target_positions - decoder_input_ids .shape [- 1 ]
1547
- elif (
1548
- passed_max_new_tokens is None
1549
- and max_new_tokens_config is not None
1550
- and max_new_tokens_config + decoder_input_ids .shape [- 1 ] > config .max_target_positions
1551
- ):
1552
- max_new_tokens = config .max_target_positions - decoder_input_ids .shape [- 1 ]
1553
-
1554
- if max_new_tokens is not None :
1555
- kwargs ["max_new_tokens" ] = max_new_tokens
1556
-
1557
- if max_length is not None :
1558
- kwargs ["max_length" ] = max_length
1559
-
1560
- return kwargs
1511
+ generation_config .max_new_tokens = max_new_tokens
1561
1512
1562
1513
@staticmethod
1563
1514
def _retrieve_compression_ratio (tokens , vocab_size ):
0 commit comments