distillation_model — Model Optimizer 0.27.1 (original) (raw)
Meta-model wrapper to support knowledge-distillation learning.
Classes
DistillationModel | Class to encapsulate multiple teacher and student models as a single model. |
---|
class DistillationModel
Bases: DynamicModule
Class to encapsulate multiple teacher and student models as a single model.
compute_kd_loss(student_loss=None, loss_reduction_fn=None, skip_balancer=False)
Compute total loss for distillation backpropagation.
Parameters:
- student_loss (Tensor | None) – Original loss computed from the student’s output.
- loss_reduction_fn (Callable) – Callable to be called on each loss tensor prior to balancing. Useful for loss-masking situations where the callable changes arguments each iteration.
- skip_balancer (bool) – Whether or not to use loss balancer to reduce the loss dict into a scalar.
Returns:
If reduce is True, the scalar total loss weighted between student_loss
and the distillation losses. If reduce is False, a dict of student model output loss and layer-wise distillation losses.
Return type:
Tensor | _dict_[str, _Tensor_]
forward(*args, **kwargs)
Implement forward pass.
Parameters:
- *args – Positional inputs to the student and teacher model.
- **kwargs – Named inputs to the student and teacher model.
Returns:
The student model’s output.
Return type:
Any
hide_loss_modules(enable=True)
Context manager to temporarily hide teacher model from the model.
hide_teacher_model(enable=True)
Context manager to temporarily hide teacher model from the model.
load_state_dict(state_dict, *args, **kwargs)
Override to potentially load the state without teacher’s or loss modules’.
Return type:
Any
property loss_balancer_: DistillationLossBalancer | None_
Fetch the loss balancer, if any.
property loss_modules_: ModuleList_
Fetch the loss modules list.
modify(teacher_model, criterion, loss_balancer=None, expose_minimal_state_dict=True)
Constructor.
Parameters:
- teacher_model (Module) – A teacher model which this class would encapsulate.
- criterion (dict [ tuple [ str , str ] , _Loss ]) – A dictionary mapping the tuple of student and teacher model layer names to the loss function to apply to that layer pair.
- loss_balancer (DistillationLossBalancer | None) – Instance of
DistillationLossBalancer
which reduces distillation and non-distillation losses into a single value using some weighing scheme. - expose_minimal_state_dict (bool) – If True, will hide teacher’s state dict when calling
state_dict
on this class. This allows avoiding to save the teacher state unnecessarily during checkpointing. .. note: Set to False if using FSDP
only_student_forward(enable=True)
Context manager to temporarily disable forward passes on the student model.
only_teacher_forward(enable=True)
Context manager to temporarily disable forward passes on the student model.
state_dict(*args, **kwargs)
Override to potentially return the state without teacher’s.
Return type:
_dict_[str, _Any_]
property teacher_model_: ModuleList_
Fetch the teacher model.
train(mode=True)
Override to prevent warnings of stored intermediate outputs in future forwards.
Parameters:
mode (bool) –