jax.scipy module — JAX documentation (original) (raw)
jax.scipy
module#
jax.scipy.cluster#
vq(obs, code_book[, check_finite]) | Assign codes from a code book to a set of observations. |
---|
jax.scipy.fft#
dct(x[, type, n, axis, norm]) | Computes the discrete cosine transform of the input |
---|---|
dctn(x[, type, s, axes, norm]) | Computes the multidimensional discrete cosine transform of the input |
idct(x[, type, n, axis, norm]) | Computes the inverse discrete cosine transform of the input |
idctn(x[, type, s, axes, norm]) | Computes the multidimensional inverse discrete cosine transform of the input |
jax.scipy.integrate#
trapezoid(y[, x, dx, axis]) | Integrate along the given axis using the composite trapezoidal rule. |
---|
jax.scipy.interpolate#
jax.scipy.linalg#
block_diag(*arrs) | Create a block diagonal matrix from input arrays. |
---|---|
cho_factor(a[, lower, overwrite_a, check_finite]) | Factorization for Cholesky-based linear solves |
cho_solve(c_and_lower, b[, overwrite_b, ...]) | Solve a linear system using a Cholesky factorization |
cholesky(a[, lower, overwrite_a, check_finite]) | Compute the Cholesky decomposition of a matrix. |
det(a[, overwrite_a, check_finite]) | Compute the determinant of a matrix |
eigh() | Compute eigenvalues and eigenvectors for a Hermitian matrix |
eigh_tridiagonal(d, e, *[, eigvals_only, ...]) | Solve the eigenvalue problem for a symmetric real tridiagonal matrix |
expm(A, *[, upper_triangular, max_squarings]) | Compute the matrix exponential |
expm_frechet() | Compute the Frechet derivative of the matrix exponential. |
funm(A, func[, disp]) | Evaluate a matrix-valued function |
hessenberg() | Compute the Hessenberg form of the matrix |
hilbert(n) | Create a Hilbert matrix of order n. |
inv(a[, overwrite_a, check_finite]) | Return the inverse of a square matrix |
lu() | Compute the LU decomposition |
lu_factor(a[, overwrite_a, check_finite]) | Factorization for LU-based linear solves |
lu_solve(lu_and_piv, b[, trans, ...]) | Solve a linear system using an LU factorization |
pascal(n[, kind]) | Create a Pascal matrix approximation of order n. |
polar(a[, side, method, eps, max_iterations]) | Computes the polar decomposition. |
qr() | Compute the QR decomposition of an array |
rsf2csf(T, Z[, check_finite]) | Convert real Schur form to complex Schur form. |
schur(a[, output]) | Compute the Schur decomposition |
solve(a, b[, lower, overwrite_a, ...]) | Solve a linear system of equations. |
solve_triangular(a, b[, trans, lower, ...]) | Solve a triangular linear system of equations |
sqrtm(A[, blocksize]) | Compute the matrix square root |
svd() | Compute the singular value decomposition. |
toeplitz(c[, r]) | Construct a Toeplitz matrix. |
jax.scipy.ndimage#
map_coordinates(input, coordinates, order[, ...]) | Map the input array to new coordinates using interpolation. |
---|
jax.scipy.optimize#
minimize(fun, x0[, args, tol, options]) | Minimization of scalar function of one or more variables. |
---|---|
OptimizeResults(x, success, status, fun, ...) | Object holding optimization results. |
jax.scipy.signal#
fftconvolve(in1, in2[, mode, axes]) | Convolve two N-dimensional arrays using Fast Fourier Transform (FFT). |
---|---|
convolve(in1, in2[, mode, method, precision]) | Convolution of two N-dimensional arrays. |
convolve2d(in1, in2[, mode, boundary, ...]) | Convolution of two 2-dimensional arrays. |
correlate(in1, in2[, mode, method, precision]) | Cross-correlation of two N-dimensional arrays. |
correlate2d(in1, in2[, mode, boundary, ...]) | Cross-correlation of two 2-dimensional arrays. |
csd(x, y[, fs, window, nperseg, noverlap, ...]) | Estimate cross power spectral density (CSD) using Welch's method. |
detrend(data[, axis, type, bp, overwrite_data]) | Remove linear or piecewise linear trends from data. |
istft(Zxx[, fs, window, nperseg, noverlap, ...]) | Perform the inverse short-time Fourier transform (ISTFT). |
stft(x[, fs, window, nperseg, noverlap, ...]) | Compute the short-time Fourier transform (STFT). |
welch(x[, fs, window, nperseg, noverlap, ...]) | Estimate power spectral density (PSD) using Welch's method. |
jax.scipy.spatial.transform#
Rotation(quat) | Rotation in 3 dimensions. |
---|---|
Slerp(times, timedelta, rotations, rotvecs) | Spherical Linear Interpolation of Rotations. |
jax.scipy.sparse.linalg#
bicgstab(A, b[, x0, tol, atol, maxiter, M]) | Use Bi-Conjugate Gradient Stable iteration to solve Ax = b. |
---|---|
cg(A, b[, x0, tol, atol, maxiter, M]) | Use Conjugate Gradient iteration to solve Ax = b. |
gmres(A, b[, x0, tol, atol, restart, ...]) | GMRES solves the linear system A x = b for x, given A and b. |
jax.scipy.special#
bernoulli(n) | Generate the first N Bernoulli numbers. |
---|---|
beta(a, b) | The beta function |
betainc(a, b, x) | The regularized incomplete beta function. |
betaln(a, b) | Natural log of the absolute value of the beta function |
digamma(x) | The digamma function |
entr(x) | The entropy function |
erf(x) | The error function |
erfc(x) | The complement of the error function |
erfinv(x) | The inverse of the error function |
exp1(x) | Exponential integral function. |
expi | Exponential integral function. |
expit(x) | The logistic sigmoid (expit) function |
expn | Generalized exponential integral function. |
factorial(n[, exact]) | Factorial function |
fresnel | The Fresnel integrals |
gamma(x) | The gamma function. |
gammainc(a, x) | The regularized lower incomplete gamma function. |
gammaincc(a, x) | The regularized upper incomplete gamma function. |
gammaln(x) | Natural log of the absolute value of the gamma function. |
gammasgn(x) | Sign of the gamma function. |
hyp1f1 | The 1F1 hypergeometric function. |
i0(x) | Modified bessel function of zeroth order. |
i0e(x) | Exponentially scaled modified bessel function of zeroth order. |
i1(x) | Modified bessel function of first order. |
i1e(x) | Exponentially scaled modified bessel function of first order. |
kl_div(p, q) | The Kullback-Leibler divergence. |
log_ndtr | Log Normal distribution function. |
log_softmax(x, /, *[, axis]) | Log-Softmax function. |
logit | The logit function |
logsumexp() | Log-sum-exp reduction. |
lpmn(m, n, z) | The associated Legendre functions (ALFs) of the first kind. |
lpmn_values(m, n, z, is_normalized) | The associated Legendre functions (ALFs) of the first kind. |
multigammaln(a, d) | The natural log of the multivariate gamma function. |
ndtr(x) | Normal distribution function. |
ndtri(p) | The inverse of the CDF of the Normal distribution function. |
poch | The Pochammer symbol. |
polygamma(n, x) | The polygamma function. |
rel_entr(p, q) | The relative entropy function. |
softmax(x, /, *[, axis]) | Softmax function. |
spence(x) | Spence's function, also known as the dilogarithm for real values. |
sph_harm(m, n, theta, phi[, n_max]) | Computes the spherical harmonics. |
xlog1py | Compute x*log(1 + y), returning 0 for x=0. |
xlogy | Compute x*log(y), returning 0 for x=0. |
zeta | The Hurwitz zeta function. |
jax.scipy.stats#
mode(a[, axis, nan_policy, keepdims]) | Compute the mode (most common value) along an axis of an array. |
---|---|
rankdata(a[, method, axis, nan_policy]) | Compute the rank of data along an array axis. |
sem(a[, axis, ddof, nan_policy, keepdims]) | Compute the standard error of the mean. |
jax.scipy.stats.bernoulli#
logpmf(k, p[, loc]) | Bernoulli log probability mass function. |
---|---|
pmf(k, p[, loc]) | Bernoulli probability mass function. |
cdf(k, p) | Bernoulli cumulative distribution function. |
ppf(q, p) | Bernoulli percent point function. |
jax.scipy.stats.beta#
logpdf(x, a, b[, loc, scale]) | Beta log probability distribution function. |
---|---|
pdf(x, a, b[, loc, scale]) | Beta probability distribution function. |
cdf(x, a, b[, loc, scale]) | Beta cumulative distribution function |
logcdf(x, a, b[, loc, scale]) | Beta log cumulative distribution function. |
sf(x, a, b[, loc, scale]) | Beta distribution survival function. |
logsf(x, a, b[, loc, scale]) | Beta distribution log survival function. |
jax.scipy.stats.betabinom#
logpmf(k, n, a, b[, loc]) | Beta-binomial log probability mass function. |
---|---|
pmf(k, n, a, b[, loc]) | Beta-binomial probability mass function. |
jax.scipy.stats.binom#
logpmf(k, n, p[, loc]) | Binomial log probability mass function. |
---|---|
pmf(k, n, p[, loc]) | Binomial probability mass function. |
jax.scipy.stats.cauchy#
logpdf(x[, loc, scale]) | Cauchy log probability distribution function. |
---|---|
pdf(x[, loc, scale]) | Cauchy probability distribution function. |
cdf(x[, loc, scale]) | Cauchy cumulative distribution function. |
logcdf(x[, loc, scale]) | Cauchy log cumulative distribution function. |
sf(x[, loc, scale]) | Cauchy distribution log survival function. |
logsf(x[, loc, scale]) | Cauchy distribution log survival function. |
isf(q[, loc, scale]) | Cauchy distribution inverse survival function. |
ppf(q[, loc, scale]) | Cauchy distribution percent point function. |
jax.scipy.stats.chi2#
logpdf(x, df[, loc, scale]) | Chi-square log probability distribution function. |
---|---|
pdf(x, df[, loc, scale]) | Chi-square probability distribution function. |
cdf(x, df[, loc, scale]) | Chi-square cumulative distribution function. |
logcdf(x, df[, loc, scale]) | Chi-square log cumulative distribution function. |
sf(x, df[, loc, scale]) | Chi-square survival function. |
logsf(x, df[, loc, scale]) | Chi-square log survival function. |
jax.scipy.stats.dirichlet#
logpdf(x, alpha) | Dirichlet log probability distribution function. |
---|---|
pdf(x, alpha) | Dirichlet probability distribution function. |
jax.scipy.stats.expon#
logpdf(x[, loc, scale]) | Exponential log probability distribution function. |
---|---|
pdf(x[, loc, scale]) | Exponential probability distribution function. |
logcdf(x[, loc, scale]) | Exponential log cumulative density function. |
cdf(x[, loc, scale]) | Exponential cumulative density function. |
logsf(x[, loc, scale]) | Exponential log survival function. |
sf(x[, loc, scale]) | Exponential survival function. |
ppf(q[, loc, scale]) | Exponential survival function. |
jax.scipy.stats.gamma#
logpdf(x, a[, loc, scale]) | Gamma log probability distribution function. |
---|---|
pdf(x, a[, loc, scale]) | Gamma probability distribution function. |
cdf(x, a[, loc, scale]) | Gamma cumulative distribution function. |
logcdf(x, a[, loc, scale]) | Gamma log cumulative distribution function. |
sf(x, a[, loc, scale]) | Gamma survival function. |
logsf(x, a[, loc, scale]) | Gamma log survival function. |
jax.scipy.stats.gennorm#
cdf(x, beta) | Generalized normal cumulative distribution function. |
---|---|
logpdf(x, beta) | Generalized normal log probability distribution function. |
pdf(x, beta) | Generalized normal probability distribution function. |
jax.scipy.stats.geom#
logpmf(k, p[, loc]) | Geometric log probability mass function. |
---|---|
pmf(k, p[, loc]) | Geometric probability mass function. |
jax.scipy.stats.laplace#
cdf(x[, loc, scale]) | Laplace cumulative distribution function. |
---|---|
logpdf(x[, loc, scale]) | Laplace log probability distribution function. |
pdf(x[, loc, scale]) | Laplace probability distribution function. |
jax.scipy.stats.logistic#
cdf(x[, loc, scale]) | Logistic cumulative distribution function. |
---|---|
isf(x[, loc, scale]) | Logistic distribution inverse survival function. |
logpdf(x[, loc, scale]) | Logistic log probability distribution function. |
pdf(x[, loc, scale]) | Logistic probability distribution function. |
ppf(x[, loc, scale]) | Logistic distribution percent point function. |
sf(x[, loc, scale]) | Logistic distribution survival function. |
jax.scipy.stats.multinomial#
logpmf(x, n, p) | Multinomial log probability mass function. |
---|---|
pmf(x, n, p) | Multinomial probability mass function. |
jax.scipy.stats.multivariate_normal#
logpdf(x, mean, cov[, allow_singular]) | Multivariate normal log probability distribution function. |
---|---|
pdf(x, mean, cov) | Multivariate normal probability distribution function. |
jax.scipy.stats.nbinom#
logpmf(k, n, p[, loc]) | Negative-binomial log probability mass function. |
---|---|
pmf(k, n, p[, loc]) | Negative-binomial probability mass function. |
jax.scipy.stats.norm#
logpdf(x[, loc, scale]) | Normal log probability distribution function. |
---|---|
pdf(x[, loc, scale]) | Normal probability distribution function. |
cdf(x[, loc, scale]) | Normal cumulative distribution function. |
logcdf(x[, loc, scale]) | Normal log cumulative distribution function. |
ppf(q[, loc, scale]) | Normal distribution percent point function. |
sf(x[, loc, scale]) | Normal distribution survival function. |
logsf(x[, loc, scale]) | Normal distribution log survival function. |
isf(q[, loc, scale]) | Normal distribution inverse survival function. |
jax.scipy.stats.pareto#
logpdf(x, b[, loc, scale]) | Pareto log probability distribution function. |
---|---|
pdf(x, b[, loc, scale]) | Pareto probability distribution function. |
jax.scipy.stats.poisson#
logpmf(k, mu[, loc]) | Poisson log probability mass function. |
---|---|
pmf(k, mu[, loc]) | Poisson probability mass function. |
cdf(k, mu[, loc]) | Poisson cumulative distribution function. |
jax.scipy.stats.t#
logpdf(x, df[, loc, scale]) | Student's T log probability distribution function. |
---|---|
pdf(x, df[, loc, scale]) | Student's T probability distribution function. |
jax.scipy.stats.truncnorm#
cdf(x, a, b[, loc, scale]) | Truncated normal cumulative distribution function. |
---|---|
logcdf(x, a, b[, loc, scale]) | Truncated normal log cumulative distribution function. |
logpdf(x, a, b[, loc, scale]) | Truncated normal log probability distribution function. |
logsf(x, a, b[, loc, scale]) | Truncated normal distribution log survival function. |
pdf(x, a, b[, loc, scale]) | Truncated normal probability distribution function. |
sf(x, a, b[, loc, scale]) | Truncated normal distribution log survival function. |
jax.scipy.stats.uniform#
logpdf(x[, loc, scale]) | Uniform log probability distribution function. |
---|---|
pdf(x[, loc, scale]) | Uniform probability distribution function. |
cdf(x[, loc, scale]) | Uniform cumulative distribution function. |
ppf(q[, loc, scale]) | Uniform distribution percent point function. |
jax.scipy.stats.gaussian_kde#
gaussian_kde(dataset[, bw_method, weights]) | Gaussian Kernel Density Estimator |
---|---|
gaussian_kde.evaluate(points) | Evaluate the Gaussian KDE on the given points. |
gaussian_kde.integrate_gaussian(mean, cov) | Integrate the distribution weighted by a Gaussian. |
gaussian_kde.integrate_box_1d(low, high) | Integrate the distribution over the given limits. |
gaussian_kde.integrate_kde(other) | Integrate the product of two Gaussian KDE distributions. |
gaussian_kde.resample(key[, shape]) | Randomly sample a dataset from the estimated pdf |
gaussian_kde.pdf(x) | Probability density function |
gaussian_kde.logpdf(x) | Log probability density function |
jax.scipy.stats.vonmises#
logpdf(x, kappa) | von Mises log probability distribution function. |
---|---|
pdf(x, kappa) | von Mises probability distribution function. |
jax.scipy.stats.wrapcauchy#
logpdf(x, c) | Wrapped Cauchy log probability distribution function. |
---|---|
pdf(x, c) | Wrapped Cauchy probability distribution function. |