Callbacks (original) (raw)

Callbacks are objects that can customize the behavior of the training loop in the PyTorchTrainer (this feature is not yet implemented in TensorFlow) that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

Callbacks are “read only” pieces of code, apart from the TrainerControl object they return, they cannot change anything in the training loop. For customizations that require changes in the training loop, you should subclass Trainer and override the methods you need (see for examples).

By default, TrainingArguments.report_to is set to "all", so a Trainer will use the following callbacks.

If a package is installed but you don’t wish to use the accompanying integration, you can change TrainingArguments.report_to to a list of just those integrations you want to use (e.g. ["azure_ml", "wandb"]).

The main class that implements callbacks is TrainerCallback. It gets theTrainingArguments used to instantiate the Trainer, can access that Trainer’s internal state via TrainerState, and can take some actions on the training loop viaTrainerControl.

Available Callbacks

Here is the list of the available TrainerCallback in the library:

class transformers.integrations.CometCallback

< source >

( )

A TrainerCallback that sends the logs to Comet ML.

Setup the optional Comet integration.

Environment:

For a number of configurable items in the environment, seehere.

class transformers.DefaultFlowCallback

< source >

( )

A TrainerCallback that handles the default flow of the training loop for logs, evaluation and checkpoints.

class transformers.ProgressCallback

< source >

( max_str_len: int = 100 )

A TrainerCallback that displays the progress of training or evaluation. You can modify max_str_len to control how long strings are truncated when logging.

class transformers.EarlyStoppingCallback

< source >

( early_stopping_patience: int = 1 early_stopping_threshold: typing.Optional[float] = 0.0 )

Parameters

A TrainerCallback that handles early stopping.

This callback depends on TrainingArguments argument load_best_model_at_end functionality to set best_metric in TrainerState. Note that if the TrainingArguments argument save_steps differs from eval_steps, the early stopping will not occur until the next save step.

class transformers.integrations.TensorBoardCallback

< source >

( tb_writer = None )

Parameters

A TrainerCallback that sends the logs to TensorBoard.

class transformers.integrations.WandbCallback

< source >

( )

A TrainerCallback that logs metrics, media, model checkpoints to Weight and Biases.

Setup the optional Weights & Biases (wandb) integration.

One can subclass and override this method to customize the setup if needed. Find more informationhere. You can also override the following environment variables:

Environment:

class transformers.integrations.MLflowCallback

< source >

( )

A TrainerCallback that sends the logs to MLflow. Can be disabled by setting environment variable DISABLE_MLFLOW_INTEGRATION = TRUE.

Setup the optional MLflow integration.

Environment:

class transformers.integrations.NeptuneCallback

< source >

( api_token: typing.Optional[str] = None project: typing.Optional[str] = None name: typing.Optional[str] = None base_namespace: str = 'finetuning' run = None log_parameters: bool = True log_checkpoints: typing.Optional[str] = None **neptune_run_kwargs )

Parameters

TrainerCallback that sends the logs to Neptune.

For instructions and examples, see the Transformers integration guide in the Neptune documentation.

class transformers.integrations.ClearMLCallback

< source >

( )

A TrainerCallback that sends the logs to ClearML.

Environment:

class transformers.integrations.DagsHubCallback

< source >

( )

A TrainerCallback that logs to DagsHub. Extends MLflowCallback

Setup the DagsHub’s Logging integration.

Environment:

class transformers.integrations.FlyteCallback

< source >

( save_log_history: bool = True sync_checkpoints: bool = True )

Parameters

A TrainerCallback that sends the logs to Flyte. NOTE: This callback only works within a Flyte task.

Example:

from flytekit import current_context, task

@task def train_hf_transformer(): cp = current_context().checkpoint trainer = Trainer(..., callbacks=[FlyteCallback()]) output = trainer.train(resume_from_checkpoint=cp.restore())

class transformers.integrations.DVCLiveCallback

< source >

( live: typing.Optional[typing.Any] = None log_model: typing.Union[typing.Literal['all'], bool, NoneType] = None **kwargs )

Parameters

A TrainerCallback that sends the logs to DVCLive.

Use the environment variables below in setup to configure the integration. To customize this callback beyond those environment variables, see here.

