⚙️ Using the Trainer (original) (raw)

The Composer Trainer implements a highly-optimized PyTorch training loop for neural networks. Using the trainer gives you several superpowers:

Note

We use the two-way callback system developed by (Howard et al, 2020) to flexibly add the logic of our speedup methods during training.

Below are simple examples for getting started with the Composer Trainer along with code snippets for more advanced usage such as using speedup methods, checkpointing, and distributed training.

Getting Started#

Create a model class that meets the ComposerModel interface, minimally implementing the following methods:

For more information, see the ComposerModel guide.

A minimal example of a ResNet-18 model is shown here:

import torchvision import torch.nn.functional as F

from composer.models import ComposerModel

class ResNet18(ComposerModel): def init(self): super().init() self.model = torchvision.models.resnet18()

def forward(self, batch):
    inputs, _ = batch
    return self.model(inputs)

def loss(self, outputs, batch):
    _, targets = batch
    return F.cross_entropy(outputs, targets)

Then, the model can be passed to the trainer with the relevant torch objects.

import torch

trainer = Trainer( model=ResNet18(), train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, optimizers=torch.optim.Adam(lr=0.01), max_duration=10, # epochs device='gpu' )

trainer.fit()

When training is complete, both training and evaluation metrics can be accessed from the trainer state.

print(trainer.state.train_metrics) print(trainer.state.eval_metrics)

In the background, we automatically add the ProgressBarLogger to log training progress to the console.

A few tips and tricks for using our Trainer:

For a full list of Trainer options, see Trainer. Below, we illustrate some example use cases.

Training Loop#

Behind the scenes, our trainer handles much of the engineering for distributed training, gradient accumulation, device movement, gradient scaling, and others. The pseudocode for our trainer loop as it interacts with the ComposerModel is as follows:

training loop

for batch in train_dataloader:

outputs = model.forward(batch)
loss = model.loss(outputs, batch)

loss.backward()
optimizer.step()

eval loop

for batch in eval_dataloader: outputs = model.eval_forward(batch) for metric in model.get_metrics(is_train=False).values(): model.update_metric(batch, outputs, metric)

For the actual code, see the Trainer.fit() and Trainer.eval() methods.

Quick Tour#

Below is a quick tour of various elements with code snippets for your reference. See the more detailed sections in the navigation menu.

Events & State#

The core principle of the Composer trainer is to make it easy to inject custom logic to run at various points in the training loop. To do this, we have events that run before and after each of the lines above, e.g.

engine.run_event("before_forward") outputs = model.forward(batch) engine.run_event("after_forward")

Algorithms and callbacks (see below) register themselves to run on one or more events.

We also maintain a State which stores the trainer’s state, such as the model, optimizers, dataloader, current batch, etc. (seeState). This allows algorithms to modify the state during the various events above.

Algorithms#

The Composer trainer is designed to easily apply our library of algorithms to both train more efficiently and build better models. These can be enabled by passing the appropriate algorithm class to the algorithmsargument.

from composer import Trainer from composer.algorithms import LayerFreezing, MixUp

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='2ep', algorithms=[ LayerFreezing(freeze_start=0.5, freeze_level=0.1), MixUp(alpha=0.1), ])

the algorithms will automatically be applied during the appropriate

points of the training loop

trainer.fit()

We handle inserting algorithms into the training loop and in the right order.

Optimizers & Schedulers#

You can easily specify which optimizer and learning rate scheduler to use during training. Composer supports both PyTorch schedulers as well as Composer’s custom schedulers.

from composer import Trainer from composer.models.tasks import ComposerClassifier import torchvision.models as models from torch.optim import SGD from torch.optim.lr_scheduler import LinearLR

model = ComposerClassifier(module=models.resnet18(), num_classes=1000) optimizer = SGD(model.parameters(), lr=0.1) scheduler = LinearLR(optimizer)

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='90ep', optimizers=optimizer, schedulers=scheduler )

Composer’s own custom schedulers are versions that support theTime abstraction. Time related inputs such as stepor T_max can be provided in many units, from epochs ("10ep") to batches ("2048ba") to duration ("0.7dur").

For example, the below would step the learning rate at 30%, 50%, and 90% of the way through the training process:

from composer import Trainer from composer.optim.scheduler import MultiStepScheduler

trainer = Trainer( model=model, train_dataloader=train_dataloader, max_duration='90ep', schedulers=MultiStepScheduler( milestones=['0.3dur', '0.5dur', '0.9dur'], gamma=0.1 ))

See 📉 Schedulers for details.

Training on GPU#

Control which device you use for training with the device parameter, and we will handle the data movement and other systems-related engineering. We currently support the cpu, gpu and tpu devices.

from composer import Trainer

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='2ep', device='cpu' )

Training on M1 chips (beta)#

To train models on Apple M-series chips, we support device='mps'. Note that this requires having torch >= 1.12 installed, as well as Mac OSX 12.3+.

For more details, see: Pytorch Release Blog.

from composer import Trainer

trainer = Trainer( ..., device='mps', )

Training on TPU (beta)#

Beta support: train your models on single core tpusin bf16 precision. You will need to have torch_xla installed using instructions here https://github.com/pytorch/xla.

from composer import Trainer

The user needs to first move the model to the xla device before sending it to the trainer.

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='2ep', device='tpu' )

Note

We will add multi-core support in future releases.

Distributed Training#

