Skip to content

Commit 49f8db9

Browse files
rahul-tulicorey-nmKSGulin
authored
[Cherry-pick] Code for layerwise distillation (#1311)
* Saving all hooks during quantization block fusing (#1280) * Saving all hooks during quantization block fusing * Clean up delete get block hooks * Layer-Wise Distillation (#1272) * Initial Commit with Alex's Work * Update `student_names` -> `student_layer_names` Update `teacher_names` -> `teacher_layer_names` * Intermediate commit * Styling * Reorg initialize * More cleanups * Update docstring * Moving finalize logic to update * Tests passing a bit * Fixing lifecycle tests * Changing projection to dict * Cleanup * Adding quantization hooks test * Add failing test for optimizer serialization * Monkey patching optimizer state_dict method * Apply suggestions from code review Co-authored-by: Konstantin Gulin <[email protected]> * Update src/sparseml/pytorch/sparsification/distillation/modifier_per_layer.py * Adding missing docstrings * Respond to review on modifier/optimizer state_dict * Add a test for modifier load before forward pass * Updating comments * Fix failing test * Add more asserts based on @bfineran 's comments * * Rename `_DISTILL_PARAM_GROUP_KEY` -> `DISTILL_PARAM_GROUP_KEY` * Add to `DISTILL_PARAM_GROUP_KEY` to `__all__` * Move state dict patching to a helper function * Quality Co-authored-by: Corey Lowman <[email protected]> Co-authored-by: corey-nm <[email protected]> Co-authored-by: Konstantin Gulin <[email protected]> Co-authored-by: corey-nm <[email protected]> Co-authored-by: Corey Lowman <[email protected]> Co-authored-by: Konstantin Gulin <[email protected]>
1 parent 48bb176 commit 49f8db9

File tree

5 files changed

+1010
-10
lines changed

5 files changed

+1010
-10
lines changed

src/sparseml/pytorch/sparsification/distillation/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
# limitations under the License.
1616

1717
from .modifier_distillation import *
18+
from .modifier_per_layer import *

src/sparseml/pytorch/sparsification/distillation/modifier_distillation_base.py

+1
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def loss_update(
332332
teacher_outputs=teacher_outputs,
333333
student_labels=student_labels,
334334
teacher_labels=teacher_labels,
335+
optimizer=optimizer,
335336
)
336337

337338
total_loss = self.compute_total_loss(loss, distillation_loss)

0 commit comments

Comments
 (0)