Skip to content

Check rewards shapes in RewardTrainer #3577

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion tests/test_reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@

import torch
from datasets import Dataset, load_dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
from transformers.testing_utils import require_peft
from transformers.utils import is_peft_available

from trl import RewardConfig, RewardTrainer, maybe_apply_chat_template
from trl.trainer.reward_trainer import _tokenize
from trl.trainer.utils import RewardDataCollatorWithPadding


if is_peft_available():
Expand Down Expand Up @@ -233,3 +234,85 @@ def test_tags(self):
model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset
)
self.assertEqual(trainer.model.model_tags, trainer._tag_names)

def test_collator_args(self):
"""Tests whether the Trainer passes data collator args to the default data collator"""
pad_to_multiple_of = 31415926
with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")

training_args = RewardConfig(
output_dir=tmp_dir,
report_to="none",
pad_to_multiple_of=pad_to_multiple_of,
bf16=False,
)

trainer = RewardTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
)

self.assertEqual(trainer.data_collator.pad_to_multiple_of, pad_to_multiple_of)

def test_custom_collator(self):
"""Tests passing an instantiated data collator to the Trainer"""

with tempfile.TemporaryDirectory() as tmp_dir:
dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")

training_args = RewardConfig(
output_dir=tmp_dir,
report_to="none",
)

collator = RewardDataCollatorWithPadding(
tokenizer=self.tokenizer,
)

RewardTrainer(
model=self.model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
data_collator=collator,
)

def test_train_with_wrong_model(self):
with tempfile.TemporaryDirectory() as tmp_dir:
misconfigured_model = AutoModelForCausalLM.from_pretrained(self.model_id)

dummy_dataset = load_dataset("trl-internal-testing/zen", "conversational_preference", split="train")

training_args = RewardConfig(output_dir=tmp_dir, max_steps=3, report_to="none")

trainer = RewardTrainer(
model=misconfigured_model,
args=training_args,
processing_class=self.tokenizer,
train_dataset=dummy_dataset,
)

with self.assertWarns(expected_warning=Warning):
trainer.compute_loss(
model=trainer.model,
inputs={
"input_ids_chosen": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
"attention_mask_chosen": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
"input_ids_rejected": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
"attention_mask_rejected": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
},
)