It’s also simple to do data-parallel training on multiple GPUs. Composer provides a launcher command that works with the trainer and handles all the torch.distributed setup for you.

run_trainer.py

from composer import Trainer

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu' ) trainer.fit()

Access the Composer launcher via the composer command line program. Specify the number of GPUs you’d like to use with the -n flag along with the file containing your training script. Use composer --help to see a full list of configurable options.

run training on 8 GPUs

$ composer -n 8 run_trainer.py

For multiple GPUs, the batch_size for each dataloader should be the per-device batch size. For example, to use a total batch size of 2048 with data parallel across 8 GPUs the dataloader should set batch_size=256.

Warning

For distributed training, your dataloader should use thetorch.utils.data.distributed.DistributedSampler. If you are running multi-node, and each rank does not have a copy of the dataset, then a normal sampler can be used.

DeepSpeed Integration#

Composer comes with DeepSpeed support, allowing you to leverage their full set of features that makes it easier to train large models across (1) any type of GPU and (2) multiple nodes. For more details on DeepSpeed, see their website.

To enable DeepSpeed, simply pass in a config as specified in the DeepSpeed docs here.

run_trainer.py

from composer import Trainer

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu', deepspeed_config={ "train_batch_size": 2048, "fp16": {"enabled": True}, })

Providing an empty dictionary to DeepSpeed is also valid. The DeepSpeed defaults will be used and other fields (such as precision) will be inferred from the trainer.

Warning

The deepspeed_config must not conflict with any other parameters passed to the trainer.

FSDP Integration (beta)#

Composer comes with preliminary FSDP support, which allows you to leverage their features and enables you to train large models across multiple nodes. For more details on FSDP, see their website.

To enable FSDP, simply pass in as shown below:

fsdp_config = { 'sharding_strategy': 'FULL_SHARD', 'cpu_offload': False, # Not supported yet 'mixed_precision': 'DEFAULT', 'backward_prefetch': 'BACKWARD_POST', 'activation_checkpointing': False, 'activation_cpu_offload': False, 'verbose': True }

trainer = Trainer( model=composer_model, parallelism_config={ 'fsdp': fsdp_config, }, ... )

trainer.fit()

Callbacks#

You can insert arbitrary callbacks to be run at various points during the training loop. The Composer library provides several useful callbacks for things such as monitoring throughput and memory usage during training, but you can also implement your own.

from composer import Trainer from composer.callbacks import SpeedMonitor

include a callback for tracking throughput/step during training

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu', callbacks=[SpeedMonitor(window_size=100)] )

Numerics#

The trainer automatically handles multiple precision types such as fp32 or, for GPUs,amp (automatic mixed precision), which is PyTorch’s built-in method for training in 16-bit floating point. For more details on amp, see torch.cuda.amp and the paper by Micikevicius et al, 2018

We recommend using amp on GPUs to accelerate your training.

from composer import Trainer

use mixed precision during training

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu', precision='amp' )

Checkpointing#

The Composer trainer makes it easy to (1) save checkpoints at various points during training and (2) load them back to resume training later.

from composer import Trainer

Saving checkpoints

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu', # Checkpointing params save_folder='checkpoints', save_interval='1ep', )

will save checkpoints to the 'checkpoints' folder every epoch

trainer.fit()

from composer import Trainer

Loading checkpoints

trainer = Trainer( model=model, train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, max_duration='160ep', device='gpu', # Checkpointing params load_path='path/to/checkpoint/mosaic_states.pt' )

will load the trainer state (including model weights) from the

load_path before resuming training

trainer.fit()

Gradient Accumulation#

Composer supports gradient accumulation, which allows training arbitrary logical batch sizes on any hardware by breaking the batch into different microbatches of size device_train_microbatch_size.

from composer import Trainer

trainer = Trainer( ..., device_train_microbatch_size=2, )

If device_train_microbatch_size=auto, Composer will try to automatically determine the largest device_train_microbatch_size which the current hardware supports. In order to support automatic microbatching, Composer initially sets device_train_microbatch_size=batch_size. During the training process, if a Cuda Out of Memory Exception is encountered, indicating the current batch size is too large for the hardware, Composer catches this exception and continues training after halving device_train_microbatch_size. As a secondary benefit, automatic gradient accumulation is able to dynamically adjust throughout the training process. For example, when usingProgressiveResizing, input size increases throughout training. Composer automatically decreases device_train_microbatch_size only when required, such as when a Cuda OOM is encountered due to larger images, allowing for faster training at the start until image sizes are scaled up. Note that this feature is experimental and may not work with all algorithms.

Reproducibility#

The random seed can be provided to the trainer directly, e.g.

from composer import Trainer

trainer = Trainer( ..., seed=42, )

If no seed is provided, a random seed will be generated from the system time.

Since the model and dataloaders are initialized outside of the Trainer, for complete determinism we recommend calling seed_all() and/orconfigure_deterministic_mode() before creating any objects. For example:

import torch.nn as nn from composer.utils import reproducibility

reproducibility.configure_deterministic_mode() reproducibility.seed_all(42)

model = MyModel()

def init_weights(m): if isinstance(m, torch.nn.Linear): nn.init.xavier_uniform(m.weight)

model will now be deterministically initialized, since the seed is set.

init_weights(model) trainer = Trainer(model=model, seed=42)

Note that the Trainer must still be seeded.

This was just a quick tour of the features available within our trainer. Please see the other guides and notebooks for further details.