State (original) (raw)

Edit this page

Toggle table of contents sidebar

class composer.State(model, rank_zero_seed, run_name, device, max_duration=None, device_train_microbatch_size=None, auto_microbatching=False, train_dataloader=None, evaluators=None, dataloader=None, dataloader_label=None, dataloader_len=- 1, dataset_state=None, dataset_resumption=None, precision=Precision.FP32, precision_config=None, optimizers=None, scaler=None, save_metrics=False, algorithms=None, callbacks=None, deepspeed_config=None, parallelism_config=None)[source]#

The state of the trainer.

Contains variables that the trainer tracks throughout the training loop. Note that all the necessary parts (i.e.,serialized_attributes) of state are serialized when the trainer is checkpointed so that it can be used to restore the trainer and continue training from a checkpoint. algorithms are able to modify an instance of this class in-place.

Note

An instance of this class is automatically constructed by the Trainer constructor. A user need not instantiate this class.

Parameters

batch#

The batch. This will be the entire batch during the Event.AFTER_DATALOADER, or a microbatch between Event.BATCH_START and Event.BATCH_END.

Type

types.Batch

device#

The device used by this process. The trainer moves the model and loaded data to this device. This can be used in callbacks and algorithms to move data onto the correct device.

Type

Device

train_metrics#

The current train metrics, organized by metric name. train_metrics will be deep-copied to ensure that each evaluator updates only its train_metrics.

For example:

trainer = Trainer( ... ..., ... train_dataloader=train_dataloader, ... eval_dataloader=eval_dataloader, ... ) trainer.fit() trainer.state.train_metrics {'MulticlassAccuracy': MulticlassAccuracy()}

Type

dict[str, Metric]

eval_metrics#

The current evaluation metrics, organized by dataloader label and then by metric name. If not using an Evaluator, the eval dataloader is labeled 'eval'. Otherwise, in the case of having multiple evaluation datasets, the evaluator label is used. See the Multiple Datasets Documentationfor more information. eval_metrics will be deep-copied to ensure that each evaluator updates only its eval_metrics.

For example: >>> from composer.metrics import CrossEntropy >>> trainer = Trainer( … …, … train_dataloader=train_dataloader, … eval_dataloader=eval_dataloader, … ) >>> trainer.fit() >>> trainer.state.eval_metrics {‘eval’: {‘CrossEntropy’: CrossEntropy(), ‘MulticlassAccuracy’: MulticlassAccuracy()}}

Or, when using an Evaluator for multiple evaluation datasets:

from composer.core import Evaluator trainer = Trainer( ... ..., ... train_dataloader=train_dataloader, ... eval_dataloader=[ ... Evaluator(label='eval1', dataloader=eval_1_dl, metric_names=['MulticlassAccuracy']), ... Evaluator(label='eval2', dataloader=eval_2_dl, metric_names=['MulticlassAccuracy']), ... ], ... ) trainer.fit() trainer.state.eval_metrics {'eval1': {'MulticlassAccuracy': MulticlassAccuracy()}, 'eval2': {'MulticlassAccuracy': MulticlassAccuracy()}}

Type

dict[str, dict[str, Metric]]

eval_timestamp#

The timestamp for the current evaluation dataloader. This timestamp is reset before the dataloader is evaluated. The epoch attribute for this timestamp is always0.

Type

Timestamp

device_train_microbatch_size#

The size of each train microbatch per device.

Type

int | float

loss#

The most recently computed loss.

Type

Tensor | _Sequence_[Tensor] | dict[_Any_, Tensor]

model#

The training model.

Note

When using DeepSpeed or multi-rank training, the model will be wrapped withDeepSpeedEngine or DistributedDataParallel, respectively.

Type

Module

outputs#

The most recently computed output from the model’s forward pass.

Type

Tensor | _Sequence_[Tensor]

predict_timestamp#

The timestamp for the current prediction dataloader. This timestamp is reset before the dataloader is used. The epoch attribute for this timestamp is always0.

Type

Timestamp

profiler#

The profiler (if profiling is enabled), or None if not profiling.

Type

Profiler

rank_zero_seed#

The seed of the rank zero process.

Type

int

run_name#

The name for this training run.

Type

str

scaler#

The gradient scaler if using mixed-precision training, orNone if not using mixed-precision training.

Type

torch.amp.GradScaler

