Booster API | Colossal-AI (original) (raw)

Author: Mingyan Jiang, Jianghai Chen, Baizhou Zhang

Prerequisite:

Example Code

Introduction

In our new design, colossalai.booster replaces the role of colossalai.initialize to inject features into your training components (e.g. model, optimizer, dataloader) seamlessly. With these new APIs, you can integrate your model with our parallelism features more friendly. Also, calling colossalai.booster is the standard procedure before you run into your training loops. In the sections below, we will cover how colossalai.booster works and what we should take note of.

Plugin

Plugin is an important component that manages parallel configuration (eg: The gemini plugin encapsulates the gemini acceleration solution). Currently supported plugins are as follows:

HybridParallelPlugin: This plugin wraps the hybrid parallel training acceleration solution. It provides an interface for any combination of tensor parallel, pipeline parallel and data parallel strategies including DDP and ZeRO.

GeminiPlugin: This plugin wraps the Gemini acceleration solution, that ZeRO with chunk-based memory management.

TorchDDPPlugin: This plugin wraps the DDP acceleration solution of Pytorch. It implements data parallel at the module level which can run across multiple machines.

LowLevelZeroPlugin: This plugin wraps the 1/2 stage of Zero Redundancy Optimizer. Stage 1 : Shards optimizer states across data parallel workers/GPUs. Stage 2 : Shards optimizer states + gradients across data parallel workers/GPUs.

TorchFSDPPlugin: This plugin wraps the FSDP acceleration solution of Pytorch and can be used to train models with zero-dp.

More details about usages of each plugin can be found in chapter Booster Plugins.

Some plugins support lazy initialization, which can be used to save memory when initializing large models. For more details, please see Lazy Initialization.

API of booster

class

colossalai.booster.Booster

(

device: typing.Optional[str] = None, mixed_precision: typing.Union[colossalai.booster.mixed_precision.mixed_precision_base.MixedPrecision, str, NoneType] = None, plugin: typing.Optional[colossalai.booster.plugin.plugin_base.Plugin] = None

)

Parameters

Description

Booster is a high-level API for training neural networks. It provides a unified interface for training with different precision, accelerator, and plugin.

Example

# Following is pseudocode

colossalai.launch(...)
plugin = GeminiPlugin(...)
booster = Booster(precision='fp16', plugin=plugin)

model = GPT2()
optimizer = HybridAdam(model.parameters())
dataloader = plugin.prepare_dataloader(train_dataset, batch_size=8)
lr_scheduler = LinearWarmupScheduler()
criterion = GPTLMLoss()

model, optimizer, criterion, dataloader, lr_scheduler = booster.boost(model, optimizer, criterion, dataloader, lr_scheduler)

for epoch in range(max_epochs):
    for input_ids, attention_mask in dataloader:
        outputs = model(input_ids.cuda(), attention_mask.cuda())
        loss = criterion(outputs.logits, input_ids)
        booster.backward(loss, optimizer)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
function

backward

(

loss: Tensor, optimizer: Optimizer

)

Parameters

Description

Execution of backward during training step.

function

boost

(

model: Module, optimizer: typing.Optional[torch.optim.optimizer.Optimizer] = None, criterion: typing.Optional[typing.Callable] = None, dataloader: typing.Optional[torch.utils.data.dataloader.DataLoader] = None, lr_scheduler: typing.Optional[torch.optim.lr_scheduler._LRScheduler] = None

)

Parameters

Returns

List[Union[nn.Module, Optimizer, LRScheduler, DataLoader]]: The list of boosted input arguments.

Description

Wrap and inject features to the passed in model, optimizer, criterion, lr_scheduler, and dataloader.

function

enable_lora

(

model: Module, pretrained_dir: typing.Optional[str] = None, lora_config: peft.LoraConfig = None, bnb_quantization_config: typing.Optional[colossalai.quantization.bnb_config.BnbQuantizationConfig] = None, quantize = False

)

Parameters

Description

Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory. Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.

function

execute_pipeline

