Skip to content

Commit 76fa17c

Browse files
authored
Fix whisper kwargs and generation config (#30018)
* clean-up whisper kwargs * failing test
1 parent 9b5a645 commit 76fa17c

File tree

1 file changed

+15
-64
lines changed

1 file changed

+15
-64
lines changed

src/transformers/models/whisper/generation_whisper.py

Lines changed: 15 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,6 @@ def generate(
511511
self._set_language_and_task(
512512
language=language, task=task, is_multilingual=is_multilingual, generation_config=generation_config
513513
)
514-
self._set_token_ids(generation_config=generation_config, config=self.config, kwargs=kwargs)
515514
self._set_num_frames(
516515
return_token_timestamps=return_token_timestamps, generation_config=generation_config, kwargs=kwargs
517516
)
@@ -546,13 +545,13 @@ def generate(
546545
logits_processor=logits_processor,
547546
begin_index=begin_index, # begin index is index of first generated decoder token
548547
is_shortform=is_shortform,
549-
num_beams=kwargs.get("num_beams", 1),
548+
num_beams=generation_config.num_beams,
550549
)
551550

552551
# 5. If we're in shortform mode, simple generate the whole input at once and return the output
553552
if is_shortform:
554553
if temperature is not None:
555-
kwargs["temperature"] = temperature
554+
generation_config.temperature = temperature
556555

557556
decoder_input_ids = kwargs.pop("decoder_input_ids", None)
558557
if decoder_input_ids is None:
@@ -564,8 +563,8 @@ def generate(
564563
[prompt_ids[None].repeat(decoder_input_ids.shape[0], 1), decoder_input_ids], dim=-1
565564
)
566565

567-
if kwargs.get("max_new_tokens", 0) + decoder_input_ids.shape[-1] > self.config.max_target_positions:
568-
max_new_tokens = kwargs.get("max_new_tokens", 0)
566+
max_new_tokens = generation_config.max_new_tokens if generation_config.max_new_tokens is not None else 0
567+
if max_new_tokens + decoder_input_ids.shape[-1] > self.config.max_target_positions:
569568
raise ValueError(
570569
f"The length of `decoder_input_ids` equal `prompt_ids` plus special start tokens is {decoder_input_ids.shape[-1]}, and the `max_new_tokens` "
571570
f"is {max_new_tokens}. Thus, the combined length of "
@@ -666,11 +665,10 @@ def generate(
666665
)
667666

668667
# 6.6 set max new tokens or max length
669-
kwargs = self._set_max_new_tokens_and_length(
668+
self._set_max_new_tokens_and_length(
670669
config=self.config,
671670
decoder_input_ids=decoder_input_ids,
672671
generation_config=generation_config,
673-
kwargs=kwargs,
674672
)
675673

676674
# 6.7 Set current `begin_index` for all logit processors
@@ -770,9 +768,9 @@ def generate_with_fallback(
770768

771769
for fallback_idx, temperature in enumerate(temperatures):
772770
generation_config.do_sample = temperature is not None and temperature > 0.0
773-
774771
generation_config.temperature = temperature if generation_config.do_sample else 1.0
775-
generation_config.num_beams = kwargs.get("num_beams", 1) if not generation_config.do_sample else 1
772+
if generation_config.do_sample:
773+
generation_config.num_beams = 1
776774

777775
generate_kwargs = copy.copy(kwargs)
778776
for key in ["do_sample", "temperature", "num_beams"]:
@@ -1095,20 +1093,15 @@ def replace_or_add(lst: List[int], num: int, itr: Iterator[int]):
10951093
task = getattr(generation_config, "task", None)
10961094
language = getattr(generation_config, "language", None)
10971095

1098-
if kwargs.get("forced_decoder_ids", None) is not None:
1099-
forced_decoder_ids = kwargs["forced_decoder_ids"]
1100-
elif hasattr(generation_config, "forced_decoder_ids") and generation_config.forced_decoder_ids is not None:
1101-
forced_decoder_ids = generation_config.forced_decoder_ids
1102-
1096+
forced_decoder_ids = generation_config.forced_decoder_ids
1097+
if forced_decoder_ids is not None:
11031098
if language is None and task is None and forced_decoder_ids[0][1] is None:
11041099
logger.warning_once(
11051100
"Due to a bug fix in https://github.com/huggingface/transformers/pull/28687 transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English."
11061101
"This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`."
11071102
)
11081103
elif hasattr(config, "forced_decoder_ids") and config.forced_decoder_ids is not None:
11091104
forced_decoder_ids = config.forced_decoder_ids
1110-
else:
1111-
forced_decoder_ids = None
11121105

11131106
if forced_decoder_ids is not None and task is not None:
11141107
logger.info(
@@ -1288,21 +1281,6 @@ def _check_decoder_input_ids(kwargs):
12881281
"Passing `decoder_input_ids` is deprecated. Consider passing `prompt_ids` instead.",
12891282
)
12901283

1291-
@staticmethod
1292-
def _set_token_ids(generation_config, config, kwargs):
1293-
eos_token_id = kwargs.pop("eos_token_id", None)
1294-
decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
1295-
1296-
eos_token_id = eos_token_id if eos_token_id is not None else generation_config.eos_token_id
1297-
decoder_start_token_id = (
1298-
decoder_start_token_id if decoder_start_token_id is not None else generation_config.decoder_start_token_id
1299-
)
1300-
1301-
generation_config.eos_token_id = eos_token_id if eos_token_id is not None else config.eos_token_id
1302-
generation_config.decoder_start_token_id = (
1303-
decoder_start_token_id if decoder_start_token_id is not None else config.decoder_start_token_id
1304-
)
1305-
13061284
@staticmethod
13071285
def _set_num_frames(return_token_timestamps, generation_config, kwargs):
13081286
if return_token_timestamps:
@@ -1313,7 +1291,6 @@ def _set_num_frames(return_token_timestamps, generation_config, kwargs):
13131291
"Model generation config has no `alignment_heads`, token-level timestamps not available. "
13141292
"See https://gist.github.com/hollance/42e32852f24243b748ae6bc1f985b13a on how to add this property to the generation config."
13151293
)
1316-
13171294
generation_config.num_frames = kwargs.pop("num_frames", None)
13181295

13191296
@staticmethod
@@ -1517,47 +1494,21 @@ def _prepare_decoder_input_ids(
15171494
return decoder_input_ids, kwargs
15181495

15191496
@staticmethod
1520-
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config, kwargs):
1497+
def _set_max_new_tokens_and_length(config, decoder_input_ids, generation_config):
15211498
num_initial_tokens = min(config.max_target_positions // 2 - 1, decoder_input_ids.shape[-1] - 1)
15221499

1523-
passed_max_length = kwargs.pop("max_length", None)
1524-
passed_max_new_tokens = kwargs.pop("max_new_tokens", None)
1525-
max_length_config = getattr(generation_config, "max_length", None)
1526-
max_new_tokens_config = getattr(generation_config, "max_new_tokens", None)
1527-
1528-
max_new_tokens = None
1529-
max_length = None
1530-
15311500
# Make sure we don't get larger than `max_length`
1532-
if passed_max_length is not None and passed_max_new_tokens is None:
1533-
max_length = min(passed_max_length + num_initial_tokens, config.max_target_positions)
1534-
logger.info(
1535-
f"Increase max_length from {passed_max_length} to {max_length} since input is conditioned on previous segment."
1536-
)
1537-
elif max_length_config is not None and passed_max_new_tokens is None and max_new_tokens_config is None:
1501+
if generation_config.max_length is not None and generation_config.max_new_tokens is None:
15381502
max_length = min(generation_config.max_length + num_initial_tokens, config.max_target_positions)
15391503
logger.info(
1540-
f"Increase max_length from {max_length_config} to {max_length} since input is conditioned on previous segment."
1504+
f"Increase max_length from {generation_config.max_length} to {max_length} since input is conditioned on previous segment."
15411505
)
15421506
elif (
1543-
passed_max_new_tokens is not None
1544-
and passed_max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
1507+
generation_config.max_new_tokens is not None
1508+
and generation_config.max_new_tokens + decoder_input_ids.shape[-1] > config.max_target_positions
15451509
):
15461510
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
1547-
elif (
1548-
passed_max_new_tokens is None
1549-
and max_new_tokens_config is not None
1550-
and max_new_tokens_config + decoder_input_ids.shape[-1] > config.max_target_positions
1551-
):
1552-
max_new_tokens = config.max_target_positions - decoder_input_ids.shape[-1]
1553-
1554-
if max_new_tokens is not None:
1555-
kwargs["max_new_tokens"] = max_new_tokens
1556-
1557-
if max_length is not None:
1558-
kwargs["max_length"] = max_length
1559-
1560-
return kwargs
1511+
generation_config.max_new_tokens = max_new_tokens
15611512

15621513
@staticmethod
15631514
def _retrieve_compression_ratio(tokens, vocab_size):

0 commit comments

Comments
 (0)