jax.scipy module — JAX documentation (original) (raw)

jax.scipy module#

jax.scipy.cluster#

vq(obs, code_book[, check_finite]) Assign codes from a code book to a set of observations.

jax.scipy.fft#

dct(x[, type, n, axis, norm]) Computes the discrete cosine transform of the input
dctn(x[, type, s, axes, norm]) Computes the multidimensional discrete cosine transform of the input
idct(x[, type, n, axis, norm]) Computes the inverse discrete cosine transform of the input
idctn(x[, type, s, axes, norm]) Computes the multidimensional inverse discrete cosine transform of the input

jax.scipy.integrate#

trapezoid(y[, x, dx, axis]) Integrate along the given axis using the composite trapezoidal rule.

jax.scipy.interpolate#

jax.scipy.linalg#

block_diag(*arrs) Create a block diagonal matrix from input arrays.
cho_factor(a[, lower, overwrite_a, check_finite]) Factorization for Cholesky-based linear solves
cho_solve(c_and_lower, b[, overwrite_b, ...]) Solve a linear system using a Cholesky factorization
cholesky(a[, lower, overwrite_a, check_finite]) Compute the Cholesky decomposition of a matrix.
det(a[, overwrite_a, check_finite]) Compute the determinant of a matrix
eigh() Compute eigenvalues and eigenvectors for a Hermitian matrix
eigh_tridiagonal(d, e, *[, eigvals_only, ...]) Solve the eigenvalue problem for a symmetric real tridiagonal matrix
expm(A, *[, upper_triangular, max_squarings]) Compute the matrix exponential
expm_frechet() Compute the Frechet derivative of the matrix exponential.
funm(A, func[, disp]) Evaluate a matrix-valued function
hessenberg() Compute the Hessenberg form of the matrix
hilbert(n) Create a Hilbert matrix of order n.
inv(a[, overwrite_a, check_finite]) Return the inverse of a square matrix
lu() Compute the LU decomposition
lu_factor(a[, overwrite_a, check_finite]) Factorization for LU-based linear solves
lu_solve(lu_and_piv, b[, trans, ...]) Solve a linear system using an LU factorization
pascal(n[, kind]) Create a Pascal matrix approximation of order n.
polar(a[, side, method, eps, max_iterations]) Computes the polar decomposition.
qr() Compute the QR decomposition of an array
rsf2csf(T, Z[, check_finite]) Convert real Schur form to complex Schur form.
schur(a[, output]) Compute the Schur decomposition
solve(a, b[, lower, overwrite_a, ...]) Solve a linear system of equations.
solve_triangular(a, b[, trans, lower, ...]) Solve a triangular linear system of equations
sqrtm(A[, blocksize]) Compute the matrix square root
svd() Compute the singular value decomposition.
toeplitz(c[, r]) Construct a Toeplitz matrix.

jax.scipy.ndimage#

map_coordinates(input, coordinates, order[, ...]) Map the input array to new coordinates using interpolation.

jax.scipy.optimize#

minimize(fun, x0[, args, tol, options]) Minimization of scalar function of one or more variables.
OptimizeResults(x, success, status, fun, ...) Object holding optimization results.

jax.scipy.signal#

fftconvolve(in1, in2[, mode, axes]) Convolve two N-dimensional arrays using Fast Fourier Transform (FFT).
convolve(in1, in2[, mode, method, precision]) Convolution of two N-dimensional arrays.
convolve2d(in1, in2[, mode, boundary, ...]) Convolution of two 2-dimensional arrays.
correlate(in1, in2[, mode, method, precision]) Cross-correlation of two N-dimensional arrays.
correlate2d(in1, in2[, mode, boundary, ...]) Cross-correlation of two 2-dimensional arrays.
csd(x, y[, fs, window, nperseg, noverlap, ...]) Estimate cross power spectral density (CSD) using Welch's method.
detrend(data[, axis, type, bp, overwrite_data]) Remove linear or piecewise linear trends from data.
istft(Zxx[, fs, window, nperseg, noverlap, ...]) Perform the inverse short-time Fourier transform (ISTFT).
stft(x[, fs, window, nperseg, noverlap, ...]) Compute the short-time Fourier transform (STFT).
welch(x[, fs, window, nperseg, noverlap, ...]) Estimate power spectral density (PSD) using Welch's method.

jax.scipy.spatial.transform#

Rotation(quat) Rotation in 3 dimensions.
Slerp(times, timedelta, rotations, rotvecs) Spherical Linear Interpolation of Rotations.

jax.scipy.sparse.linalg#

bicgstab(A, b[, x0, tol, atol, maxiter, M]) Use Bi-Conjugate Gradient Stable iteration to solve Ax = b.
cg(A, b[, x0, tol, atol, maxiter, M]) Use Conjugate Gradient iteration to solve Ax = b.
gmres(A, b[, x0, tol, atol, restart, ...]) GMRES solves the linear system A x = b for x, given A and b.

jax.scipy.special#

bernoulli(n) Generate the first N Bernoulli numbers.
beta(a, b) The beta function
betainc(a, b, x) The regularized incomplete beta function.
betaln(a, b) Natural log of the absolute value of the beta function
digamma(x) The digamma function
entr(x) The entropy function
erf(x) The error function
erfc(x) The complement of the error function
erfinv(x) The inverse of the error function
exp1(x) Exponential integral function.
expi Exponential integral function.
expit(x) The logistic sigmoid (expit) function
expn Generalized exponential integral function.
factorial(n[, exact]) Factorial function
fresnel The Fresnel integrals
gamma(x) The gamma function.
gammainc(a, x) The regularized lower incomplete gamma function.
gammaincc(a, x) The regularized upper incomplete gamma function.
gammaln(x) Natural log of the absolute value of the gamma function.
gammasgn(x) Sign of the gamma function.
hyp1f1 The 1F1 hypergeometric function.
i0(x) Modified bessel function of zeroth order.
i0e(x) Exponentially scaled modified bessel function of zeroth order.
i1(x) Modified bessel function of first order.
i1e(x) Exponentially scaled modified bessel function of first order.
kl_div(p, q) The Kullback-Leibler divergence.
log_ndtr Log Normal distribution function.
log_softmax(x, /, *[, axis]) Log-Softmax function.
logit The logit function
logsumexp() Log-sum-exp reduction.
lpmn(m, n, z) The associated Legendre functions (ALFs) of the first kind.
lpmn_values(m, n, z, is_normalized) The associated Legendre functions (ALFs) of the first kind.
multigammaln(a, d) The natural log of the multivariate gamma function.
ndtr(x) Normal distribution function.
ndtri(p) The inverse of the CDF of the Normal distribution function.
poch The Pochammer symbol.
polygamma(n, x) The polygamma function.
rel_entr(p, q) The relative entropy function.
softmax(x, /, *[, axis]) Softmax function.
spence(x) Spence's function, also known as the dilogarithm for real values.
sph_harm(m, n, theta, phi[, n_max]) Computes the spherical harmonics.
xlog1py Compute x*log(1 + y), returning 0 for x=0.
xlogy Compute x*log(y), returning 0 for x=0.
zeta The Hurwitz zeta function.

jax.scipy.stats#

mode(a[, axis, nan_policy, keepdims]) Compute the mode (most common value) along an axis of an array.
rankdata(a[, method, axis, nan_policy]) Compute the rank of data along an array axis.
sem(a[, axis, ddof, nan_policy, keepdims]) Compute the standard error of the mean.

jax.scipy.stats.bernoulli#

logpmf(k, p[, loc]) Bernoulli log probability mass function.
pmf(k, p[, loc]) Bernoulli probability mass function.
cdf(k, p) Bernoulli cumulative distribution function.
ppf(q, p) Bernoulli percent point function.

jax.scipy.stats.beta#

logpdf(x, a, b[, loc, scale]) Beta log probability distribution function.
pdf(x, a, b[, loc, scale]) Beta probability distribution function.
cdf(x, a, b[, loc, scale]) Beta cumulative distribution function
logcdf(x, a, b[, loc, scale]) Beta log cumulative distribution function.
sf(x, a, b[, loc, scale]) Beta distribution survival function.
logsf(x, a, b[, loc, scale]) Beta distribution log survival function.

jax.scipy.stats.betabinom#

logpmf(k, n, a, b[, loc]) Beta-binomial log probability mass function.
pmf(k, n, a, b[, loc]) Beta-binomial probability mass function.

jax.scipy.stats.binom#

logpmf(k, n, p[, loc]) Binomial log probability mass function.
pmf(k, n, p[, loc]) Binomial probability mass function.

jax.scipy.stats.cauchy#

logpdf(x[, loc, scale]) Cauchy log probability distribution function.
pdf(x[, loc, scale]) Cauchy probability distribution function.
cdf(x[, loc, scale]) Cauchy cumulative distribution function.
logcdf(x[, loc, scale]) Cauchy log cumulative distribution function.
sf(x[, loc, scale]) Cauchy distribution log survival function.
logsf(x[, loc, scale]) Cauchy distribution log survival function.
isf(q[, loc, scale]) Cauchy distribution inverse survival function.
ppf(q[, loc, scale]) Cauchy distribution percent point function.

jax.scipy.stats.chi2#

logpdf(x, df[, loc, scale]) Chi-square log probability distribution function.
pdf(x, df[, loc, scale]) Chi-square probability distribution function.
cdf(x, df[, loc, scale]) Chi-square cumulative distribution function.
logcdf(x, df[, loc, scale]) Chi-square log cumulative distribution function.
sf(x, df[, loc, scale]) Chi-square survival function.
logsf(x, df[, loc, scale]) Chi-square log survival function.

jax.scipy.stats.dirichlet#

logpdf(x, alpha) Dirichlet log probability distribution function.
pdf(x, alpha) Dirichlet probability distribution function.

jax.scipy.stats.expon#

logpdf(x[, loc, scale]) Exponential log probability distribution function.
pdf(x[, loc, scale]) Exponential probability distribution function.
logcdf(x[, loc, scale]) Exponential log cumulative density function.
cdf(x[, loc, scale]) Exponential cumulative density function.
logsf(x[, loc, scale]) Exponential log survival function.
sf(x[, loc, scale]) Exponential survival function.
ppf(q[, loc, scale]) Exponential survival function.

jax.scipy.stats.gamma#

logpdf(x, a[, loc, scale]) Gamma log probability distribution function.
pdf(x, a[, loc, scale]) Gamma probability distribution function.
cdf(x, a[, loc, scale]) Gamma cumulative distribution function.
logcdf(x, a[, loc, scale]) Gamma log cumulative distribution function.
sf(x, a[, loc, scale]) Gamma survival function.
logsf(x, a[, loc, scale]) Gamma log survival function.

jax.scipy.stats.gennorm#

cdf(x, beta) Generalized normal cumulative distribution function.
logpdf(x, beta) Generalized normal log probability distribution function.
pdf(x, beta) Generalized normal probability distribution function.

jax.scipy.stats.geom#

logpmf(k, p[, loc]) Geometric log probability mass function.
pmf(k, p[, loc]) Geometric probability mass function.

jax.scipy.stats.laplace#

cdf(x[, loc, scale]) Laplace cumulative distribution function.
logpdf(x[, loc, scale]) Laplace log probability distribution function.
pdf(x[, loc, scale]) Laplace probability distribution function.

jax.scipy.stats.logistic#

cdf(x[, loc, scale]) Logistic cumulative distribution function.
isf(x[, loc, scale]) Logistic distribution inverse survival function.
logpdf(x[, loc, scale]) Logistic log probability distribution function.
pdf(x[, loc, scale]) Logistic probability distribution function.
ppf(x[, loc, scale]) Logistic distribution percent point function.
sf(x[, loc, scale]) Logistic distribution survival function.

jax.scipy.stats.multinomial#

logpmf(x, n, p) Multinomial log probability mass function.
pmf(x, n, p) Multinomial probability mass function.

jax.scipy.stats.multivariate_normal#

logpdf(x, mean, cov[, allow_singular]) Multivariate normal log probability distribution function.
pdf(x, mean, cov) Multivariate normal probability distribution function.

jax.scipy.stats.nbinom#

logpmf(k, n, p[, loc]) Negative-binomial log probability mass function.
pmf(k, n, p[, loc]) Negative-binomial probability mass function.

jax.scipy.stats.norm#

logpdf(x[, loc, scale]) Normal log probability distribution function.
pdf(x[, loc, scale]) Normal probability distribution function.
cdf(x[, loc, scale]) Normal cumulative distribution function.
logcdf(x[, loc, scale]) Normal log cumulative distribution function.
ppf(q[, loc, scale]) Normal distribution percent point function.
sf(x[, loc, scale]) Normal distribution survival function.
logsf(x[, loc, scale]) Normal distribution log survival function.
isf(q[, loc, scale]) Normal distribution inverse survival function.

jax.scipy.stats.pareto#

logpdf(x, b[, loc, scale]) Pareto log probability distribution function.
pdf(x, b[, loc, scale]) Pareto probability distribution function.

jax.scipy.stats.poisson#

logpmf(k, mu[, loc]) Poisson log probability mass function.
pmf(k, mu[, loc]) Poisson probability mass function.
cdf(k, mu[, loc]) Poisson cumulative distribution function.

jax.scipy.stats.t#

logpdf(x, df[, loc, scale]) Student's T log probability distribution function.
pdf(x, df[, loc, scale]) Student's T probability distribution function.

jax.scipy.stats.truncnorm#

cdf(x, a, b[, loc, scale]) Truncated normal cumulative distribution function.
logcdf(x, a, b[, loc, scale]) Truncated normal log cumulative distribution function.
logpdf(x, a, b[, loc, scale]) Truncated normal log probability distribution function.
logsf(x, a, b[, loc, scale]) Truncated normal distribution log survival function.
pdf(x, a, b[, loc, scale]) Truncated normal probability distribution function.
sf(x, a, b[, loc, scale]) Truncated normal distribution log survival function.

jax.scipy.stats.uniform#

logpdf(x[, loc, scale]) Uniform log probability distribution function.
pdf(x[, loc, scale]) Uniform probability distribution function.
cdf(x[, loc, scale]) Uniform cumulative distribution function.
ppf(q[, loc, scale]) Uniform distribution percent point function.

jax.scipy.stats.gaussian_kde#

gaussian_kde(dataset[, bw_method, weights]) Gaussian Kernel Density Estimator
gaussian_kde.evaluate(points) Evaluate the Gaussian KDE on the given points.
gaussian_kde.integrate_gaussian(mean, cov) Integrate the distribution weighted by a Gaussian.
gaussian_kde.integrate_box_1d(low, high) Integrate the distribution over the given limits.
gaussian_kde.integrate_kde(other) Integrate the product of two Gaussian KDE distributions.
gaussian_kde.resample(key[, shape]) Randomly sample a dataset from the estimated pdf
gaussian_kde.pdf(x) Probability density function
gaussian_kde.logpdf(x) Log probability density function

jax.scipy.stats.vonmises#

logpdf(x, kappa) von Mises log probability distribution function.
pdf(x, kappa) von Mises probability distribution function.

jax.scipy.stats.wrapcauchy#

logpdf(x, c) Wrapped Cauchy log probability distribution function.
pdf(x, c) Wrapped Cauchy probability distribution function.