BackboneFinetuning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]

Bases: BaseFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters:

Example:

from lightning.pytorch import Trainer from lightning.pytorch.callbacks import BackboneFinetuning multiplicative = lambda epoch: 1.5 backbone_finetuning = BackboneFinetuning(200, multiplicative) trainer = Trainer(callbacks=[backbone_finetuning])

finetune_function(pl_module, epoch, optimizer)[source]

Called when the epoch begins.

Return type:

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type:

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_fit_start(trainer, pl_module)[source]

Raises:

MisconfigurationException – If LightningModule has no nn.Module backbone attribute.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.