>> # channel first tensor >>> sbn = apex.parallel.SyncBatchNorm(100).cuda() >>> inp = torch.randn(10, 100, 14, 14).cuda() >>> out = sbn(inp) >>> inp = torch.randn(3, 100, 20).cuda() >>> out = sbn(inp) >>> # channel last tensor >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda() >>> inp = torch.randn(10, 14, 14, 100).cuda() """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False): super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats) self.process_group = process_group self.channel_last = channel_last self.fuse_relu = fuse_relu def _specify_process_group(self, process_group): self.process_group = process_group def _specify_channel_last(self, channel_last): self.channel_last = channel_last">

apex.parallel.optimized_sync_batchnorm — Apex 0.1.0 documentation (original) (raw)

import torch from torch.nn.modules.batchnorm import _BatchNorm from torch.nn import functional as F

import syncbn from .optimized_sync_batchnorm_kernel import SyncBatchnormFunction

[docs]class SyncBatchNorm(_BatchNorm): """ synchronized batch normalization module extented from torch.nn.BatchNormNd with the added stats reduction across multiple processes. :class:apex.parallel.SyncBatchNorm is designed to work with DistributedDataParallel.

When running in training mode, the layer reduces stats across all processes
to increase the effective batchsize for normalization layer. This is useful
in applications where batch size is small on a given process that would
diminish converged accuracy of the model. The model uses collective
communication package from `torch.distributed`.

When running in evaluation mode, the layer falls back to
`torch.nn.functional.batch_norm`

Args:
    num_features: :math:`C` from an expected input of size
        :math:`(N, C, L)` or :math:`L` from input of size :math:`(N, L)`
    eps: a value added to the denominator for numerical stability.
        Default: 1e-5
    momentum: the value used for the running_mean and running_var
        computation. Can be set to ``None`` for cumulative moving average
        (i.e. simple average). Default: 0.1
    affine: a boolean value that when set to ``True``, this module has
        learnable affine parameters. Default: ``True``
    track_running_stats: a boolean value that when set to ``True``, this
        module tracks the running mean and variance, and when set to ``False``,
        this module does not track such statistics and always uses batch
        statistics in both training and eval modes. Default: ``True``
    process_group: pass in a process group within which the stats of the
        mini-batch is being synchronized. ``None`` for using default process
        group
    channel_last: a boolean value that when set to ``True``, this module
        take the last dimension of the input tensor to be the channel
        dimension. Default: False

Examples::
    >>> # channel first tensor
    >>> sbn = apex.parallel.SyncBatchNorm(100).cuda()
    >>> inp = torch.randn(10, 100, 14, 14).cuda()
    >>> out = sbn(inp)
    >>> inp = torch.randn(3, 100, 20).cuda()
    >>> out = sbn(inp)
    >>> # channel last tensor
    >>> sbn = apex.parallel.SyncBatchNorm(100, channel_last=True).cuda()
    >>> inp = torch.randn(10, 14, 14, 100).cuda()
"""

def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, process_group=None, channel_last=False, fuse_relu=False):
    super(SyncBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine, track_running_stats=track_running_stats)
    self.process_group = process_group
    self.channel_last = channel_last
    self.fuse_relu = fuse_relu

def _specify_process_group(self, process_group):
    self.process_group = process_group

def _specify_channel_last(self, channel_last):
    self.channel_last = channel_last

[docs] def forward(self, input, z = None): # if input.dim() == 2, we switch to channel_last for efficient memory accessing channel_last = self.channel_last if input.dim() != 2 else True

    if not self.training and self.track_running_stats and not self.channel_last and not self.fuse_relu and z == None:
        # fall back to pytorch implementation for inference
        return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, 0.0, self.eps)
    else:
        exponential_average_factor = 0.0
        if self.training and self.track_running_stats:
            self.num_batches_tracked += 1
            if self.momentum is None:
                exponential_average_factor = 1.0 / float(self.num_batches_tracked)
            else:
                exponential_average_factor = self.momentum
        return SyncBatchnormFunction.apply(input, z, self.weight, self.bias, self.running_mean, self.running_var, self.eps, self.training or not self.track_running_stats, exponential_average_factor, self.process_group, self.channel_last, self.fuse_relu)