Threshold — PyTorch 2.7 documentation (original) (raw)

class torch.nn.Threshold(threshold, value, inplace=False)[source][source]

Thresholds each element of the input Tensor.

Threshold is defined as:

y={x, if x>thresholdvalue, otherwise y = \begin{cases} x, &\text{ if } x > \text{threshold} \\ \text{value}, &\text{ otherwise } \end{cases}

Parameters

Shape:

Examples:

m = nn.Threshold(0.1, 20) input = torch.randn(2) output = m(input)