apex.parallel — Apex 0.1.0 documentation (original) (raw)

Apex

class apex.parallel. DistributedDataParallel(module, message_size=10000000, delay_allreduce=False, shared_param=None, allreduce_trigger_params=None, retain_allreduce_buffers=False, allreduce_always_fp32=False, num_allreduce_streams=1, allreduce_communicators=None, gradient_average=True, gradient_predivide_factor=1.0, gradient_average_split_factor=None, prof=False)[source]

apex.parallel.DistributedDataParallel is a module wrapper that enables easy multiprocess distributed data parallel training, similar to torch.nn.parallel.DistributedDataParallel. Parameters are broadcast across participating processes on initialization, and gradients are allreduced and averaged over processes during backward().

DistributedDataParallel is optimized for use with NCCL. It achieves high performance by overlapping communication with computation during backward() and bucketing smaller gradient transfers to reduce the total number of transfers required.

DistributedDataParallel is designed to work with the upstream launch utility scripttorch.distributed.launch with --nproc_per_node <= number of gpus per node. When used with this launcher, DistributedDataParallel assumes 1:1 mapping of processes to GPUs. It also assumes that your script calls torch.cuda.set_device(args.rank) before creating the model.

https://github.com/NVIDIA/apex/tree/master/examples/simple/distributed shows detailed usage.https://github.com/NVIDIA/apex/tree/master/examples/imagenet shows another example that combines DistributedDataParallel with mixed precision training.

Parameters

Warning

If gradient_average=False, the pre-allreduce division (grads.mul_(1.0/gradient_predivide_factor)) will still be applied, but the post-allreduce gradient averaging (grads.mul_(gradient_predivide_factor/world size)) will be omitted.

forward(*inputs, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class apex.parallel. Reducer(module_or_grads_list)[source]

apex.parallel.Reducer is a simple class that helps allreduce a module’s parameters across processes. Reducer is intended to give the user additional control: Unlike DistributedDataParallel, Reducer will not automatically allreduce parameters during backward(). Instead, Reducer waits for the user to call <reducer_instance>.reduce() manually. This enables, for example, delaying the allreduce to be carried out every several iterations instead of every single iteration.

Like DistributedDataParallel, Reducer averages any tensors it allreduces over the number of participating processes.

Reducer is designed to work with the upstream launch utility scripttorch.distributed.launch with --nproc_per_node <= number of gpus per node. When used with this launcher, Reducer assumes 1:1 mapping of processes to GPUs. It also assumes that your script calls torch.cuda.set_device(args.rank) before creating the model.

Parameters

module_or_grads_list – Either a network definition (module) being run in multi-gpu/distributed mode, or an iterable of gradients to be reduced. If a module is passed in, the Reducer constructor will sync the parameters across processes (broadcasting from rank 0) to make sure they’re all initialized with the same values. If a list of gradients (that came from some module) is passed in, the user is responsible for manually syncing that module’s parameters at the beginning of training.

class apex.parallel. SyncBatchNorm(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False)[source]

synchronized batch normalization module extented from torch.nn.BatchNormNdwith the added stats reduction across multiple processes.apex.parallel.SyncBatchNorm is designed to work withDistributedDataParallel.

When running in training mode, the layer reduces stats across all processes to increase the effective batchsize for normalization layer. This is useful in applications where batch size is small on a given process that would diminish converged accuracy of the model. The model uses collective communication package from torch.distributed.

When running in evaluation mode, the layer falls back totorch.nn.functional.batch_norm

Parameters

Examples::

channel first tensor

sbn = apex.parallel.SyncBatchNorm(100).cuda() inp = torch.randn(10, 100, 14, 14).cuda() out = sbn(inp) inp = torch.randn(3, 100, 20).cuda() out = sbn(inp)

channel last tensor

sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda() inp = torch.randn(10, 14, 14, 100).cuda()

forward(input, z=None)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Utility functions

apex.parallel. convert_syncbn_model(module, process_group=None, channel_last=False)[source]

Recursively traverse module and its children to replace all instances oftorch.nn.modules.batchnorm._BatchNorm with apex.parallel.SyncBatchNorm.

All torch.nn.BatchNorm*N*d wrap aroundtorch.nn.modules.batchnorm._BatchNorm, so this function lets you easily switch to use sync BN.

Parameters

module (torch.nn.Module) – input module

Example:

model is an instance of torch.nn.Module

import apex sync_bn_model = apex.parallel.convert_syncbn_model(model)