jax.numpy.convolve — JAX documentation (original) (raw)

jax.numpy.convolve#

jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[source]#

Convolution of two one dimensional arrays.

JAX implementation of numpy.convolve().

Convolution of one dimensional arrays is defined as:

\[c_k = \sum_j a_{k - j} v_j\]

Parameters:

Returns:

Array containing the convolved result.

Return type:

Array

Examples

A few 1D convolution examples:

x = jnp.array([1, 2, 3, 2, 1]) y = jnp.array([4, 1, 2])

jax.numpy.convolve, by default, returns full convolution using implicit zero-padding at the edges:

jnp.convolve(x, y) Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32)

Specifying mode = 'same' returns a centered convolution the same size as the first input:

jnp.convolve(x, y, mode='same') Array([ 9., 16., 15., 12., 5.], dtype=float32)

Specifying mode = 'valid' returns only the portion where the two arrays fully overlap:

jnp.convolve(x, y, mode='valid') Array([16., 15., 12.], dtype=float32)

For complex-valued inputs:

x1 = jnp.array([3+1j, 2, 4-3j]) y1 = jnp.array([1, 2-3j, 4+5j]) jnp.convolve(x1, y1) Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)