|
60 | 60 | "import numpy as np\n",
|
61 | 61 | "from transformers import (\n",
|
62 | 62 | " AutoModelForSequenceClassification,\n",
|
63 |
| - " AutoConfig, \n", |
64 |
| - " AutoTokenizer, \n", |
65 |
| - " EvalPrediction, \n", |
66 |
| - " default_data_collator\n", |
| 63 | + " AutoConfig,\n", |
| 64 | + " AutoTokenizer,\n", |
| 65 | + " EvalPrediction,\n", |
| 66 | + " default_data_collator,\n", |
67 | 67 | ")\n",
|
68 | 68 | "from datasets import load_dataset, load_metric"
|
69 | 69 | ]
|
|
96 | 96 | "dataset[\"train\"].to_csv(\"rotten_tomatoes-train.csv\")\n",
|
97 | 97 | "dataset[\"validation\"].to_csv(\"rotten_tomatoes-validation.csv\")\n",
|
98 | 98 | "data_files = {\n",
|
99 |
| - " \"train\": \"rotten_tomatoes-train.csv\",\n", |
100 |
| - " \"validation\": \"rotten_tomatoes-validation.csv\"\n", |
| 99 | + " \"train\": \"rotten_tomatoes-train.csv\",\n", |
| 100 | + " \"validation\": \"rotten_tomatoes-validation.csv\",\n", |
101 | 101 | "}\n",
|
102 | 102 | "dataset_from_json = load_dataset(\"csv\", data_files=data_files)"
|
103 | 103 | ]
|
|
163 | 163 | "source": [
|
164 | 164 | "metric = load_metric(\"accuracy\")\n",
|
165 | 165 | "\n",
|
| 166 | + "\n", |
166 | 167 | "def compute_metrics(p: EvalPrediction):\n",
|
167 |
| - " preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n", |
168 |
| - " preds = np.argmax(preds, axis=1)\n", |
169 |
| - " result = metric.compute(predictions=preds, references=p.label_ids)\n", |
170 |
| - " if len(result) > 1:\n", |
171 |
| - " result[\"combined_score\"] = np.mean(list(result.values())).item()\n", |
172 |
| - " return result" |
| 168 | + " preds = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions\n", |
| 169 | + " preds = np.argmax(preds, axis=1)\n", |
| 170 | + " result = metric.compute(predictions=preds, references=p.label_ids)\n", |
| 171 | + " if len(result) > 1:\n", |
| 172 | + " result[\"combined_score\"] = np.mean(list(result.values())).item()\n", |
| 173 | + " return result" |
173 | 174 | ]
|
174 | 175 | },
|
175 | 176 | {
|
|
209 | 210 | "outputs": [],
|
210 | 211 | "source": [
|
211 | 212 | "# downloads 90% pruned upstream BERT trained on MLM objective\n",
|
212 |
| - "model_stub = \"zoo:nlp/masked_language_modeling/obert-base/pytorch/huggingface/wikipedia_bookcorpus/pruned90-none\" \n", |
213 |
| - "model_path = Model(model_stub, download_path=\"./model\").training.path \n", |
| 213 | + "model_stub = \"zoo:nlp/masked_language_modeling/obert-base/pytorch/huggingface/wikipedia_bookcorpus/pruned90-none\"\n", |
| 214 | + "model_path = Model(model_stub, download_path=\"./model\").training.path\n", |
214 | 215 | "\n",
|
215 | 216 | "# downloads transfer recipe for MNLI (pruned90_quant)\n",
|
216 |
| - "transfer_stub = \"zoo:nlp/sentiment_analysis/obert-base/pytorch/huggingface/sst2/pruned90_quant-none\"\n", |
217 |
| - "recipe_path = Model(transfer_stub, download_path=\"./transfer_recipe\").recipes.default.path" |
| 217 | + "transfer_stub = (\n", |
| 218 | + " \"zoo:nlp/sentiment_analysis/obert-base/pytorch/huggingface/sst2/pruned90_quant-none\"\n", |
| 219 | + ")\n", |
| 220 | + "recipe_path = Model(\n", |
| 221 | + " transfer_stub, download_path=\"./transfer_recipe\"\n", |
| 222 | + ").recipes.default.path" |
218 | 223 | ]
|
219 | 224 | },
|
220 | 225 | {
|
|
363 | 368 | "# initialize model using familiar HF AutoModel\n",
|
364 | 369 | "model_kwargs = {\"config\": model_config}\n",
|
365 | 370 | "model_kwargs[\"state_dict\"], s_delayed = SparseAutoModel._loadable_state_dict(model_path)\n",
|
366 |
| - "model = AutoModelForSequenceClassification.from_pretrained(model_path,**model_kwargs,)\n", |
367 |
| - "SparseAutoModel.log_model_load(model, model_path, \"student\", s_delayed) # prints metrics on sparsity profile\n", |
| 371 | + "model = AutoModelForSequenceClassification.from_pretrained(\n", |
| 372 | + " model_path,\n", |
| 373 | + " **model_kwargs,\n", |
| 374 | + ")\n", |
| 375 | + "SparseAutoModel.log_model_load(\n", |
| 376 | + " model, model_path, \"student\", s_delayed\n", |
| 377 | + ") # prints metrics on sparsity profile\n", |
368 | 378 | "\n",
|
369 | 379 | "# initialize teacher using familiar HF AutoModel\n",
|
370 | 380 | "teacher_kwargs = {\"config\": teacher_config}\n",
|
371 |
| - "teacher_kwargs[\"state_dict\"], t_delayed = SparseAutoModel._loadable_state_dict(teacher_path)\n", |
372 |
| - "teacher = AutoModelForSequenceClassification.from_pretrained(teacher_path,**teacher_kwargs,)\n", |
| 381 | + "teacher_kwargs[\"state_dict\"], t_delayed = SparseAutoModel._loadable_state_dict(\n", |
| 382 | + " teacher_path\n", |
| 383 | + ")\n", |
| 384 | + "teacher = AutoModelForSequenceClassification.from_pretrained(\n", |
| 385 | + " teacher_path,\n", |
| 386 | + " **teacher_kwargs,\n", |
| 387 | + ")\n", |
373 | 388 | "SparseAutoModel.log_model_load(teacher, teacher_path, \"teacher\", t_delayed)"
|
374 | 389 | ]
|
375 | 390 | },
|
|
393 | 408 | "outputs": [],
|
394 | 409 | "source": [
|
395 | 410 | "MAX_LEN = 128\n",
|
| 411 | + "\n", |
| 412 | + "\n", |
396 | 413 | "def preprocess_fn(examples):\n",
|
397 |
| - " args = None\n", |
398 |
| - " if INPUT_COL_2 is None:\n", |
399 |
| - " args = (examples[INPUT_COL_1], )\n", |
400 |
| - " else:\n", |
401 |
| - " args = (examples[INPUT_COL_1], examples[INPUT_COL_2])\n", |
402 |
| - " result = tokenizer(*args, \n", |
403 |
| - " padding=\"max_length\", \n", |
404 |
| - " max_length=min(tokenizer.model_max_length, MAX_LEN), \n", |
405 |
| - " truncation=True)\n", |
406 |
| - " return result\n", |
| 414 | + " args = None\n", |
| 415 | + " if INPUT_COL_2 is None:\n", |
| 416 | + " args = (examples[INPUT_COL_1],)\n", |
| 417 | + " else:\n", |
| 418 | + " args = (examples[INPUT_COL_1], examples[INPUT_COL_2])\n", |
| 419 | + " result = tokenizer(\n", |
| 420 | + " *args,\n", |
| 421 | + " padding=\"max_length\",\n", |
| 422 | + " max_length=min(tokenizer.model_max_length, MAX_LEN),\n", |
| 423 | + " truncation=True,\n", |
| 424 | + " )\n", |
| 425 | + " return result\n", |
| 426 | + "\n", |
407 | 427 | "\n",
|
408 | 428 | "tokenized_dataset = dataset_from_json.map(\n",
|
409 |
| - " preprocess_fn,\n", |
410 |
| - " batched=True,\n", |
411 |
| - " desc=\"Running tokenizer on dataset\"\n", |
| 429 | + " preprocess_fn, batched=True, desc=\"Running tokenizer on dataset\"\n", |
412 | 430 | ")"
|
413 | 431 | ]
|
414 | 432 | },
|
|
447 | 465 | " save_total_limit=1,\n",
|
448 | 466 | " per_device_train_batch_size=32,\n",
|
449 | 467 | " per_device_eval_batch_size=32,\n",
|
450 |
| - " fp16=True)\n", |
| 468 | + " fp16=True,\n", |
| 469 | + ")\n", |
451 | 470 | "\n",
|
452 | 471 | "trainer = Trainer(\n",
|
453 | 472 | " model=model,\n",
|
454 | 473 | " model_state_path=model_path,\n",
|
455 | 474 | " recipe=recipe_path,\n",
|
456 | 475 | " teacher=teacher,\n",
|
457 |
| - " metadata_args=[\"per_device_train_batch_size\",\"per_device_eval_batch_size\",\"fp16\"],\n", |
| 476 | + " metadata_args=[\"per_device_train_batch_size\", \"per_device_eval_batch_size\", \"fp16\"],\n", |
458 | 477 | " args=training_args,\n",
|
459 | 478 | " train_dataset=tokenized_dataset[\"train\"],\n",
|
460 | 479 | " eval_dataset=tokenized_dataset[\"validation\"],\n",
|
461 | 480 | " tokenizer=tokenizer,\n",
|
462 | 481 | " data_collator=default_data_collator,\n",
|
463 |
| - " compute_metrics=compute_metrics)" |
| 482 | + " compute_metrics=compute_metrics,\n", |
| 483 | + ")" |
464 | 484 | ]
|
465 | 485 | },
|
466 | 486 | {
|
|
0 commit comments