tf.keras.metrics.SparseCategoricalAccuracy  |  TensorFlow v2.16.1 (original) (raw)

tf.keras.metrics.SparseCategoricalAccuracy

Stay organized with collections Save and categorize content based on your preferences.

Calculates how often predictions match integer labels.

Inherits From: MeanMetricWrapper, Mean, Metric

tf.keras.metrics.SparseCategoricalAccuracy(
    name='sparse_categorical_accuracy', dtype=None
)

Used in the notebooks

Used in the guide Used in the tutorials
Migrate early stopping Effective Tensorflow 2 Mixed precision Use TPUs Multi-GPU and distributed training Custom training with tf.distribute.Strategy Scalable model compression Using DTensors with Keras TensorFlow 2 quickstart for experts Custom training: walkthrough
acc = np.dot(sample_weight, np.equal(y_true, np.argmax(y_pred, axis=1))

You can provide logits of classes as y_pred, since argmax of logits and probabilities are same.

This metric creates two local variables, total and count that are used to compute the frequency with which y_pred matches y_true. This frequency is ultimately returned as sparse categorical accuracy: an idempotent operation that simply divides total by count.

If sample_weight is None, weights default to 1. Use sample_weight of 0 to mask values.

Args
name (Optional) string name of the metric instance.
dtype (Optional) data type of the metric result.

Example:

m = keras.metrics.SparseCategoricalAccuracy() m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]]) m.result() 0.5

m.reset_state() m.update_state([[2], [1]], [[0.1, 0.6, 0.3], [0.05, 0.95, 0]], sample_weight=[0.7, 0.3]) m.result() 0.3

Usage with compile() API:

model.compile(optimizer='sgd',
              loss='sparse_categorical_crossentropy',
              metrics=[keras.metrics.SparseCategoricalAccuracy()])

| Attributes | | | ---------- | | | dtype | | | variables | |

Methods

add_variable

View source

add_variable(
    shape, initializer, dtype=None, aggregation='sum', name=None
)

add_weight

View source

add_weight(
    shape=(), initializer=None, dtype=None, name=None
)

from_config

View source

@classmethod from_config( config )

get_config

View source

get_config()

Return the serializable config of the metric.

reset_state

View source

reset_state()

Reset all of the metric state variables.

This function is called between epochs/steps, when a metric is evaluated during training.

result

View source

result()

Compute the current metric value.

Returns
A scalar tensor, or a dictionary of scalar tensors.

stateless_reset_state

View source

stateless_reset_state()

stateless_result

View source

stateless_result(
    metric_variables
)

stateless_update_state

View source

stateless_update_state(
    metric_variables, *args, **kwargs
)

update_state

View source

update_state(
    y_true, y_pred, sample_weight=None
)

Accumulate statistics for the metric.

__call__

View source

__call__(
    *args, **kwargs
)

Call self as a function.