jax.numpy.histogramdd — JAX documentation (original) (raw)

jax.numpy.histogramdd#

jax.numpy.histogramdd(sample, bins=10, range=None, weights=None, density=None)[source]#

Compute an N-dimensional histogram.

JAX implementation of numpy.histogramdd().

Parameters:

Returns:

A tuple of arrays (histogram, bin_edges), where histogram contains the aggregated data, and bin_edges specifies the boundaries of the bins.

Return type:

tuple[Array, list[Array]]

See also

Examples

A histogram over 100 points in three dimensions

key = jax.random.key(42) a = jax.random.normal(key, (100, 3)) counts, bin_edges = jnp.histogramdd(a, bins=6, ... range=[(-3, 3), (-3, 3), (-3, 3)]) counts.shape (6, 6, 6) bin_edges
[Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32), Array([-3., -2., -1., 0., 1., 2., 3.], dtype=float32)]

Using density=True returns a normalized histogram:

density, bin_edges = jnp.histogramdd(a, density=True) bin_widths = map(jnp.diff, bin_edges) dx, dy, dz = jnp.meshgrid(*bin_widths, indexing='ij') normed = jnp.sum(density * dx * dy * dz) jnp.allclose(normed, 1.0) Array(True, dtype=bool)