torchmetrics.Metric — PyTorch-Metrics 1.7.1 documentation (original) (raw)

The base Metric class is an abstract base class that are used as the building block for all other Module metrics.

class torchmetrics.Metric(**kwargs)[source]

Base class for all metrics present in the Metrics API.

This class is inherited by all metrics and implements the following functionality:

  1. Handles the transfer of metric states to the correct device.
  2. Handles the synchronization of metric states across processes.
  3. Provides properties and methods to control the overall behavior of the metric and its states.

The three core methods of the base class are: add_state(), forward() and reset() which should almost never be overwritten by child classes. Instead, the following methods should be overwritten update() andcompute().

Parameters:

kwargs (Any) –

additional keyword arguments, see Advanced metric settings for more info.

add_state(name, default, dist_reduce_fx=None, persistent=False)[source]

Add metric state variable. Only used by subclasses.

Metric state variables are either :class:`~torch.Tensor or an empty list, which can be appended to by the metric. Each state variable must have a unique name associated with it. State variables are accessible as attributes of the metric i.e, if name is "my_state" then its value can be accessed from an instancemetric as metric.my_state. Metric states behave like buffers and parameters of Moduleas they are also updated when .to() is called. Unlike parameters and buffers, metric states are not by default saved in the modules state_dict.

Parameters:

Return type:

None

Note

Setting dist_reduce_fx to None will return the metric state synchronized across different processes. However, there won’t be any reduction function applied to the synchronized metric state.

The metric states would be synced as follows

Important

When passing a custom function to dist_reduce_fx, expect the synchronized metric state to follow the format discussed in the above note.

Caution

The values inserted into a list state are deleted whenever reset() is called. This allows device memory to be automatically reallocated, but may produce unexpected effects when referencing list states. To retain such values after reset() is called, you must first copy them to another object.

Raises:

clone()[source]

Make a copy of the metric.

Return type:

Metric

abstract compute()[source]

Override this method to compute the final metric value.

This method will automatically synchronize state variables when running in distributed backend.

Return type:

Any

double()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

float()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

forward(*args, **kwargs)[source]

Aggregate and evaluate batch input directly.

Serves the dual purpose of both computing the metric on the current batch of inputs but also add the batch statistics to the overall accumulating metric state. Input arguments are the exact same as correspondingupdate method. The returned output is the exact same as the output of compute.

Parameters:

Return type:

Any

Returns:

The output of the compute method evaluated on the current batch.

Raises:

TorchMetricsUserError – If the metric is already synced and forward is called again.

half()[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

merge_state(incoming_state)[source]

Merge incoming metric state to the current state of the metric.

Parameters:

incoming_state (Union[dict[str, Any], Metric]) – either a dict containing a metric state similar to the metric itself or an instance of the metric class.

Raises:

Example with a metric instance: :rtype: None

from torchmetrics.aggregation import SumMetric metric1 = SumMetric() metric2 = SumMetric() metric1.update(1) metric2.update(2) metric1.merge_state(metric2) metric1.compute() tensor(3.)

Example with a dict:

from torchmetrics.aggregation import SumMetric metric = SumMetric() metric.update(1)

SumMetric has one state variable called sum_value

metric.merge_state({"sum_value": torch.tensor(2)}) metric.compute() tensor(3.)

persistent(mode=False)[source]

Change post-init if metric states should be saved to its state_dict.

Return type:

None

plot(*_, **__)[source]

Override this method plot the metric value.

Return type:

Any

reset()[source]

Reset metric state variables to their default value.

Return type:

None

set_dtype(dst_type)[source]

Transfer all metric state to specific dtype. Special version of standard type method.

Parameters:

dst_type (Union[str, dtype]) – the desired type as string or dtype object

Return type:

Metric

state_dict(destination=None, prefix='', keep_vars=False)[source]

Get the current state of metric as an dictionary.

Parameters:

Return type:

dict[str, Any]

sync(dist_sync_fn=None, process_group=None, should_sync=True, distributed_available=None)[source]

Sync function for manually controlling when metrics states should be synced across processes.

Parameters:

Raises:

TorchMetricsUserError – If the metric is already synced and sync is called again.

Return type:

None

sync_context(dist_sync_fn=None, process_group=None, should_sync=True, should_unsync=True, distributed_available=None)[source]

Context manager to synchronize states.

This context manager is used in distributed setting and makes sure that the local cache states are restored after yielding the synchronized state.

Parameters:

Return type:

Generator

type(dst_type)[source]

Override default and prevent dtype casting.

Please use Metric.set_dtype() instead.

Return type:

Metric

unsync(should_unsync=True)[source]

Unsync function for manually controlling when metrics states should be reverted back to their local states.

Parameters:

should_unsync (bool) – Whether to perform unsync

Return type:

None

abstract update(*_, **__)[source]

Override this method to update the state variables of your metric class.

Return type:

None

property device_: device_[source]

Return the device of the metric.

property dtype_: dtype_[source]

Return the default dtype of the metric.

property metric_state_: dict[str, Union[List[torch.Tensor], torch.Tensor]]_[source]

Get the current state of the metric.

property update_called_: bool_[source]

Returns True if update or forward has been called initialization or last reset.

property update_count_: int_[source]

Get the number of times update and/or forward has been called since initialization or last reset.