(

data_iter: typing.Iterator, model: Module, criterion: typing.Callable[[typing.Any, typing.Any], torch.Tensor], optimizer: typing.Optional[torch.optim.optimizer.Optimizer] = None, return_loss: bool = True, return_outputs: bool = False

)

Parameters

data_iter(Iterator) -- The iterator for getting the next batch of data. Usually there are two ways to obtain this argument:

  1. wrap the dataloader to iterator through: iter(dataloader)
  2. get the next batch from dataloader, and wrap this batch to iterator: iter([batch])

Returns

Dict[str, Any]: Output dict in the form of {'loss': ..., 'outputs': ...}. ret_dict['loss'] is the loss of forward if return_loss is set to True, else None. ret_dict['outputs'] is the Huggingface style model outputs during forward if return_output is set to True, else None.

Description

Execute forward & backward when utilizing pipeline parallel. Return loss or Huggingface style model outputs if needed.

Warning: This function is tailored for the scenario of pipeline parallel. As a result, please don't do the forward/backward pass in the conventional way (model(input)/loss.backward()) when doing pipeline parallel training with booster, which will cause unexpected errors.

function

load_lr_scheduler

(

lr_scheduler: _LRScheduler, checkpoint: str

)

Parameters

Description

Load lr scheduler from checkpoint.

function

load_model

(

model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, strict: bool = True

)

Parameters

Description

Load model from checkpoint.

function

load_optimizer

(

optimizer: Optimizer, checkpoint: str

)

Parameters

Description

Load optimizer from checkpoint.

function

no_sync

(

model: Module = None, optimizer: OptimizerWrapper = None

)

Parameters

Returns

contextmanager: Context to disable gradient synchronization.

Description

Context manager to disable gradient synchronization across DP process groups. Support torch DDP and Low Level ZeRO-1 for now.

function

save_lora_as_pretrained

(

model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, use_safetensors: bool = False

)

Parameters

Description

Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.

function

save_lr_scheduler

(

lr_scheduler: _LRScheduler, checkpoint: str

)

Parameters

Description

Save lr scheduler to checkpoint.

function

save_model

(

model: typing.Union[torch.nn.modules.module.Module, colossalai.interface.model.ModelWrapper], checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024, use_safetensors: bool = False, use_async: bool = False

)

Parameters

Description

Save model to checkpoint.

function

save_optimizer

(

optimizer: Optimizer, checkpoint: str, shard: bool = False, gather_dtensor: bool = True, prefix: typing.Optional[str] = None, size_per_shard: int = 1024, use_async: bool = False

)

Parameters

Description

Save optimizer to checkpoint.

Usage

In a typical workflow, you should launch distributed environment at the beginning of training script and create objects needed (such as models, optimizers, loss function, data loaders etc.) firstly, then call booster.boost to inject features into these objects, After that, you can use our booster APIs and these returned objects to continue the rest of your training processes.

A pseudo-code example is like below:

import torch
from torch.optim import SGD
from torchvision.models import resnet18

import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin

def train():
    # launch colossalai
    colossalai.launch(rank=rank, world_size=world_size, port=port, host='localhost')

    # create plugin and objects for training
    plugin = TorchDDPPlugin()
    booster = Booster(plugin=plugin)
    model = resnet18()
    criterion = lambda x: x.mean()
    optimizer = SGD((model.parameters()), lr=0.001)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.1)

    # use booster.boost to wrap the training objects
    model, optimizer, criterion, _, scheduler = booster.boost(model, optimizer, criterion, lr_scheduler=scheduler)

    # do training as normal, except that the backward should be called by booster
    x = torch.randn(4, 3, 224, 224)
    x = x.to('cuda')
    output = model(x)
    loss = criterion(output)
    booster.backward(loss, optimizer)
    optimizer.clip_grad_by_norm(1.0)
    optimizer.step()
    scheduler.step()
    optimizer.zero_grad()

    # checkpointing using booster api
    save_path = "./model"
    booster.save_model(model, save_path, shard=True, size_per_shard=10, use_safetensors=True)

    new_model = resnet18()
    booster.load_model(new_model, save_path)

For more design details please see this page.