@@ -153,9 +153,9 @@ def test_tokenize_and_process_tokens(self):
153
153
batched = True ,
154
154
batch_size = 2 ,
155
155
)
156
- self .assertListEqual (tokenized_dataset ["prompt" ], train_dataset ["prompt" ])
157
- self .assertListEqual (tokenized_dataset ["completion" ], train_dataset ["completion" ])
158
- self .assertListEqual (tokenized_dataset ["label" ], train_dataset ["label" ])
156
+ self .assertListEqual (tokenized_dataset ["prompt" ][:] , train_dataset ["prompt" ][: ])
157
+ self .assertListEqual (tokenized_dataset ["completion" ][:] , train_dataset ["completion" ][: ])
158
+ self .assertListEqual (tokenized_dataset ["label" ][:] , train_dataset ["label" ][: ])
159
159
self .assertListEqual (tokenized_dataset ["prompt_input_ids" ][0 ], [46518 , 374 , 2664 , 1091 ])
160
160
self .assertListEqual (tokenized_dataset ["prompt_attention_mask" ][0 ], [1 , 1 , 1 , 1 ])
161
161
self .assertListEqual (tokenized_dataset ["answer_input_ids" ][0 ], [27261 , 13 ])
@@ -193,9 +193,9 @@ def test_tokenize_and_process_tokens(self):
193
193
"max_prompt_length" : trainer .max_prompt_length ,
194
194
}
195
195
processed_dataset = tokenized_dataset .map (_process_tokens , fn_kwargs = fn_kwargs , num_proc = 2 )
196
- self .assertListEqual (processed_dataset ["prompt" ], train_dataset ["prompt" ])
197
- self .assertListEqual (processed_dataset ["completion" ], train_dataset ["completion" ])
198
- self .assertListEqual (processed_dataset ["label" ], train_dataset ["label" ])
196
+ self .assertListEqual (processed_dataset ["prompt" ][:] , train_dataset ["prompt" ][: ])
197
+ self .assertListEqual (processed_dataset ["completion" ][:] , train_dataset ["completion" ][: ])
198
+ self .assertListEqual (processed_dataset ["label" ][:] , train_dataset ["label" ][: ])
199
199
self .assertListEqual (processed_dataset ["prompt_input_ids" ][0 ], [46518 , 374 , 2664 , 1091 ])
200
200
self .assertListEqual (processed_dataset ["prompt_attention_mask" ][0 ], [1 , 1 , 1 , 1 ])
201
201
self .assertListEqual (
0 commit comments