with self.assertRaises(expected_exception=RuntimeError):
trainer.compute_loss(
model=trainer.model,
inputs={
"input_ids_chosen": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
"attention_mask_chosen": torch.ones((2, 8), dtype=torch.int).to(trainer.model.device),
"input_ids_rejected": torch.ones((2, 6), dtype=torch.int).to(trainer.model.device),
"attention_mask_rejected": torch.ones((2, 6), dtype=torch.int).to(trainer.model.device),
},
)
6 changes: 6 additions & 0 deletions trl/trainer/reward_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class may differ from those in [`~transformers.TrainingArguments`].
max_length (`int` or `None`, *optional*, defaults to `1024`):
Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
limit. This argument is required if you want to use the default data collator.
pad_to_multiple_of (`int` or `None`, *optional*, defaults to `None`):
If set, the sequences will be padded to a multiple of this value.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
dataset_num_proc (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -81,6 +83,10 @@ class may differ from those in [`~transformers.TrainingArguments`].
"exceed the limit. This argument is required if you want to use the default data collator."
},
)
pad_to_multiple_of: Optional[int] = field(
default=None,
metadata={"help": "If set, the sequences will be padded to a multiple of this value."},
)
disable_dropout: bool = field(
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
Expand Down
40 changes: 34 additions & 6 deletions trl/trainer/reward_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,9 +179,10 @@ def __init__(
"A processing_class must be specified when using the default RewardDataCollatorWithPadding"
)

max_length = args.max_length

data_collator = RewardDataCollatorWithPadding(processing_class)
data_collator = RewardDataCollatorWithPadding(
tokenizer=processing_class,
pad_to_multiple_of=args.pad_to_multiple_of,
)

if args.remove_unused_columns:
try: # for bc before https://github.com/huggingface/transformers/pull/25435
Expand Down Expand Up @@ -222,7 +223,8 @@ def __init__(
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
train_dataset = train_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length and len(x["input_ids_rejected"]) <= max_length,
lambda x: len(x["input_ids_chosen"]) <= args.max_length
and len(x["input_ids_rejected"]) <= args.max_length,
num_proc=args.dataset_num_proc,
)
if eval_dataset is not None:
Expand All @@ -239,8 +241,8 @@ def __init__(
# get truncated => noisy signal the chosen/rejected label gets lost. The downside is that the
# user might get surprised if N samples are missing from training.
eval_dataset = eval_dataset.filter(
lambda x: len(x["input_ids_chosen"]) <= max_length
and len(x["input_ids_rejected"]) <= max_length,
lambda x: len(x["input_ids_chosen"]) <= args.max_length
and len(x["input_ids_rejected"]) <= args.max_length,
num_proc=args.dataset_num_proc,
)

Expand Down Expand Up @@ -279,6 +281,32 @@ def compute_loss(
attention_mask=inputs["attention_mask_rejected"],
return_dict=True,
)["logits"]

if self.state.global_step == 0:
# Only runs on the first training step as a check
if len(rewards_chosen.shape) != 2 or len(rewards_rejected.shape) != 2:
# Make sure that the rewards are defined at the sequence level
warnings.warn(
message="The output of the model is of unexpected shape. "
f"Chosen rewards: {rewards_chosen.shape}. "
f"Rejected rewards: {rewards_rejected.shape}. "
"The expected output does not have a sequence length. "
"This can happen if the model is not setup for sequence classification. "
"Please check your model configuration.",
category=RuntimeWarning,
)

if rewards_chosen.shape != rewards_rejected.shape:
raise RuntimeError(
"The output of the model is incompatible. "
f"Chosen rewards: {rewards_chosen.shape}. "
f"Rejected rewards: {rewards_rejected.shape}. "
"The shapes of the rewards should match exactly. "
"This will raise a RuntimeError when computing the loss. "
"This can happen if the model is not setup for sequence classification. "
"Please check your model configuration.",
)

# calculate loss, optionally modulate with margin
if "margin" in inputs:
loss = -nn.functional.logsigmoid(rewards_chosen - rewards_rejected - inputs["margin"]).mean()
Expand Down
37 changes: 23 additions & 14 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,28 +349,30 @@ def __call__(self, examples: list[dict[str, Any]]) -> dict[str, torch.Tensor]:
@dataclass
class RewardDataCollatorWithPadding:
r"""
Reward DataCollator class that pads the inputs to the maximum length of the batch.
Reward DataCollator class that pads the inputs to the maximum length of the batch.

Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for encoding the data.
padding (`Union[bool, str, `PaddingStrategy`]`, `optional`, defaults to `True`):
padding_strategy to pass to the tokenizer.
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
return_tensors (`str`, `optional`, defaults to `"pt"`):
The tensor type to use.
Assumes the input is pre-tokenized, but still represented as lists of ints.

Requires the columns 'input_ids_chosen', 'input_ids_rejected', 'attention_mask_chosen', 'attention_mask_rejected' to be present in the batch.

Args:
tokenizer (`PreTrainedTokenizerBase`):
The tokenizer used for encoding the data.
pad_to_multiple_of (`int` or `None`, `optional`, defaults to `None`):
If set will pad the sequence to a multiple of the provided value.
return_tensors (`str`, `optional`, defaults to `"pt"`):
The tensor type to use.
"""

tokenizer: PreTrainedTokenizerBase
padding: Union[bool, str] = True
pad_to_multiple_of: Optional[int] = None
return_tensors: str = "pt"

def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
def __call__(self, features: list[dict[str, list[int]]]) -> dict[str, torch.tensor]:
features_chosen = []
features_rejected = []
margin = []

# check if we have a margin. If we do, we need to batch it as well
has_margin = "margin" in features[0]
for feature in features:
Expand All @@ -391,36 +393,43 @@ def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]:
"attention_mask": feature["attention_mask_chosen"],
}
)

features_rejected.append(
{
"input_ids": feature["input_ids_rejected"],
"attention_mask": feature["attention_mask_rejected"],
}
)

if has_margin:
margin.append(feature["margin"])

batch_chosen = self.tokenizer.pad(
features_chosen,
padding=self.padding,
padding=True,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)

batch_rejected = self.tokenizer.pad(
features_rejected,
padding=self.padding,
padding=True,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors,
)

batch = {
"input_ids_chosen": batch_chosen["input_ids"],
"attention_mask_chosen": batch_chosen["attention_mask"],
"input_ids_rejected": batch_rejected["input_ids"],
"attention_mask_rejected": batch_rejected["attention_mask"],
"return_loss": True,
}

if has_margin:
margin = torch.tensor(margin, dtype=torch.float)
batch["margin"] = margin

return batch


Expand Down