Skip to content

Commit 94f4876

Browse files
authored
[generate] model defaults being inherited only happens for newer models (#36881)
1 parent f19d018 commit 94f4876

File tree

2 files changed

+54
-26
lines changed

2 files changed

+54
-26
lines changed

src/transformers/generation/utils.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
import torch
2525
import torch.distributed as dist
26+
from packaging import version
2627
from torch import nn
2728
from torch.nn import functional as F
2829

@@ -1552,7 +1553,7 @@ def _prepare_generated_length(
15521553
return generation_config
15531554

15541555
def _prepare_generation_config(
1555-
self, generation_config: Optional[GenerationConfig], **kwargs: Dict
1556+
self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict
15561557
) -> Tuple[GenerationConfig, Dict]:
15571558
"""
15581559
Prepares the base generation config, then applies any generation configuration options from kwargs. This
@@ -1591,23 +1592,38 @@ def _prepare_generation_config(
15911592

15921593
generation_config = copy.deepcopy(generation_config)
15931594

1594-
# If `generation_config` is provided, let's fallback ALL default values to the model's generation config
15951595
if not using_model_generation_config:
1596-
modified_values = {}
1597-
default_generation_config = GenerationConfig()
1598-
for key, default_value in default_generation_config.__dict__.items():
1599-
if key.startswith("_"): # metadata
1600-
continue
1601-
custom_gen_config_value = getattr(generation_config, key)
1602-
model_gen_config_value = getattr(self.generation_config, key)
1603-
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
1604-
modified_values[key] = model_gen_config_value
1605-
setattr(generation_config, key, model_gen_config_value)
1606-
if len(modified_values) > 0:
1607-
logger.warning_once(
1608-
f"`generation_config` default values have been modified to match model-specific defaults: "
1609-
f"{modified_values}. If this is not desired, please set these values explicitly."
1610-
)
1596+
# If `generation_config` is provided:
1597+
# - `use_model_defaults`: let's fallback ALL default values to the model's generation config
1598+
# - otherwise: legacy behavior, let's just make sure we have the tokens defined
1599+
model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version)
1600+
if use_model_defaults is True or (
1601+
use_model_defaults is None and model_base_version >= version.parse("4.50.0")
1602+
):
1603+
modified_values = {}
1604+
default_generation_config = GenerationConfig()
1605+
for key, default_value in default_generation_config.__dict__.items():
1606+
if key.startswith("_") or key == "transformers_version": # metadata
1607+
continue
1608+
custom_gen_config_value = getattr(generation_config, key)
1609+
model_gen_config_value = getattr(self.generation_config, key)
1610+
if custom_gen_config_value == default_value and model_gen_config_value != default_value:
1611+
modified_values[key] = model_gen_config_value
1612+
setattr(generation_config, key, model_gen_config_value)
1613+
if len(modified_values) > 0:
1614+
logger.warning_once(
1615+
f"`generation_config` default values have been modified to match model-specific defaults: "
1616+
f"{modified_values}. If this is not desired, please set these values explicitly."
1617+
)
1618+
else:
1619+
if generation_config.bos_token_id is None:
1620+
generation_config.bos_token_id = self.generation_config.bos_token_id
1621+
if generation_config.eos_token_id is None:
1622+
generation_config.eos_token_id = self.generation_config.eos_token_id
1623+
if generation_config.pad_token_id is None:
1624+
generation_config.pad_token_id = self.generation_config.pad_token_id
1625+
if generation_config.decoder_start_token_id is None:
1626+
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
16111627

16121628
# Finally, apply any passed kwargs
16131629
model_kwargs = generation_config.update(**kwargs)
@@ -1967,6 +1983,7 @@ def generate(
19671983
streamer: Optional["BaseStreamer"] = None,
19681984
negative_prompt_ids: Optional[torch.Tensor] = None,
19691985
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
1986+
use_model_defaults: Optional[bool] = None,
19701987
**kwargs,
19711988
) -> Union[GenerateOutput, torch.LongTensor]:
19721989
r"""
@@ -2031,6 +2048,11 @@ def generate(
20312048
size. This is an experimental feature, subject to breaking API changes in future versions.
20322049
negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
20332050
Attention_mask for `negative_prompt_ids`.
2051+
use_model_defaults (`bool`, *optional*):
2052+
When it is `True`, unset parameters in `generation_config` will be set to the model-specific default
2053+
generation configuration (`model.generation_config`), as opposed to the global defaults
2054+
(`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be
2055+
`True`.
20342056
kwargs (`Dict[str, Any]`, *optional*):
20352057
Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
20362058
forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
@@ -2058,7 +2080,9 @@ def generate(
20582080
tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
20592081
assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
20602082

2061-
generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
2083+
generation_config, model_kwargs = self._prepare_generation_config(
2084+
generation_config, use_model_defaults, **kwargs
2085+
)
20622086
self._validate_model_kwargs(model_kwargs.copy())
20632087
self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
20642088

tests/models/gemma3/test_modeling_gemma3.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -575,8 +575,8 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
575575

576576
def test_generation_beyond_sliding_window_with_generation_config(self):
577577
"""
578-
Same as `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684 --
579-
ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
578+
Similar to `test_generation_beyond_sliding_window`, but passing a GenerationConfig. Regression test for #36684
579+
-- ensures `cache_implementation='hybrid'` is correctly inherited from the base `model.generation_config`.
580580
"""
581581
model_id = "google/gemma-3-1b-it"
582582
attn_implementation = "sdpa"
@@ -594,12 +594,16 @@ def test_generation_beyond_sliding_window_with_generation_config(self):
594594

595595
# Make sure prefill is larger than sliding window
596596
input_size = inputs.input_ids.shape[-1]
597-
self.assertTrue(input_size > model.config.sliding_window)
597+
self.assertGreater(input_size, model.config.sliding_window)
598598

599-
generation_config = GenerationConfig(max_new_tokens=20)
599+
generation_config = GenerationConfig(max_new_tokens=5, min_new_tokens=5)
600+
out = model.generate(**inputs, generation_config=generation_config)
600601

601-
out = model.generate(**inputs, generation_config=generation_config)[:, input_size:]
602-
output_text = tokenizer.batch_decode(out)
602+
# Generation works beyond sliding window
603+
self.assertGreater(out.shape[1], model.config.sliding_window)
604+
self.assertEqual(out.shape[1], input_size + 5)
603605

604-
EXPECTED_COMPLETIONS = [" and I'm going to take a walk.\n\nI really enjoy the scenery, and I'", ", green, yellow, orange, purple, brown, black, white, gray.\n\nI'"] # fmt: skip
605-
self.assertEqual(output_text, EXPECTED_COMPLETIONS)
606+
# Note: Auto-inheritance only works for models saved starting from 4.50.0
607+
model.generation_config.transformers_version = "4.49.0"
608+
with self.assertRaises(RuntimeError): # errors out because it is not using hybrid cache
609+
out = model.generate(**inputs, generation_config=generation_config)

0 commit comments

Comments
 (0)