|
23 | 23 | import numpy as np
|
24 | 24 | import torch
|
25 | 25 | import torch.distributed as dist
|
| 26 | +from packaging import version |
26 | 27 | from torch import nn
|
27 | 28 | from torch.nn import functional as F
|
28 | 29 |
|
@@ -1552,7 +1553,7 @@ def _prepare_generated_length(
|
1552 | 1553 | return generation_config
|
1553 | 1554 |
|
1554 | 1555 | def _prepare_generation_config(
|
1555 |
| - self, generation_config: Optional[GenerationConfig], **kwargs: Dict |
| 1556 | + self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Dict |
1556 | 1557 | ) -> Tuple[GenerationConfig, Dict]:
|
1557 | 1558 | """
|
1558 | 1559 | Prepares the base generation config, then applies any generation configuration options from kwargs. This
|
@@ -1591,23 +1592,38 @@ def _prepare_generation_config(
|
1591 | 1592 |
|
1592 | 1593 | generation_config = copy.deepcopy(generation_config)
|
1593 | 1594 |
|
1594 |
| - # If `generation_config` is provided, let's fallback ALL default values to the model's generation config |
1595 | 1595 | if not using_model_generation_config:
|
1596 |
| - modified_values = {} |
1597 |
| - default_generation_config = GenerationConfig() |
1598 |
| - for key, default_value in default_generation_config.__dict__.items(): |
1599 |
| - if key.startswith("_"): # metadata |
1600 |
| - continue |
1601 |
| - custom_gen_config_value = getattr(generation_config, key) |
1602 |
| - model_gen_config_value = getattr(self.generation_config, key) |
1603 |
| - if custom_gen_config_value == default_value and model_gen_config_value != default_value: |
1604 |
| - modified_values[key] = model_gen_config_value |
1605 |
| - setattr(generation_config, key, model_gen_config_value) |
1606 |
| - if len(modified_values) > 0: |
1607 |
| - logger.warning_once( |
1608 |
| - f"`generation_config` default values have been modified to match model-specific defaults: " |
1609 |
| - f"{modified_values}. If this is not desired, please set these values explicitly." |
1610 |
| - ) |
| 1596 | + # If `generation_config` is provided: |
| 1597 | + # - `use_model_defaults`: let's fallback ALL default values to the model's generation config |
| 1598 | + # - otherwise: legacy behavior, let's just make sure we have the tokens defined |
| 1599 | + model_base_version = version.parse(version.parse(self.generation_config.transformers_version).base_version) |
| 1600 | + if use_model_defaults is True or ( |
| 1601 | + use_model_defaults is None and model_base_version >= version.parse("4.50.0") |
| 1602 | + ): |
| 1603 | + modified_values = {} |
| 1604 | + default_generation_config = GenerationConfig() |
| 1605 | + for key, default_value in default_generation_config.__dict__.items(): |
| 1606 | + if key.startswith("_") or key == "transformers_version": # metadata |
| 1607 | + continue |
| 1608 | + custom_gen_config_value = getattr(generation_config, key) |
| 1609 | + model_gen_config_value = getattr(self.generation_config, key) |
| 1610 | + if custom_gen_config_value == default_value and model_gen_config_value != default_value: |
| 1611 | + modified_values[key] = model_gen_config_value |
| 1612 | + setattr(generation_config, key, model_gen_config_value) |
| 1613 | + if len(modified_values) > 0: |
| 1614 | + logger.warning_once( |
| 1615 | + f"`generation_config` default values have been modified to match model-specific defaults: " |
| 1616 | + f"{modified_values}. If this is not desired, please set these values explicitly." |
| 1617 | + ) |
| 1618 | + else: |
| 1619 | + if generation_config.bos_token_id is None: |
| 1620 | + generation_config.bos_token_id = self.generation_config.bos_token_id |
| 1621 | + if generation_config.eos_token_id is None: |
| 1622 | + generation_config.eos_token_id = self.generation_config.eos_token_id |
| 1623 | + if generation_config.pad_token_id is None: |
| 1624 | + generation_config.pad_token_id = self.generation_config.pad_token_id |
| 1625 | + if generation_config.decoder_start_token_id is None: |
| 1626 | + generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id |
1611 | 1627 |
|
1612 | 1628 | # Finally, apply any passed kwargs
|
1613 | 1629 | model_kwargs = generation_config.update(**kwargs)
|
@@ -1967,6 +1983,7 @@ def generate(
|
1967 | 1983 | streamer: Optional["BaseStreamer"] = None,
|
1968 | 1984 | negative_prompt_ids: Optional[torch.Tensor] = None,
|
1969 | 1985 | negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
| 1986 | + use_model_defaults: Optional[bool] = None, |
1970 | 1987 | **kwargs,
|
1971 | 1988 | ) -> Union[GenerateOutput, torch.LongTensor]:
|
1972 | 1989 | r"""
|
@@ -2031,6 +2048,11 @@ def generate(
|
2031 | 2048 | size. This is an experimental feature, subject to breaking API changes in future versions.
|
2032 | 2049 | negative_prompt_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
2033 | 2050 | Attention_mask for `negative_prompt_ids`.
|
| 2051 | + use_model_defaults (`bool`, *optional*): |
| 2052 | + When it is `True`, unset parameters in `generation_config` will be set to the model-specific default |
| 2053 | + generation configuration (`model.generation_config`), as opposed to the global defaults |
| 2054 | + (`GenerationConfig()`). If unset, models saved starting from `v4.50` will consider this flag to be |
| 2055 | + `True`. |
2034 | 2056 | kwargs (`Dict[str, Any]`, *optional*):
|
2035 | 2057 | Ad hoc parametrization of `generation_config` and/or additional model-specific kwargs that will be
|
2036 | 2058 | forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
|
@@ -2058,7 +2080,9 @@ def generate(
|
2058 | 2080 | tokenizer = kwargs.pop("tokenizer", None) # Pull this out first, we only use it for stopping criteria
|
2059 | 2081 | assistant_tokenizer = kwargs.pop("assistant_tokenizer", None) # only used for assisted generation
|
2060 | 2082 |
|
2061 |
| - generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs) |
| 2083 | + generation_config, model_kwargs = self._prepare_generation_config( |
| 2084 | + generation_config, use_model_defaults, **kwargs |
| 2085 | + ) |
2062 | 2086 | self._validate_model_kwargs(model_kwargs.copy())
|
2063 | 2087 | self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
|
2064 | 2088 |
|
|
0 commit comments