Distributed - PyTorch Metric Learning (original) (raw)

Wrap a tuple loss or miner with these when using PyTorch's DistributedDataParallel (i.e. multiprocessing).

DistributedLossWrapper

utils.distributed.DistributedLossWrapper(loss, efficient=False)

Parameters:

Example usage:

`from pytorch_metric_learning import losses from pytorch_metric_learning.utils import distributed as pml_dist

loss_func = losses.ContrastiveLoss() loss_func = pml_dist.DistributedLossWrapper(loss_func)

in each process during training

loss = loss_func(embeddings, labels) `

DistributedMinerWrapper

utils.distributed.DistributedMinerWrapper(miner, efficient=False)

Parameters:

Example usage:

`from pytorch_metric_learning import miners from pytorch_metric_learning.utils import distributed as pml_dist

miner = miners.MultiSimilarityMiner() miner = pml_dist.DistributedMinerWrapper(miner)

in each process

tuples = miner(embeddings, labels)

pass into a DistributedLossWrapper

loss = loss_func(embeddings, labels, indices_tuple) `