Classwise Wrapper — PyTorch-Metrics 1.7.1 documentation (original) (raw)

Module Interface

class torchmetrics.wrappers.ClasswiseWrapper(metric, labels=None, prefix=None, postfix=None)[source]

Wrapper metric for altering the output of classification metrics.

This metric works together with classification metrics that returns multiple values (one value per class) such that label information can be automatically included in the output.

Parameters:

Example::

Basic example where the output of a metric is unwrapped into a dictionary with the class index as keys:

from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_0': tensor(0.5000), 'multiclassaccuracy_1': tensor(0.7500), 'multiclassaccuracy_2': tensor(0.)}

Example::

Using custom name via prefix and postfix:

from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-") metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc") preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric_pre(preds, target)
{'acc-0': tensor(0.3333), 'acc-1': tensor(0.6667), 'acc-2': tensor(0.)} metric_post(preds, target)
{'0-acc': tensor(0.3333), '1-acc': tensor(0.6667), '2-acc': tensor(0.)}

Example::

Providing labels as a list of strings:

from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper( ... MulticlassAccuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_horse': tensor(0.), 'multiclassaccuracy_fish': tensor(0.3333), 'multiclassaccuracy_dog': tensor(0.4000)}

Example::

Classwise can also be used in combination with MetricCollection. In this case, everything will be flattened into a single dictionary:

from torch import randint, randn from torchmetrics import MetricCollection from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall labels = ["horse", "fish", "dog"] metric = MetricCollection( ... {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels), ... 'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)} ... ) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_horse': tensor(0.6667), 'multiclassaccuracy_fish': tensor(0.3333), 'multiclassaccuracy_dog': tensor(0.5000), 'multiclassrecall_horse': tensor(0.6667), 'multiclassrecall_fish': tensor(0.3333), 'multiclassrecall_dog': tensor(0.5000)}

compute()[source]

Compute metric.

Return type:

dict[str, Tensor]

forward(*args, **kwargs)[source]

Calculate on batch and accumulate to global state.

Return type:

Any

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

Example plotting a single value

import torch from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) metric.update(torch.randint(3, (20,)), torch.randint(3, (20,))) fig_, ax_ = metric.plot()

../_images/classwise_wrapper-1.png

Example plotting multiple values

import torch from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) values = [ ] for _ in range(3): ... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,)))) fig_, ax_ = metric.plot(values)

../_images/classwise_wrapper-2.png

reset()[source]

Reset metric.

Return type:

None

update(*args, **kwargs)[source]

Update state.

Return type:

None

property higher_is_better_: Optional[bool]_[source]

Return if the metric is higher the better.