Skip to content

Commit f088321

Browse files
author
Sara Adkins
authored
Enable One-Shot Launch from Finetuning Script (#1907)
* setup StageRunner class * running one_shot from text_gen script * cleanup helper fns * precision support * formatting
1 parent 3bba1b6 commit f088321

File tree

8 files changed

+330
-149
lines changed

8 files changed

+330
-149
lines changed

src/sparseml/pytorch/model_load/helpers.py

+60
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
"apply_recipe_structure_to_model",
3333
"reload_model_state",
3434
"reload_model_from_checkpoint",
35+
"save_model_and_recipe",
36+
"fallback_to_cpu",
37+
"parse_dtype",
3538
]
3639

3740
_LOGGER = logging.getLogger(__name__)
@@ -173,3 +176,60 @@ def reload_model_from_checkpoint(model: Module, checkpoint: Optional[str] = None
173176
# reload the state dict for the model from the checkpoint
174177
if reload_model_state(model, checkpoint, orig_state_dict):
175178
_LOGGER.info(f"Reloaded model state from checkpoint {checkpoint}")
179+
180+
181+
def save_model_and_recipe(
182+
model: Module,
183+
save_path: str,
184+
tokenizer: Optional[Any] = None,
185+
):
186+
"""
187+
Save a model, tokenizer and the currently loaded recipe to file
188+
189+
:param model: pytorch model to save
190+
:param save_path: path to save output to
191+
:param tokenizer: model tokenizer to save
192+
"""
193+
model.save_pretrained(save_path)
194+
if tokenizer is not None:
195+
tokenizer.save_pretrained(save_path)
196+
197+
_LOGGER.info("Saving output to {}".format(os.path.abspath(save_path)))
198+
199+
recipe_path = os.path.join(save_path, RECIPE_FILE_NAME)
200+
session = session_manager.active_session()
201+
recipe_yaml_str = session.get_serialized_recipe()
202+
with open(recipe_path, "w") as fp:
203+
fp.write(recipe_yaml_str)
204+
205+
206+
def fallback_to_cpu(device: str) -> str:
207+
"""
208+
Takes in a device string and forces it to cpu if cuda is not available
209+
210+
:param device: device id to check
211+
:return: device modified for CUDA status
212+
"""
213+
if "cuda" in device and not torch.cuda.is_available():
214+
_LOGGER.warning(
215+
f"Requested {device} but CUDA is not available, falling back to CPU"
216+
)
217+
return "cpu"
218+
219+
return device
220+
221+
222+
def parse_dtype(dtype_arg: str) -> torch.dtype:
223+
"""
224+
:param dtype_arg: dtype string to parse
225+
:return: torch.dtype parsed from input string
226+
"""
227+
dtype = "auto" # get precision from model by default
228+
if dtype_arg == "half" or dtype_arg == "float16":
229+
dtype = torch.float16
230+
elif dtype_arg == "bfloat16":
231+
dtype = torch.bfloat16
232+
elif dtype_arg == "full" or dtype_arg == "float32":
233+
dtype = torch.float32
234+
235+
return dtype

src/sparseml/transformers/finetune/data/data_args.py

+4
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,10 @@ class DataTrainingArguments:
5858
default=None,
5959
metadata={"help": "Optional percentages of each split to download"},
6060
)
61+
num_calibration_samples: Optional[int] = field(
62+
default=512,
63+
metadata={"help": "Number of samples to use for one-shot calibration"},
64+
)
6165
overwrite_cache: bool = field(
6266
default=False,
6367
metadata={"help": "Overwrite the cached preprocessed datasets or not."},

src/sparseml/transformers/finetune/data/data_helpers.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def get_raw_dataset(data_args, cache_dir: Optional[str] = None, **kwargs) -> Dat
3838

3939

4040
def make_dataset_splits(
41-
tokenized_datasets: Dict[str, Any], do_train: bool, do_eval: bool, do_predict: bool
41+
tokenized_datasets: Dict[str, Any],
42+
do_train: bool = False,
43+
do_eval: bool = False,
44+
do_predict: bool = False,
45+
do_oneshot: bool = False,
4246
) -> Dict[str, Dataset]:
4347
"""
4448
Restructures the datasets dictionary based on what tasks will be run
@@ -48,14 +52,15 @@ def make_dataset_splits(
4852
:param do_train: Whether to store the train dataset
4953
:param do_eval: Whether to store the validation dataset
5054
:param do_predict: Whether to store the test dataset
55+
:param do_oneshot: Whether to store the calibration dataset
5156
:return: Datasets to be used by the requested tasks
5257
"""
5358

5459
# handles case where all splits are contained in a single dataset
5560
if "all" in tokenized_datasets and len(tokenized_datasets) == 1:
5661
tokenized_datasets = tokenized_datasets.get("all")
5762

58-
train_split = eval_split = predict_split = None
63+
train_split = eval_split = predict_split = calib_split = None
5964
if do_train:
6065
if "train" not in tokenized_datasets:
6166
raise ValueError("--do_train requires a train dataset")
@@ -68,10 +73,15 @@ def make_dataset_splits(
6873
if "test" not in tokenized_datasets:
6974
raise ValueError("--do_predict requires a test dataset")
7075
predict_split = tokenized_datasets["test"]
76+
if do_oneshot:
77+
if "calibration" not in tokenized_datasets:
78+
raise ValueError("--do_oneshot requires a calibration dataset")
79+
calib_split = tokenized_datasets["calibration"]
7180

7281
split_datasets = {
7382
"train": train_split,
7483
"validation": eval_split,
7584
"test": predict_split,
85+
"calibration": calib_split,
7686
}
7787
return split_datasets

src/sparseml/transformers/finetune/model_args.py

+4
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,7 @@ class ModelArguments:
6464
"(necessary to use this script with private models)"
6565
},
6666
)
67+
precision: str = field(
68+
default="auto",
69+
metadata={"help": "Precision to cast model weights to, default to auto"},
70+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing,
10+
# software distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
from typing import List
17+
18+
import torch
19+
from torch.nn import Module
20+
from torch.utils.data import DataLoader, Dataset, RandomSampler
21+
from transformers import AutoTokenizer
22+
23+
import sparseml.core.session as session_manager
24+
from sparseml.core.framework import Framework
25+
from sparseml.pytorch.model_load.helpers import fallback_to_cpu, save_model_and_recipe
26+
from sparseml.transformers.finetune import Trainer, TrainingArguments
27+
from sparseml.transformers.finetune.data import TextGenerationDataset
28+
from sparseml.transformers.finetune.data.data_args import DataTrainingArguments
29+
from sparseml.transformers.finetune.data.data_helpers import make_dataset_splits
30+
from sparseml.transformers.finetune.model_args import ModelArguments
31+
32+
33+
_LOGGER: logging.Logger = logging.getLogger(__name__)
34+
35+
36+
class StageRunner:
37+
"""
38+
Launcher class for train, eval and one_shot flows. Manages data splits for each
39+
flow and configurations. In the future this class will also handle alternating
40+
between the different flows
41+
42+
LifeCycle
43+
- populate_datasets()
44+
- set_trainer()
45+
- train() / evaluate() / predict()
46+
47+
:param model_args: Arguments pertaining to model/config/tokenizer
48+
:param data_args: Arguments pertaining to what data to use for different flows
49+
:param training_args: Arguments pertaining to training loop configuration
50+
:model: unwrapped model to run flows on
51+
"""
52+
53+
def __init__(
54+
self,
55+
data_args: "DataTrainingArguments",
56+
model_args: "ModelArguments",
57+
training_args: "TrainingArguments",
58+
model: Module,
59+
):
60+
self._data_args = data_args
61+
self._model_args = model_args
62+
self._training_args = training_args
63+
64+
self.datasets = {}
65+
self.model = model
66+
self.trainer = None
67+
self.tokenizer = None
68+
69+
def populate_datasets(self, tokenizer: "AutoTokenizer"):
70+
"""
71+
Loads datasets for each flow based on data_args, stores a Dataset for each
72+
enabled flow in self.datasets
73+
74+
:param tokenizer: tokenizer to use for dataset tokenization
75+
"""
76+
splits = self._data_args.splits
77+
tokenized_datasets = {}
78+
if self._data_args.splits is None:
79+
splits = {"all": None}
80+
for split_name, split_str in splits.items():
81+
dataset_manager = TextGenerationDataset.load_from_registry(
82+
self._data_args.dataset_name,
83+
data_args=self._data_args,
84+
split=split_str,
85+
tokenizer=tokenizer,
86+
)
87+
raw_dataset = dataset_manager.get_raw_dataset(self._model_args.cache_dir)
88+
tokenized_dataset = dataset_manager.tokenize_and_process(raw_dataset)
89+
tokenized_datasets[split_name] = tokenized_dataset
90+
91+
self.datasets = make_dataset_splits(
92+
tokenized_datasets,
93+
self._training_args.do_train,
94+
self._training_args.do_eval,
95+
self._training_args.do_predict,
96+
self._training_args.do_oneshot,
97+
)
98+
self.tokenizer = tokenizer
99+
100+
def set_trainer(self, trainer: Trainer):
101+
"""
102+
:param trainer: update trainer
103+
"""
104+
self.trainer = trainer
105+
106+
def set_model(self, model: Module):
107+
"""
108+
:param model: update pytorch model
109+
"""
110+
self.model = model
111+
112+
def get_dataset_split(self, split_name: str) -> Dataset:
113+
"""
114+
Retrieve a dataset split by name
115+
116+
:param split_name: name of dataset split to return
117+
:return: dataset split labeled by split_name
118+
"""
119+
return self.datasets.get(split_name)
120+
121+
def format_calibration_data(self) -> List[torch.Tensor]:
122+
"""
123+
Creates a dataloader out of the calibration dataset split, trimming it to
124+
the desired number of calibration samples
125+
126+
:return: list of trimmed calibration data tensors
127+
"""
128+
oneshot_dataset = self.get_dataset_split("calibration")
129+
130+
dataloader_params = {
131+
"batch_size": 1,
132+
"sampler": RandomSampler(oneshot_dataset),
133+
"collate_fn": self.trainer.data_collator,
134+
}
135+
136+
calib_dataloader = DataLoader(oneshot_dataset, **dataloader_params)
137+
parsed_calib_data = [inp["input_ids"] for inp in calib_dataloader]
138+
return parsed_calib_data[
139+
: min(self._data_args.num_calibration_samples, len(parsed_calib_data))
140+
]
141+
142+
def one_shot(self):
143+
"""
144+
Run oneshot calibration on the active model
145+
"""
146+
_LOGGER.info("*** One Shot ***")
147+
148+
calib_data = self.format_calibration_data()
149+
oneshot_device = fallback_to_cpu(self._training_args.oneshot_device)
150+
session_manager.apply(
151+
framework=Framework.pytorch,
152+
recipe=self._training_args.recipe,
153+
model=self.model,
154+
calib_data=calib_data,
155+
start=-1,
156+
device=oneshot_device,
157+
copy_data=False,
158+
)
159+
160+
save_model_and_recipe(
161+
model=self.model,
162+
save_path=self._training_args.output_dir,
163+
tokenizer=self.tokenizer,
164+
)
165+
166+
def train(self, checkpoint: str):
167+
"""
168+
Run trainer's training loop on train_dataset, saving the resulting model to
169+
output_dir
170+
171+
:param checkpoint: Optional checkpoint to resume from
172+
"""
173+
train_result = self.trainer.train(resume_from_checkpoint=checkpoint)
174+
metrics = train_result.metrics
175+
metrics["train_samples"] = len(self.get_dataset_split("train"))
176+
self.trainer.log_metrics("train", metrics)
177+
self.trainer.save_metrics("train", metrics)
178+
179+
# this includes saving the state, optimizer and scheduler
180+
self.trainer.save_model()
181+
182+
def evaluate(self):
183+
"""
184+
Run trainer's evaluation loop on eval_dataset, logging the desired metrics
185+
"""
186+
_LOGGER.info("*** Evaluate ***")
187+
metrics = self.trainer.evaluate(self.get_dataset_split("validation"))
188+
189+
metrics["eval_samples"] = len(self.get_dataset_split("validation"))
190+
self.trainer.log_metrics("eval", metrics)
191+
self.trainer.save_metrics("eval", metrics)
192+
193+
def predict(self):
194+
"""
195+
Run trainer's prediction loop on predict_dataset, logging the desired metrics
196+
"""
197+
_LOGGER.info("*** Predict ***")
198+
results = self.trainer.predict(self.dataset["test"])
199+
metrics = results.metrics
200+
201+
metrics["predict_samples"] = len(self.dataset["test"])
202+
self.trainer.log_metrics("predict", metrics)
203+
self.trainer.save_metrics("predict", metrics)

0 commit comments

Comments
 (0)