BaseFinetuning — PyTorch Lightning 2.5.1.post0 documentation (original) (raw)

class lightning.pytorch.callbacks.BaseFinetuning[source]

Bases: Callback

This class implements the base logic for writing your own Finetuning Callback.

Override freeze_before_training and finetune_function methods with your own logic.

freeze_before_training: This method is called before configure_optimizers

and should be used to freeze any modules parameters.

finetune_function: This method is called on every train epoch start and should be used to

unfreeze any parameters. Those parameters need to be added in a new param_groupwithin the optimizer.

Note

Make sure to filter the parameters based on requires_grad.

Example:

from torch.optim import Adam class MyModel(pl.LightningModule): ... def configure_optimizer(self): ... # Make sure to filter the parameters based on requires_grad ... return Adam(filter(lambda p: p.requires_grad, self.parameters())) ... class FeatureExtractorFreezeUnfreeze(BaseFinetuning): ... def init(self, unfreeze_at_epoch=10): ... super().init() ... self._unfreeze_at_epoch = unfreeze_at_epoch ... ... def freeze_before_training(self, pl_module): ... # freeze any module you want ... # Here, we are freezing feature_extractor ... self.freeze(pl_module.feature_extractor) ... ... def finetune_function(self, pl_module, current_epoch, optimizer): ... # When current_epoch is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( ... modules=pl_module.feature_extractor, ... optimizer=optimizer, ... train_bn=True, ... )

static filter_on_optimizer(optimizer, params)[source]

This function is used to exclude any parameter which already exists in this optimizer.

Parameters:

Return type:

list

Returns:

List of parameters not contained in this optimizer param groups

static filter_params(modules, train_bn=True, requires_grad=True)[source]

Yields the requires_grad parameters of a given module or list of modules.

Parameters:

Return type:

Generator

Returns:

Generator

finetune_function(pl_module, epoch, optimizer)[source]

Override to add your unfreeze logic.

Return type:

None

static flatten_modules(modules)[source]

This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.

Parameters:

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type:

list[Module]

Returns:

List of modules

static freeze(modules, train_bn=True)[source]

Freezes the parameters of the provided modules.

Parameters:

Return type:

None

Returns:

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type:

None

static freeze_module(module)[source]

Freezes the parameters of the provided module.

Parameters:

module (Module) – A given module

Return type:

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

static make_trainable(modules)[source]

Unfreezes the parameters of the provided modules.

Parameters:

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type:

None

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.

static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]

Unfreezes a module and adds its parameters to an optimizer.

Parameters:

Return type:

None