jax.scipy module — JAX documentation (original) (raw)
jax.scipy module#
jax.scipy.cluster#
| vq(obs, code_book[, check_finite]) | Assign codes from a code book to a set of observations. |
|---|
jax.scipy.fft#
| dct(x[, type, n, axis, norm]) | Computes the discrete cosine transform of the input |
|---|---|
| dctn(x[, type, s, axes, norm]) | Computes the multidimensional discrete cosine transform of the input |
| idct(x[, type, n, axis, norm]) | Computes the inverse discrete cosine transform of the input |
| idctn(x[, type, s, axes, norm]) | Computes the multidimensional inverse discrete cosine transform of the input |
jax.scipy.integrate#
| trapezoid(y[, x, dx, axis]) | Integrate along the given axis using the composite trapezoidal rule. |
|---|
jax.scipy.interpolate#
jax.scipy.linalg#
| block_diag(*arrs) | Create a block diagonal matrix from input arrays. |
|---|---|
| cho_factor(a[, lower, overwrite_a, check_finite]) | Factorization for Cholesky-based linear solves |
| cho_solve(c_and_lower, b[, overwrite_b, ...]) | Solve a linear system using a Cholesky factorization |
| cholesky(a[, lower, overwrite_a, check_finite]) | Compute the Cholesky decomposition of a matrix. |
| det(a[, overwrite_a, check_finite]) | Compute the determinant of a matrix |
| eigh() | Compute eigenvalues and eigenvectors for a Hermitian matrix |
| eigh_tridiagonal(d, e, *[, eigvals_only, ...]) | Solve the eigenvalue problem for a symmetric real tridiagonal matrix |
| expm(A, *[, upper_triangular, max_squarings]) | Compute the matrix exponential |
| expm_frechet() | Compute the Frechet derivative of the matrix exponential. |
| funm(A, func[, disp]) | Evaluate a matrix-valued function |
| hessenberg() | Compute the Hessenberg form of the matrix |
| hilbert(n) | Create a Hilbert matrix of order n. |
| inv(a[, overwrite_a, check_finite]) | Return the inverse of a square matrix |
| lu() | Compute the LU decomposition |
| lu_factor(a[, overwrite_a, check_finite]) | Factorization for LU-based linear solves |
| lu_solve(lu_and_piv, b[, trans, ...]) | Solve a linear system using an LU factorization |
| pascal(n[, kind]) | Create a Pascal matrix approximation of order n. |
| polar(a[, side, method, eps, max_iterations]) | Computes the polar decomposition. |
| qr() | Compute the QR decomposition of an array |
| rsf2csf(T, Z[, check_finite]) | Convert real Schur form to complex Schur form. |
| schur(a[, output]) | Compute the Schur decomposition |
| solve(a, b[, lower, overwrite_a, ...]) | Solve a linear system of equations. |
| solve_sylvester(A, B, C, *[, method, tol]) | Solves the Sylvester equation .. math::. |
| solve_triangular(a, b[, trans, lower, ...]) | Solve a triangular linear system of equations |
| sqrtm(A[, blocksize]) | Compute the matrix square root |
| svd() | Compute the singular value decomposition. |
| toeplitz(c[, r]) | Construct a Toeplitz matrix. |
jax.scipy.ndimage#
| map_coordinates(input, coordinates, order[, ...]) | Map the input array to new coordinates using interpolation. |
|---|
jax.scipy.optimize#
| minimize(fun, x0[, args, tol, options]) | Minimization of scalar function of one or more variables. |
|---|---|
| OptimizeResults(x, success, status, fun, ...) | Object holding optimization results. |
jax.scipy.signal#
| fftconvolve(in1, in2[, mode, axes]) | Convolve two N-dimensional arrays using Fast Fourier Transform (FFT). |
|---|---|
| convolve(in1, in2[, mode, method, precision]) | Convolution of two N-dimensional arrays. |
| convolve2d(in1, in2[, mode, boundary, ...]) | Convolution of two 2-dimensional arrays. |
| correlate(in1, in2[, mode, method, precision]) | Cross-correlation of two N-dimensional arrays. |
| correlate2d(in1, in2[, mode, boundary, ...]) | Cross-correlation of two 2-dimensional arrays. |
| csd(x, y[, fs, window, nperseg, noverlap, ...]) | Estimate cross power spectral density (CSD) using Welch's method. |
| detrend(data[, axis, type, bp, overwrite_data]) | Remove linear or piecewise linear trends from data. |
| istft(Zxx[, fs, window, nperseg, noverlap, ...]) | Perform the inverse short-time Fourier transform (ISTFT). |
| stft(x[, fs, window, nperseg, noverlap, ...]) | Compute the short-time Fourier transform (STFT). |
| welch(x[, fs, window, nperseg, noverlap, ...]) | Estimate power spectral density (PSD) using Welch's method. |
jax.scipy.spatial.transform#
| Rotation(quat) | Rotation in 3 dimensions. |
|---|---|
| Slerp(times, timedelta, rotations, rotvecs) | Spherical Linear Interpolation of Rotations. |
jax.scipy.sparse.linalg#
| bicgstab(A, b[, x0, tol, atol, maxiter, M]) | Use Bi-Conjugate Gradient Stable iteration to solve Ax = b. |
|---|---|
| cg(A, b[, x0, tol, atol, maxiter, M]) | Use Conjugate Gradient iteration to solve Ax = b. |
| gmres(A, b[, x0, tol, atol, restart, ...]) | GMRES solves the linear system A x = b for x, given A and b. |
jax.scipy.special#
| bernoulli(n) | Generate the first N Bernoulli numbers. |
|---|---|
| beta(a, b) | The beta function |
| betainc(a, b, x) | The regularized incomplete beta function. |
| betaln(a, b) | Natural log of the absolute value of the beta function |
| digamma(x) | The digamma function |
| entr(x) | The entropy function |
| erf(x) | The error function |
| erfc(x) | The complement of the error function |
| erfinv(x) | The inverse of the error function |
| exp1(x) | Exponential integral function. |
| expi | Exponential integral function. |
| expit(x) | The logistic sigmoid (expit) function |
| expn | Generalized exponential integral function. |
| factorial(n[, exact]) | Factorial function |
| fresnel | The Fresnel integrals |
| gamma(x) | The gamma function. |
| gammainc(a, x) | The regularized lower incomplete gamma function. |
| gammaincc(a, x) | The regularized upper incomplete gamma function. |
| gammaln(x) | Natural log of the absolute value of the gamma function. |
| gammasgn(x) | Sign of the gamma function. |
| hyp1f1 | The 1F1 hypergeometric function. |
| hyp2f1 | The 2F1 hypergeometric function. |
| i0(x) | Modified bessel function of zeroth order. |
| i0e(x) | Exponentially scaled modified bessel function of zeroth order. |
| i1(x) | Modified bessel function of first order. |
| i1e(x) | Exponentially scaled modified bessel function of first order. |
| kl_div(p, q) | The Kullback-Leibler divergence. |
| log_ndtr | Log Normal distribution function. |
| log_softmax(x, /, *[, axis]) | Log-Softmax function. |
| logit | The logit function |
| logsumexp() | Log-sum-exp reduction. |
| lpmn(m, n, z) | The associated Legendre functions (ALFs) of the first kind. |
| lpmn_values(m, n, z, is_normalized) | The associated Legendre functions (ALFs) of the first kind. |
| multigammaln(a, d) | The natural log of the multivariate gamma function. |
| ndtr(x) | Normal distribution function. |
| ndtri(p) | The inverse of the CDF of the Normal distribution function. |
| poch | The Pochammer symbol. |
| polygamma(n, x) | The polygamma function. |
| rel_entr(p, q) | The relative entropy function. |
| sici | Sine and cosine integrals. |
| softmax(x, /, *[, axis]) | Softmax function. |
| spence(x) | Spence's function, also known as the dilogarithm for real values. |
| sph_harm(m, n, theta, phi[, n_max]) | Computes the spherical harmonics. |
| xlog1py | Compute x*log(1 + y), returning 0 for x=0. |
| xlogy | Compute x*log(y), returning 0 for x=0. |
| zeta | The Hurwitz zeta function. |
jax.scipy.stats#
| mode(a[, axis, nan_policy, keepdims]) | Compute the mode (most common value) along an axis of an array. |
|---|---|
| rankdata(a[, method, axis, nan_policy]) | Compute the rank of data along an array axis. |
| sem(a[, axis, ddof, nan_policy, keepdims]) | Compute the standard error of the mean. |
jax.scipy.stats.bernoulli#
| logpmf(k, p[, loc]) | Bernoulli log probability mass function. |
|---|---|
| pmf(k, p[, loc]) | Bernoulli probability mass function. |
| cdf(k, p) | Bernoulli cumulative distribution function. |
| ppf(q, p) | Bernoulli percent point function. |
jax.scipy.stats.beta#
| logpdf(x, a, b[, loc, scale]) | Beta log probability distribution function. |
|---|---|
| pdf(x, a, b[, loc, scale]) | Beta probability distribution function. |
| cdf(x, a, b[, loc, scale]) | Beta cumulative distribution function |
| logcdf(x, a, b[, loc, scale]) | Beta log cumulative distribution function. |
| sf(x, a, b[, loc, scale]) | Beta distribution survival function. |
| logsf(x, a, b[, loc, scale]) | Beta distribution log survival function. |
jax.scipy.stats.betabinom#
| logpmf(k, n, a, b[, loc]) | Beta-binomial log probability mass function. |
|---|---|
| pmf(k, n, a, b[, loc]) | Beta-binomial probability mass function. |
jax.scipy.stats.binom#
| logpmf(k, n, p[, loc]) | Binomial log probability mass function. |
|---|---|
| pmf(k, n, p[, loc]) | Binomial probability mass function. |
jax.scipy.stats.cauchy#
| logpdf(x[, loc, scale]) | Cauchy log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Cauchy probability distribution function. |
| cdf(x[, loc, scale]) | Cauchy cumulative distribution function. |
| logcdf(x[, loc, scale]) | Cauchy log cumulative distribution function. |
| sf(x[, loc, scale]) | Cauchy distribution log survival function. |
| logsf(x[, loc, scale]) | Cauchy distribution log survival function. |
| isf(q[, loc, scale]) | Cauchy distribution inverse survival function. |
| ppf(q[, loc, scale]) | Cauchy distribution percent point function. |
jax.scipy.stats.chi2#
| logpdf(x, df[, loc, scale]) | Chi-square log probability distribution function. |
|---|---|
| pdf(x, df[, loc, scale]) | Chi-square probability distribution function. |
| cdf(x, df[, loc, scale]) | Chi-square cumulative distribution function. |
| logcdf(x, df[, loc, scale]) | Chi-square log cumulative distribution function. |
| sf(x, df[, loc, scale]) | Chi-square survival function. |
| logsf(x, df[, loc, scale]) | Chi-square log survival function. |
jax.scipy.stats.dirichlet#
| logpdf(x, alpha) | Dirichlet log probability distribution function. |
|---|---|
| pdf(x, alpha) | Dirichlet probability distribution function. |
jax.scipy.stats.expon#
| logpdf(x[, loc, scale]) | Exponential log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Exponential probability distribution function. |
| logcdf(x[, loc, scale]) | Exponential log cumulative density function. |
| cdf(x[, loc, scale]) | Exponential cumulative density function. |
| logsf(x[, loc, scale]) | Exponential log survival function. |
| sf(x[, loc, scale]) | Exponential survival function. |
| ppf(q[, loc, scale]) | Exponential survival function. |
jax.scipy.stats.gamma#
| logpdf(x, a[, loc, scale]) | Gamma log probability distribution function. |
|---|---|
| pdf(x, a[, loc, scale]) | Gamma probability distribution function. |
| cdf(x, a[, loc, scale]) | Gamma cumulative distribution function. |
| logcdf(x, a[, loc, scale]) | Gamma log cumulative distribution function. |
| sf(x, a[, loc, scale]) | Gamma survival function. |
| logsf(x, a[, loc, scale]) | Gamma log survival function. |
jax.scipy.stats.gumbel_l#
| logpdf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) probability distribution function. |
| cdf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) cumulative density function. |
| logcdf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) log cumulative density function. |
| sf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) survival function. |
| logsf(x[, loc, scale]) | Gumbel Distribution (Left Skewed) log survival function. |
| ppf(p[, loc, scale]) | Gumbel Distribution (Left Skewed) percent point function (inverse of CDF) |
jax.scipy.stats.gumbel_r#
| logpdf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) probability distribution function. |
| cdf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) cumulative density function. |
| logcdf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) log cumulative density function. |
| sf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) survival function. |
| logsf(x[, loc, scale]) | Gumbel Distribution (Right Skewed) log survival function. |
| ppf(p[, loc, scale]) | Gumbel Distribution (Right Skewed) percent point function. |
jax.scipy.stats.gennorm#
| cdf(x, beta) | Generalized normal cumulative distribution function. |
|---|---|
| logpdf(x, beta) | Generalized normal log probability distribution function. |
| pdf(x, beta) | Generalized normal probability distribution function. |
jax.scipy.stats.geom#
| logpmf(k, p[, loc]) | Geometric log probability mass function. |
|---|---|
| pmf(k, p[, loc]) | Geometric probability mass function. |
jax.scipy.stats.laplace#
| cdf(x[, loc, scale]) | Laplace cumulative distribution function. |
|---|---|
| logpdf(x[, loc, scale]) | Laplace log probability distribution function. |
| pdf(x[, loc, scale]) | Laplace probability distribution function. |
jax.scipy.stats.logistic#
| cdf(x[, loc, scale]) | Logistic cumulative distribution function. |
|---|---|
| isf(x[, loc, scale]) | Logistic distribution inverse survival function. |
| logpdf(x[, loc, scale]) | Logistic log probability distribution function. |
| pdf(x[, loc, scale]) | Logistic probability distribution function. |
| ppf(x[, loc, scale]) | Logistic distribution percent point function. |
| sf(x[, loc, scale]) | Logistic distribution survival function. |
jax.scipy.stats.multinomial#
| logpmf(x, n, p) | Multinomial log probability mass function. |
|---|---|
| pmf(x, n, p) | Multinomial probability mass function. |
jax.scipy.stats.multivariate_normal#
| logpdf(x, mean, cov[, allow_singular]) | Multivariate normal log probability distribution function. |
|---|---|
| pdf(x, mean, cov) | Multivariate normal probability distribution function. |
jax.scipy.stats.nbinom#
| logpmf(k, n, p[, loc]) | Negative-binomial log probability mass function. |
|---|---|
| pmf(k, n, p[, loc]) | Negative-binomial probability mass function. |
jax.scipy.stats.norm#
| logpdf(x[, loc, scale]) | Normal log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Normal probability distribution function. |
| cdf(x[, loc, scale]) | Normal cumulative distribution function. |
| logcdf(x[, loc, scale]) | Normal log cumulative distribution function. |
| ppf(q[, loc, scale]) | Normal distribution percent point function. |
| sf(x[, loc, scale]) | Normal distribution survival function. |
| logsf(x[, loc, scale]) | Normal distribution log survival function. |
| isf(q[, loc, scale]) | Normal distribution inverse survival function. |
jax.scipy.stats.pareto#
| logcdf(x, b[, loc, scale]) | Pareto log cumulative distribution function. |
|---|---|
| logpdf(x, b[, loc, scale]) | Pareto log probability distribution function. |
| logsf(x, b[, loc, scale]) | Pareto log survival function. |
| cdf(x, b[, loc, scale]) | Pareto cumulative distribution function. |
| pdf(x, b[, loc, scale]) | Pareto probability distribution function. |
| ppf(q, b[, loc, scale]) | Pareto percent point function (inverse CDF). |
| sf(x, b[, loc, scale]) | Pareto survival function. |
jax.scipy.stats.poisson#
| logpmf(k, mu[, loc]) | Poisson log probability mass function. |
|---|---|
| pmf(k, mu[, loc]) | Poisson probability mass function. |
| cdf(k, mu[, loc]) | Poisson cumulative distribution function. |
| entropy(mu[, loc]) | Shannon entropy of the Poisson distribution. |
jax.scipy.stats.t#
| logpdf(x, df[, loc, scale]) | Student's T log probability distribution function. |
|---|---|
| pdf(x, df[, loc, scale]) | Student's T probability distribution function. |
jax.scipy.stats.truncnorm#
| cdf(x, a, b[, loc, scale]) | Truncated normal cumulative distribution function. |
|---|---|
| logcdf(x, a, b[, loc, scale]) | Truncated normal log cumulative distribution function. |
| logpdf(x, a, b[, loc, scale]) | Truncated normal log probability distribution function. |
| logsf(x, a, b[, loc, scale]) | Truncated normal distribution log survival function. |
| pdf(x, a, b[, loc, scale]) | Truncated normal probability distribution function. |
| sf(x, a, b[, loc, scale]) | Truncated normal distribution log survival function. |
jax.scipy.stats.uniform#
| logpdf(x[, loc, scale]) | Uniform log probability distribution function. |
|---|---|
| pdf(x[, loc, scale]) | Uniform probability distribution function. |
| cdf(x[, loc, scale]) | Uniform cumulative distribution function. |
| ppf(q[, loc, scale]) | Uniform distribution percent point function. |
jax.scipy.stats.gaussian_kde#
| gaussian_kde(dataset[, bw_method, weights]) | Gaussian Kernel Density Estimator |
|---|---|
| gaussian_kde.evaluate(points) | Evaluate the Gaussian KDE on the given points. |
| gaussian_kde.integrate_gaussian(mean, cov) | Integrate the distribution weighted by a Gaussian. |
| gaussian_kde.integrate_box_1d(low, high) | Integrate the distribution over the given limits. |
| gaussian_kde.integrate_kde(other) | Integrate the product of two Gaussian KDE distributions. |
| gaussian_kde.resample(key[, shape]) | Randomly sample a dataset from the estimated pdf |
| gaussian_kde.pdf(x) | Probability density function |
| gaussian_kde.logpdf(x) | Log probability density function |
jax.scipy.stats.vonmises#
| logpdf(x, kappa) | von Mises log probability distribution function. |
|---|---|
| pdf(x, kappa) | von Mises probability distribution function. |
jax.scipy.stats.wrapcauchy#
| logpdf(x, c) | Wrapped Cauchy log probability distribution function. |
|---|---|
| pdf(x, c) | Wrapped Cauchy probability distribution function. |