serialized_attributes#

The names of the attribute which are serialized in a checkpoint.

By default, the following attributes are serialized:

Attribute Description
model The model under training.
optimizers The optimizers being used to train the model.
schedulers The learning rate schedulers.
algorithms The algorithms used for training.
callbacks The callbacks used for training.
scaler The gradient scaler in use for mixed precision training.
timestamp The timestamp that tracks training loop progress.
rank_zero_seed The seed of the rank zero process.
train_metrics The current training metrics
eval_metrics The current evaluation metrics
run_name The run name for training.
dataset_state The dataset iteration state.

Type

list[str]

timestamp#

The current training timestamp.

Type

Timestamp

property algorithms#

The algorithms.

batch_get_item(key)[source]#

Gets element from batch either specified by key or user-specified function.

See batch_get in utils/batch_helpers.py for examples.

Parameters

key (str | int | tuple[ Callable , Callable ] | Any , optional) – A key to index into the batch or a user-specified function to do the extracting. A pair of callables is also supported for cases where a get and set function pair are both passed (like in Algorithms). The getter is assumed to be the first of the pair.

Returns

The part of the batch specified by the key. This could be any type – depending on what the batch is composed of.

batch_set_item(key, value)[source]#

Sets the element specified by the key of the set_fn to the specified value.

This is not an in-place operation, as for tuple-typed batches, a new batch object must be created to modify them.

See batch_set in utils/batch_helpers.py for examples.

Parameters

Returns

batch (Any) – The updated batch with value set at key.

property callbacks#

The callbacks.

property dataloader#

The active dataloader.

property dataloader_label#

The dataloader label for the active dataloader.

By default, the training dataloader is called 'train'. The evaluator dataloader is called 'eval', or when multiple evaluators are used, the name of the evaluator. However, the dataloader label can be explicitly specified in Trainer.fit()and Trainer.eval().

Returns

Optional[str] – The dataloader label, or None if no dataloader is set.

property dataloader_len#

The number of batches per dataloader iteration (e.g. epoch), as used by the trainer.

Note

If not explicitly specified, this value is an approximation, as it depends on len(self.dataloader). See the PyTorch DataLoader Documentation for more information.

Returns

property deepspeed_enabled#

Indicates if deepspeed is enabled.

property deepspeed_model#

Cast model to DeepSpeedEngine.

property evaluators#

The evaluators.

property fsdp_enabled#

Indicates if FSDP is enabled.

get_elapsed_duration()[source]#

Get the elapsed training duration.

Returns

Optional[Time[float]] – The elapsed duration, in TimeUnit.DURATION.Time(0.0, TimeUnit.DURATION) represents the beginning of training and Time(1.0, TimeUnit.DURATION)represents a completed training process. Returns None if max_duration is None.

get_model_state_dict()[source]#

Collect the state dict for the model.

Returns

dict[str, Any] – The state dict for the model.

get_optim_state_dict()[source]#

Collect the state dict for the optimizer.

Returns

dict[str, Any] – The state dict for the optimizer.

property is_model_ddp#

Whether model is an instance of a DistributedDataParallel.

load_model_state(state_dict, logger, strict, exclude_algorithms=None, algorithm_passes=None)[source]#

Loads the model’s state from a state_dict.

Parameters

load_optim_state(state_dict, strict=True)[source]#

Load the optimizer state.

Parameters

load_state_dict(state, logger, strict=False, exclude_algorithms=None, algorithm_passes=None)[source]#

Loads the state.

Parameters

property max_duration#

The maximum training duration.

property optimizers#

The optimizers.

property precision#

The numerical precision to use for training.

See Precision for the supported precisions.

property precision_config#

The config for FP8 scaling strategy.

See parameters for DelayedScaling.

property schedulers#

The schedulers.

property seed#

The seed for the current rank.

set_dataloader(dataloader=None, dataloader_label=None, dataloader_len=- 1)[source]#

Update the active dataloader and dataloader label.

Parameters

state_dict()[source]#

Collect the state dicts of our serializable attributes.

Returns

dict[str, Any] – The state dict.

stop_training()[source]#

Gracefully stop training.

The current batch of training will finish, and any scheduled evaluation, logging, and evaluation for that batch, as well as any epoch end events.

property train_dataloader#

Get the train dataloader.

Returns

Iterable | DataLoader, optional – The dataloader.