hivemind.optim — hivemind latest documentation (original) (raw)

hivemind

This module contains decentralized optimizers that wrap your regular PyTorch Optimizer to train with peers. Depending on the exact configuration, Optimizer may perform large synchronous updates equivalent, or perform asynchronous local updates and average model parameters.

hivemind.Optimizer

class hivemind.optim.optimizer.Optimizer(*, dht: DHT, run_id: str, target_batch_size: int, batch_size_per_step: Optional[int] = None, optimizer: Union[Optimizer, Callable[[Union[Iterable[Tensor], Iterable[Dict[str, Any]]]], Optimizer]], params: Optional[Union[Iterable[Tensor], Iterable[Dict[str, Any]]]] = None, scheduler: Optional[Union[LRScheduler, Callable[[Optimizer], LRScheduler]]] = None, matchmaking_time: Optional[float] = 15.0, averaging_timeout: Optional[float] = 60.0, allreduce_timeout: Optional[float] = None, next_chunk_timeout: Optional[float] = None, load_state_timeout: float = 600.0, reuse_grad_buffers: bool = False, offload_optimizer: Optional[bool] = None, delay_optimizer_step: Optional[bool] = None, delay_grad_averaging: bool = False, delay_state_averaging: bool = True, average_state_every: int = 1, use_local_updates: bool = False, client_mode: bool = None, auxiliary: bool = False, grad_compression: CompressionBase = hivemind.NoCompression(), grad_averager_factory: Optional[Callable[[...], TGradientAverager]] = None, state_averaging_compression: CompressionBase = hivemind.NoCompression(), load_state_compression: CompressionBase = hivemind.NoCompression(), average_opt_statistics: Sequence[str] = (), extra_tensors: Sequence[Tensor] = (), averager_opts: Optional[dict] = None, tracker_opts: Optional[dict] = None, performance_ema_alpha: float = 0.1, shutdown_timeout: float = 5, verbose: bool = False)[source]

hivemind.Optimizer wraps your regular PyTorch Optimizer for training collaboratively with peers.

By default, Optimizer is configured to be exactly equivalent to synchronous training with target_batch_size. There are advanced options make training semi-asynchronous (delay_optimizer_step and delay_gradient_averaging) or even fully asynchronous (use_local_updates=True).

Example

The Optimizer can be used as a drop-in replacement for a regular PyTorch Optimizer:

model = transformers.AutoModel("albert-xxlarge-v2") dht = hivemind.DHT(initial_peers=INITIAL_PEERS, start=True) opt = hivemind.Optimizer(dht=dht, run_id="run_42", batch_size_per_step=4, target_batch_size=4096, params=model.parameters(), optimizer=lambda params: torch.optim.Adam(params)) while True: loss = compute_loss_on_batch(model, batch_size=4) opt.zero_grad() loss.backward() opt.step() # <-- train collaboratively with any peers that use the same prefix (run_42)

By default, peers will perform the following steps:

Unlike regular training, your device may join midway through training, when other peers already made some progress. For this reason, any learning rate schedulers, curriculum and other time-dependent features should be based on optimizer.local_epoch (and not the number of calls to opt.step). Otherwise, peers that joined training late may end up having different learning rates. To do so automatically, specify scheduler=... parameter below.

What is an epoch?

Optimizer uses the term epoch to describe intervals between synchronizations. One epoch corresponds to processing certain number of training samples (target_batch_size) in total across all peers. Like in PyTorch LR Scheduler, **epoch does not necessarily correspond to a full pass over the training data.**At the end of epoch, peers perform synchronous actions such as averaging gradients for a global optimizer update, updating the learning rate scheduler or simply averaging parameters (if using local updates). The purpose of this is to ensure that changing the number of peers does not require changing hyperparameters. For instance, if the number of peers doubles, they will run all-reduce more frequently to adjust for faster training.

Configuration guide

This guide will help you set up your first collaborative training run. It covers the most important basic options, but ignores features that require significant changes to the training code.

dht = hivemind.DHT(initial_peers=INITIAL_PEERS, client_mode=IF_BEHIND_FIREWALL_OR_VERY_UNRELIABLE, start=True) opt = hivemind.Optimizer( dht=dht, run_id="a_unique_name_that_every_participant_will_see_when_training", batch_size_per_step=ACTUAL_BATCH_SIZE_OF_THIS_PEER, target_batch_size=LARGE_GLOBAL_BATCH,

^--- Each global optimzier step will use gradients from 1x-1.1x of target_batch_size (due to latency);

It is recommended to train with very large batch sizes to reduce the % of time spent on communication.

params=params, optimizer=lambda params: AnyPyTorchOptimizer(params, **hyperparams_for_target_batch_size),

tune learning rate for your target_batch_size. Here's a good reference: https://arxiv.org/abs/1904.00962

scheduler=lambda opt: AnyPyTorchScheduler(opt, **hyperparams_for_target_batch_size),

scheduler.step will be called automatically each time when peers collectively accumulate target_batch_size

offload_optimizer=True, # saves GPU memory, but increases RAM usage; Generally a good practice to use this. delay_grad_averaging=OPTIONAL, delay_optimizer_step=OPTIONAL, # train faster, but with 1 round of staleness;

setting both to True is equivalent to Delayed Parameter Updates (see https://arxiv.org/abs/2101.06840)

grad_compression=hivemind.compression.Float16Compression(), state_averaging_compression=hivemind.compression.Float16Compression(),

^-- it is usually fine to use pure 16-bit or even lower precision during communication with no precaution;

See hivemind/examples/albert for an working example of mixed 8/16-bit compression.

matchmaking_time=15.0, # 3-5s for small local runs, 10-15s for training over the internet or with many peers averaging_timeout=60.0, # around of 2x the actual time it takes to run all-reduce verbose=True # periodically report the training progress to the console (e.g. "Averaged with N peers") ) # and you're done!

Parameters

Note

in a large-scale training, peers will inevitably fail and you will see error messages. hivemind.Optimizer is designed to recover from such failures, but will sometimes need a minute or two to re-adjust.

property local_epoch_: int_[source]

This worker’s current epoch, kept synchronized with peers. If peer’s local_epoch lags behind others, it will automatically re-synchronize by downloading state from another peer. An epoch corresponds to accumulating target_batch_size across all active devices.

step(closure: Optional[Callable[[], Tensor]] = None, batch_size: Optional[int] = None, grad_scaler: Optional[GradScaler] = None)[source]

Update training progress after accumulating another local batch size. Depending on the configuration, this will report progress to peers, run global or local optimizer step, average parameters or schedule background tasks.

Parameters

Note

this .step is different from normal pytorch optimizers in several key ways. See __init__ for details.

zero_grad(set_to_none: bool = True)[source]

Reset gradients from model. If reuse_grad_buffers=True, this will raise an error.

load_state_from_peers(**kwargs)[source]

Attempt to load the newest collaboration state from other peers within the same run_id.

If successful, this will update parameters, optimizer state, local epoch and learning rate schedule in-place.

class hivemind.optim.grad_scaler.GradScaler(*args, **kwargs)[source]

A wrapper over pytorch GradScaler made specifically for training hivemind.Optimizer with reuse_grad_buffers=True.

Note

if not using reuse_grad_buffers=True, one can and should train normally without this class, e.g. using standard PyTorch AMP or Apex. This custom GradScaler is more memory-efficient, but requires custom training code.

hivemind.GradScaler makes 3 modifications to the regular PyTorch AMP:

Note

The above modiffications will be enabled automatically. One can (and should) use hivemind.GradScaler exactly as regular torch.amp.GradScaler.