jax.numpy.histogram2d — JAX documentation (original) (raw)

jax.numpy.histogram2d#

jax.numpy.histogram2d(x, y, bins=10, range=None, weights=None, density=None)[source]#

Compute a 2-dimensional histogram.

JAX implementation of numpy.histogram2d().

Parameters:

Returns:

A tuple of arrays (histogram, x_edges, y_edges), where histogramcontains the aggregated data, and x_edges and y_edges specify the boundaries of the bins.

Return type:

tuple[Array, Array, Array]

See also

Examples

x = jnp.array([1, 2, 3, 10, 11, 15, 19, 25]) y = jnp.array([2, 5, 6, 8, 13, 16, 17, 18]) counts, x_edges, y_edges = jnp.histogram2d(x, y, bins=8) counts.shape (8, 8) x_edges Array([ 1., 4., 7., 10., 13., 16., 19., 22., 25.], dtype=float32) y_edges Array([ 2., 4., 6., 8., 10., 12., 14., 16., 18.], dtype=float32)

Specifying the bin range:

counts, x_edges, y_edges = jnp.histogram2d(x, y, range=[(0, 25), (0, 25)], bins=5) counts.shape (5, 5) x_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32) y_edges Array([ 0., 5., 10., 15., 20., 25.], dtype=float32)

Specifying the bin edges explicitly:

x_edges = jnp.array([0, 10, 20, 30]) y_edges = jnp.array([0, 10, 20, 30]) counts, _, _ = jnp.histogram2d(x, y, bins=[x_edges, y_edges]) counts Array([[3, 0, 0], [1, 3, 0], [0, 1, 0]], dtype=int32)

Using density=True returns a normalized histogram:

density, x_edges, y_edges = jnp.histogram2d(x, y, density=True) dx = jnp.diff(x_edges) dy = jnp.diff(y_edges) normed_sum = jnp.sum(density * dx[:, None] * dy[None, :]) jnp.allclose(normed_sum, 1.0) Array(True, dtype=bool)