BaseFinetuning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)
class lightning.pytorch.callbacks.BaseFinetuning[source]¶
Bases: Callback
This class implements the base logic for writing your own Finetuning Callback.
Override freeze_before_training
and finetune_function
methods with your own logic.
freeze_before_training
: This method is called before configure_optimizers
and should be used to freeze any modules parameters.
finetune_function
: This method is called on every train epoch start and should be used to
unfreeze
any parameters. Those parameters need to be added in a new param_group
within the optimizer.
Note
Make sure to filter the parameters based on requires_grad
.
Example:
from torch.optim import Adam class MyModel(pl.LightningModule): ... def configure_optimizer(self): ... # Make sure to filter the parameters based on
requires_grad
... return Adam(filter(lambda p: p.requires_grad, self.parameters())) ... class FeatureExtractorFreezeUnfreeze(BaseFinetuning): ... def init(self, unfreeze_at_epoch=10): ... super().init() ... self._unfreeze_at_epoch = unfreeze_at_epoch ... ... def freeze_before_training(self, pl_module): ... # freeze any module you want ... # Here, we are freezingfeature_extractor
... self.freeze(pl_module.feature_extractor) ... ... def finetune_function(self, pl_module, current_epoch, optimizer): ... # Whencurrent_epoch
is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( ... modules=pl_module.feature_extractor, ... optimizer=optimizer, ... train_bn=True, ... )
static filter_on_optimizer(optimizer, params)[source]¶
This function is used to exclude any parameter which already exists in this optimizer.
Parameters:
- optimizer¶ (Optimizer) – Optimizer used for parameter exclusion
- params¶ (Iterable) – Iterable of parameters used to check against the provided optimizer
Return type:
Returns:
List of parameters not contained in this optimizer param groups
static filter_params(modules, train_bn=True, requires_grad=True)[source]¶
Yields the requires_grad parameters of a given module or list of modules.
Parameters:
- modules¶ (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules
- train_bn¶ (bool) – Whether not to train the BatchNorm module
- requires_grad¶ (bool) – Whether to create a generator for trainable or non-trainable parameters.
Return type:
Returns:
Generator
finetune_function(pl_module, epoch, optimizer)[source]¶
Override to add your unfreeze logic.
Return type:
static flatten_modules(modules)[source]¶
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.
Parameters:
modules¶ (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules
Return type:
Returns:
List of modules
static freeze(modules, train_bn=True)[source]¶
Freezes the parameters of the provided modules.
Parameters:
- modules¶ (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules
- train_bn¶ (bool) – If True, leave the BatchNorm layers in training mode
Return type:
Returns:
None
freeze_before_training(pl_module)[source]¶
Override to add your freeze logic.
Return type:
static freeze_module(module)[source]¶
Freezes the parameters of the provided module.
Parameters:
module¶ (Module) – A given module
Return type:
load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s state_dict
.
Parameters:
state_dict¶ (dict[str, Any]) – the callback state returned by state_dict
.
Return type:
static make_trainable(modules)[source]¶
Unfreezes the parameters of the provided modules.
Parameters:
modules¶ (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules
Return type:
on_fit_start(trainer, pl_module)[source]¶
Called when fit begins.
Return type:
on_train_epoch_start(trainer, pl_module)[source]¶
Called when the epoch begins.
Return type:
setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
Return type:
Called when saving a checkpoint, implement to generate callback’s state_dict
.
Return type:
Returns:
A dictionary containing callback state.
static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]¶
Unfreezes a module and adds its parameters to an optimizer.
Parameters:
- modules¶ (Union[Module, Iterable[Union[Module, Iterable]]]) – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.
- optimizer¶ (Optimizer) – The provided optimizer will receive new parameters and will add them toadd_param_group
- lr¶ (Optional[float]) – Learning rate for the new param group.
- initial_denom_lr¶ (float) – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.
- train_bn¶ (bool) – Whether to train the BatchNormalization layers.
Return type: