jax.numpy.convolve — JAX documentation (original) (raw)
jax.numpy.convolve#
jax.numpy.convolve(a, v, mode='full', *, precision=None, preferred_element_type=None)[source]#
Convolution of two one dimensional arrays.
JAX implementation of numpy.convolve().
Convolution of one dimensional arrays is defined as:
\[c_k = \sum_j a_{k - j} v_j\]
Parameters:
- a (ArrayLike) – left-hand input to the convolution. Must have
a.ndim == 1
. - v (ArrayLike) – right-hand input to the convolution. Must have
v.ndim == 1
. - mode (str) –
controls the size of the output. Available operations are:"full"
: (default) output the full convolution of the inputs."same"
: return a centered portion of the"full"
output which is the same size asa
."valid"
: return the portion of the"full"
output which do not depend on padding at the array edges.
- precision (PrecisionLike) – Specify the precision of the computation. Refer tojax.lax.Precision for a description of available values.
- preferred_element_type (DTypeLike | None) – A datatype, indicating to accumulate results to and return a result with that datatype. Default is
None
, which means the default accumulation type for the input types.
Returns:
Array containing the convolved result.
Return type:
Examples
A few 1D convolution examples:
x = jnp.array([1, 2, 3, 2, 1]) y = jnp.array([4, 1, 2])
jax.numpy.convolve
, by default, returns full convolution using implicit zero-padding at the edges:
jnp.convolve(x, y) Array([ 4., 9., 16., 15., 12., 5., 2.], dtype=float32)
Specifying mode = 'same'
returns a centered convolution the same size as the first input:
jnp.convolve(x, y, mode='same') Array([ 9., 16., 15., 12., 5.], dtype=float32)
Specifying mode = 'valid'
returns only the portion where the two arrays fully overlap:
jnp.convolve(x, y, mode='valid') Array([16., 15., 12.], dtype=float32)
For complex-valued inputs:
x1 = jnp.array([3+1j, 2, 4-3j]) y1 = jnp.array([1, 2-3j, 4+5j]) jnp.convolve(x1, y1) Array([ 3. +1.j, 11. -7.j, 15.+10.j, 7. -8.j, 31. +8.j], dtype=complex64)