BackboneFinetuning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]¶
Bases: BaseFinetuning
Finetune a backbone model based on a learning rate user-defined scheduling.
When the backbone learning rate reaches the current model learning rate and should_align
is set to True, it will align with it for the rest of the training.
Parameters:
- unfreeze_backbone_at_epoch¶ (int) – Epoch at which the backbone will be unfreezed.
- lambda_func¶ (Callable) – Scheduling function for increasing backbone learning rate.
- backbone_initial_ratio_lr¶ (float) – Used to scale down the backbone learning rate compared to rest of model
- backbone_initial_lr¶ (Optional[float]) – Optional, Initial learning rate for the backbone. By default, we will use
current_learning / backbone_initial_ratio_lr
- should_align¶ (bool) – Whether to align with current learning rate when backbone learning reaches it.
- initial_denom_lr¶ (float) – When unfreezing the backbone, the initial learning rate will
current_learning_rate / initial_denom_lr
. - train_bn¶ (bool) – Whether to make Batch Normalization trainable.
- verbose¶ (bool) – Display current learning rate for model and backbone
- rounding¶ (int) – Precision for displaying learning rate
Example:
from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BackboneFinetuning multiplicative = lambda epoch: 1.5 backbone_finetuning = BackboneFinetuning(200, multiplicative) trainer = Trainer(callbacks=[backbone_finetuning])
finetune_function(pl_module, epoch, optimizer)[source]¶
Called when the epoch begins.
Return type:
freeze_before_training(pl_module)[source]¶
Override to add your freeze logic.
Return type:
load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s state_dict
.
Parameters:
state_dict¶ (dict[str, Any]) – the callback state returned by state_dict
.
Return type:
on_fit_start(trainer, pl_module)[source]¶
Raises:
MisconfigurationException – If LightningModule has no nn.Module backbone attribute.
Return type:
Called when saving a checkpoint, implement to generate callback’s state_dict
.
Return type:
Returns:
A dictionary containing callback state.