Setup the optional DVCLive integration. To customize this callback beyond the environment variables below, seehere.

Environment:

class transformers.integrations.SwanLabCallback

< source >

( )

A TrainerCallback that logs metrics, media, model checkpoints to SwanLab.

Setup the optional SwanLab (swanlab) integration.

One can subclass and override this method to customize the setup if needed. Find more informationhere.

You can also override the following environment variables. Find more information about environment variables here

Environment:

TrainerCallback

class transformers.TrainerCallback

< source >

( )

Parameters

A class for objects that will inspect the state of the training loop at some events and take some decisions. At each of those events the following arguments are available:

The control object is the only one that can be changed by the callback, in which case the event that changes it should return the modified version.

The argument args, state and control are positionals for all events, all the others are grouped in kwargs. You can unpack the ones you need in the signature of the event using them. As an example, see the code of the simple PrinterCallback.

Example:

class PrinterCallback(TrainerCallback): def on_log(self, args, state, control, logs=None, **kwargs): _ = logs.pop("total_flos", None) if state.is_local_process_zero: print(logs)

on_epoch_begin

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the beginning of an epoch.

on_epoch_end

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the end of an epoch.

on_evaluate

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called after an evaluation phase.

on_init_end

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the end of the initialization of the Trainer.

on_log

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called after logging the last logs.

on_optimizer_step

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called after the optimizer step but before gradients are zeroed out. Useful for monitoring gradients.

on_pre_optimizer_step

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called before the optimizer step but after gradient clipping. Useful for monitoring gradients.

on_predict

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl metrics **kwargs )

Event called after a successful prediction.

on_prediction_step

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called after a prediction step.

on_save

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called after a checkpoint save.

on_step_begin

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs.

on_step_end

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

on_substep_end

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the end of an substep during gradient accumulation.

on_train_begin

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the beginning of training.

on_train_end

< source >

( args: TrainingArguments state: TrainerState control: TrainerControl **kwargs )

Event called at the end of training.

Here is an example of how to register a custom callback with the PyTorch Trainer:

class MyCallback(TrainerCallback): "A callback that prints a message at the beginning of training"

def on_train_begin(self, args, state, control, **kwargs):
    print("Starting training")

trainer = Trainer( model, args, train_dataset=train_dataset, eval_dataset=eval_dataset, callbacks=[MyCallback],
)

Another way to register a callback is to call trainer.add_callback() as follows:

trainer = Trainer(...) trainer.add_callback(MyCallback)

trainer.add_callback(MyCallback())

TrainerState

class transformers.TrainerState

< source >

( epoch: typing.Optional[float] = None global_step: int = 0 max_steps: int = 0 logging_steps: int = 500 eval_steps: int = 500 save_steps: int = 500 train_batch_size: typing.Optional[int] = None num_train_epochs: int = 0 num_input_tokens_seen: int = 0 total_flos: float = 0 log_history: list = None best_metric: typing.Optional[float] = None best_global_step: typing.Optional[int] = None best_model_checkpoint: typing.Optional[str] = None is_local_process_zero: bool = True is_world_process_zero: bool = True is_hyper_param_search: bool = False trial_name: typing.Optional[str] = None trial_params: dict = None stateful_callbacks: list = None )

Parameters

A class containing the Trainer inner state that will be saved along the model and optimizer when checkpointing and passed to the TrainerCallback.

In all this class, one step is to be understood as one update step. When using gradient accumulation, one update step may require several forward and backward passes: if you use gradient_accumulation_steps=n, then one update step requires going through n batches.

Calculates and stores the absolute value for logging, eval, and save steps based on if it was a proportion or not.

init_training_references

< source >

( trainer max_steps num_train_epochs trial )

Stores the initial training references needed in self

Create an instance from the content of json_path.

Save the content of this instance in JSON format inside json_path.

TrainerControl

class transformers.TrainerControl

< source >

( should_training_stop: bool = False should_epoch_stop: bool = False should_save: bool = False should_evaluate: bool = False should_log: bool = False )

Parameters

A class that handles the Trainer control flow. This class is used by the TrainerCallback to activate some switches in the training loop.

< > Update on GitHub