jax.nn.squareplus — JAX documentation (original) (raw)
jax.nn.squareplus#
jax.nn.squareplus(x, b=4)[source]#
Squareplus activation function.
Computes the element-wise function
\[\mathrm{squareplus}(x) = \frac{x + \sqrt{x^2 + b}}{2}\]
as described in https://arxiv.org/abs/2112.11687.
Parameters:
- x (ArrayLike) – input array
- b (ArrayLike) – smoothness parameter
Return type: