Skip to content

Commit 0576346

Browse files
authored
🏛️ Fix CI and Iterative SFT (#3614)
1 parent e63588a commit 0576346

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

tests/test_bco_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,9 @@ def test_tokenize_and_process_tokens(self):
194194
batched=True,
195195
batch_size=2,
196196
)
197-
self.assertListEqual(tokenized_dataset["prompt"], dataset["prompt"])
198-
self.assertListEqual(tokenized_dataset["completion"], dataset["completion"])
199-
self.assertListEqual(tokenized_dataset["label"], dataset["label"])
197+
self.assertListEqual(tokenized_dataset["prompt"][:], dataset["prompt"][:])
198+
self.assertListEqual(tokenized_dataset["completion"][:], dataset["completion"][:])
199+
self.assertListEqual(tokenized_dataset["label"][:], dataset["label"][:])
200200
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
201201
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
202202
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
@@ -212,9 +212,9 @@ def test_tokenize_and_process_tokens(self):
212212
"max_prompt_length": trainer.max_prompt_length,
213213
}
214214
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs)
215-
self.assertListEqual(processed_dataset["prompt"], dataset["prompt"])
216-
self.assertListEqual(processed_dataset["completion"], dataset["completion"])
217-
self.assertListEqual(processed_dataset["label"], dataset["label"])
215+
self.assertListEqual(processed_dataset["prompt"][:], dataset["prompt"][:])
216+
self.assertListEqual(processed_dataset["completion"][:], dataset["completion"][:])
217+
self.assertListEqual(processed_dataset["label"][:], dataset["label"][:])
218218
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
219219
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
220220
self.assertListEqual(

tests/test_kto_trainer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,9 +153,9 @@ def test_tokenize_and_process_tokens(self):
153153
batched=True,
154154
batch_size=2,
155155
)
156-
self.assertListEqual(tokenized_dataset["prompt"], train_dataset["prompt"])
157-
self.assertListEqual(tokenized_dataset["completion"], train_dataset["completion"])
158-
self.assertListEqual(tokenized_dataset["label"], train_dataset["label"])
156+
self.assertListEqual(tokenized_dataset["prompt"][:], train_dataset["prompt"][:])
157+
self.assertListEqual(tokenized_dataset["completion"][:], train_dataset["completion"][:])
158+
self.assertListEqual(tokenized_dataset["label"][:], train_dataset["label"][:])
159159
self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
160160
self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
161161
self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13])
@@ -193,9 +193,9 @@ def test_tokenize_and_process_tokens(self):
193193
"max_prompt_length": trainer.max_prompt_length,
194194
}
195195
processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2)
196-
self.assertListEqual(processed_dataset["prompt"], train_dataset["prompt"])
197-
self.assertListEqual(processed_dataset["completion"], train_dataset["completion"])
198-
self.assertListEqual(processed_dataset["label"], train_dataset["label"])
196+
self.assertListEqual(processed_dataset["prompt"][:], train_dataset["prompt"][:])
197+
self.assertListEqual(processed_dataset["completion"][:], train_dataset["completion"][:])
198+
self.assertListEqual(processed_dataset["label"][:], train_dataset["label"][:])
199199
self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091])
200200
self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1])
201201
self.assertListEqual(

trl/trainer/iterative_sft_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,13 @@ def step(
357357
"No 'labels' or 'text_labels' are provided. When using an encoder-decoder architecture, 'labels' or 'text_labels' must be passed."
358358
)
359359

360+
# Convert Column to list if not already
361+
input_ids = input_ids[:] if input_ids is not None else None
362+
attention_mask = attention_mask[:] if attention_mask is not None else None
363+
labels = labels[:] if labels is not None else None
364+
texts = texts[:] if texts is not None else None
365+
texts_labels = texts_labels[:] if texts_labels is not None else None
366+
360367
input_ids, attention_mask, labels, texts, texts_labels = self._step_safety_checker(
361368
input_ids, attention_mask, labels, texts, texts_labels
362369
)

0 commit comments

Comments
 (0)