Classwise Wrapper — PyTorch-Metrics 1.7.1 documentation (original) (raw)
Module Interface¶
class torchmetrics.wrappers.ClasswiseWrapper(metric, labels=None, prefix=None, postfix=None)[source]¶
Wrapper metric for altering the output of classification metrics.
This metric works together with classification metrics that returns multiple values (one value per class) such that label information can be automatically included in the output.
Parameters:
- metric¶ (Metric) – base metric that should be wrapped. It is assumed that the metric outputs a single tensor that is split along the first dimension.
- labels¶ (Optional[list[str]]) – list of strings indicating the different classes.
- prefix¶ (Optional[str]) – string that is prepended to the metric names.
- postfix¶ (Optional[str]) – string that is appended to the metric names.
Example::
Basic example where the output of a metric is unwrapped into a dictionary with the class index as keys:
from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_0': tensor(0.5000), 'multiclassaccuracy_1': tensor(0.7500), 'multiclassaccuracy_2': tensor(0.)}
Example::
Using custom name via prefix and postfix:
from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric_pre = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), prefix="acc-") metric_post = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), postfix="-acc") preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric_pre(preds, target)
{'acc-0': tensor(0.3333), 'acc-1': tensor(0.6667), 'acc-2': tensor(0.)} metric_post(preds, target)
{'0-acc': tensor(0.3333), '1-acc': tensor(0.6667), '2-acc': tensor(0.)}
Example::
Providing labels as a list of strings:
from torch import randint, randn from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper( ... MulticlassAccuracy(num_classes=3, average=None), ... labels=["horse", "fish", "dog"] ... ) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_horse': tensor(0.), 'multiclassaccuracy_fish': tensor(0.3333), 'multiclassaccuracy_dog': tensor(0.4000)}
Example::
Classwise can also be used in combination with MetricCollection. In this case, everything will be flattened into a single dictionary:
from torch import randint, randn from torchmetrics import MetricCollection from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy, MulticlassRecall labels = ["horse", "fish", "dog"] metric = MetricCollection( ... {'multiclassaccuracy': ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None), labels), ... 'multiclassrecall': ClasswiseWrapper(MulticlassRecall(num_classes=3, average=None), labels)} ... ) preds = randn(10, 3).softmax(dim=-1) target = randint(3, (10,)) metric(preds, target)
{'multiclassaccuracy_horse': tensor(0.6667), 'multiclassaccuracy_fish': tensor(0.3333), 'multiclassaccuracy_dog': tensor(0.5000), 'multiclassrecall_horse': tensor(0.6667), 'multiclassrecall_fish': tensor(0.3333), 'multiclassrecall_dog': tensor(0.5000)}
Compute metric.
Return type:
forward(*args, **kwargs)[source]¶
Calculate on batch and accumulate to global state.
Return type:
plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
Parameters:
- val¶ (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.
- ax¶ (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis
Return type:
tuple[Figure, Union[Axes, ndarray]]
Returns:
Figure and Axes object
Raises:
ModuleNotFoundError – If matplotlib is not installed
Example plotting a single value
import torch from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) metric.update(torch.randint(3, (20,)), torch.randint(3, (20,))) fig_, ax_ = metric.plot()
Example plotting multiple values
import torch from torchmetrics.wrappers import ClasswiseWrapper from torchmetrics.classification import MulticlassAccuracy metric = ClasswiseWrapper(MulticlassAccuracy(num_classes=3, average=None)) values = [ ] for _ in range(3): ... values.append(metric(torch.randint(3, (20,)), torch.randint(3, (20,)))) fig_, ax_ = metric.plot(values)
Reset metric.
Return type:
update(*args, **kwargs)[source]¶
Update state.
Return type:
property higher_is_better_: Optional[bool]_[source]¶
Return if the metric is higher the better.