Interpolations - Diffrax (original) (raw)

When solving controlled differential equations, it is relatively common for the control to be an interpolation of discrete data.

The following interpolation routines may be used to perform this interpolation.

Note

Missing data, represented as NaN, can be handled here as well. (And if you are familiar with the problem of informative missingness, note that this can be handled as well: see Sections 3.5 and 3.6 of this paper.)

References

The main two references for using interpolation with controlled differential equations are as follows.

Original neural CDE paper:

@article{kidger2020neuralcde, author={Kidger, Patrick and Morrill, James and Foster, James and Lyons, Terry}, title={{N}eural {C}ontrolled {D}ifferential {E}quations for {I}rregular {T}ime {S}eries}, journal={Neural Information Processing Systems}, year={2020}, }

Investigating specifically the choice of interpolation scheme for CDEs:

@article{morrill2021cdeonline, title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline {P}rediction {T}asks}, author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry}, journal={arXiv:2106.11028}, year={2021} }

How to pick an interpolation scheme

There are a few main types of interpolation provided here. For 99% of applications you will want either rectilinear or cubic interpolation, as follows.

Rectilinear interpolation can be obtained by combining diffrax.rectilinear_interpolation and diffrax.LinearInterpolation.

Hermite cubic splines with backward differences can be obtained by combining diffrax.backward_hermite_coefficients and diffrax.CubicInterpolation.


Interpolation classes¤

The following are the main interpolation classes. Instances of these classes are suitable controls to pass to diffrax.ControlTerm.

`` diffrax.LinearInterpolation([diffrax.AbstractPath](../path/#diffrax.AbstractPath)) ¤

Linearly interpolates some data ys over the interval \([t_0, t_1]\) with knots at ts.

Warning

If using LinearInterpolation as part of a diffrax.ControlTerm, then the vector field will make a jump every time one of the knots ts is passed. If using an adaptive step size controller such as diffrax.PIDController, then this means the controller should be informed about the jumps, so that it can handle them appropriately:

ts = ... interp = LinearInterpolation(ts=ts, ...) term = ControlTerm(..., control=interp) stepsize_controller = PIDController(..., jump_ts=ts)

`` t0 property ¤

The start of the interval over which the interpolation is defined.

`` t1 property ¤

The end of the interval over which the interpolation is defined.

`` __init__(ts: Real[Array, 'times'], ys: PyTree[Shaped[Array, 'times ...']]) ¤

Arguments:

Note that if ys has any missing data then you may wish to usediffrax.linear_interpolation or diffrax.rectilinear_interpolation first to interpolate over these.

`` evaluate(t0: Real[ArrayLike, ''], t1: Real[ArrayLike, ''] | None = None, left: bool = True) -> PyTree[Array] ¤

Evaluate the linear interpolation.

Arguments:

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as obtained via instance.t0 and instance.t1. We use t0 and t1 to refer to some subinterval of \([t_0, t_1]\). This is an API that is used for consistency with the rest of the package, and just happens to be a little confusing here.

Returns:

If t1 is not passed:

The interpolation of the data. Suppose \(t_j < t < t_{j+1}\), where \(t\) is t0and \(t_j\) and \(t_{j+1}\) are some element of ts as passed in __init__. Then the value returned is\(y_j + (y_{j+1} - y_j)\frac{t - t_j}{t_{j+1} - t_j}\).

If t1 is passed:

As above, with \(t\) taken to be both t0 and t1, and the increment between them returned.

`` derivative(t: Real[ArrayLike, ''], left: bool = True) -> PyTree[Array] ¤

Evaluate the derivative of the linear interpolation. Essentially equivalent to jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),)).

Arguments:

Returns:

The derivative of the interpolation of the data. Suppose \(t_j < t < t_{j+1}\), where \(t_j\) and \(t_{j+1}\) are some elements of ts passed in __init__. Then the value returned is \(\frac{y_{j+1} - y_j}{t_{j+1} - t_j}\).

`` diffrax.CubicInterpolation([diffrax.AbstractPath](../path/#diffrax.AbstractPath)) ¤

Piecewise cubic spline interpolation over the interval \([t_0, t_1]\).

`` t0 property ¤

The start of the interval over which the interpolation is defined.

`` t1 property ¤

The end of the interval over which the interpolation is defined.

`` __init__(ts: Real[Array, 'times'], coeffs: tuple[PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y']]) ¤

Arguments:

Any kind of spline (natural, ...) may be used; simply pass the appropriate coefficients.

In practice a good choice is typically "cubic Hermite splines with backward differences", introduced in this paper. Such coefficients can be obtained using diffrax.backward_hermite_coefficients.

Letting d, c, b, a = coeffs, then for all t in the interval from ts[i] tots[i + 1] the interpolation is defined as

d[i] * (t - ts[i]) ** 3 + c[i] * (t - ts[i]) ** 2 + b[i] * (t - ts[i]) + a[i]

`` evaluate(t0: Real[ArrayLike, ''], t1: Real[ArrayLike, ''] | None = None, left: bool = True) -> PyTree[Shaped[Array, '?*shape'], 'Y'] ¤

Evaluate the cubic interpolation.

Arguments:

FAQ

Note that we use \(t_0\) and \(t_1\) to refer to the overall interval, as obtained via instance.t0 and instance.t1. We use t0 and t1 to refer to some subinterval of \([t_0, t_1]\). This is an API that is used for consistency with the rest of the package, and just happens to be a little confusing here.

Returns:

If t1 is not passed:

The interpolation of the data at t0.

If t1 is passed:

The increment between t0 and t1.

`` derivative(t: Real[ArrayLike, ''], left: bool = True) -> PyTree[Shaped[Array, '?*shape'], 'Y'] ¤

Evaluate the derivative of the cubic interpolation. Essentially equivalent to jax.jvp(self.evaluate, (t,), (jnp.ones_like(t),)).

Arguments:

Returns:

The derivative of the interpolation of the data.


Handling missing data¤

We would like diffrax.LinearInterpolation to be able to handle missing data (represented as NaN). The following can be used for this purpose.

`` diffrax.linear_interpolation(ts: Real[Array, 'times'], ys: PyTree[Shaped[Array, 'times ?*shape'], 'Y'], *, fill_forward_nans_at_end: bool = False, replace_nans_at_start: PyTree[Shaped[ArrayLike, '?#*shape'], 'Y'] | None = None) -> PyTree[Shaped[Array, 'times ?*shape'], 'Y'] ¤

Fill in any missing values via linear interpolation.

Any missing values in ys (represented as NaN) are filled in by looking at the nearest non-NaN values either side, and linearly interpolating.

This is often useful prior to using diffrax.LinearInterpolation to create a continuous path from discrete observations.

Arguments:

Returns:

As ys, but with NaN values filled in.

`` diffrax.rectilinear_interpolation(ts: Real[Array, 'times'], ys: PyTree[Shaped[Array, 'times ?*shape'], 'Y'], replace_nans_at_start: PyTree[Shaped[ArrayLike, '?#*shape'], 'Y'] | None = None) -> tuple[Real[Array, '2*times-1'], PyTree[Shaped[Array, '2*times-1 ?*shape'], 'Y']] ¤

Rectilinearly interpolates the input. This is a variant of linear interpolation that is particularly useful when using neural CDEs in a real-time scenario.

This is often useful prior to using diffrax.LinearInterpolation to create a continuous path from discrete observations, in real-time scenarios.

It is strongly recommended to have a read of the reference below if you are unfamiliar.

Reference

@article{morrill2021cdeonline, title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline {P}rediction {T}asks}, author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry}, journal={arXiv:2106.11028}, year={2021} }

Example

Suppose ts = [t0, t1, t2, t3] and ys = [y0, y1, y2, y3]. Then rectilinearly interpolating these produces new_ts = [t0, t1, t1, t2, t2, t3, t3] andnew_ys = [y0, y0, y1, y1, y2, y2, y3].

This can be thought of as advancing time whilst keeping the data fixed, then keeping time fixed whilst advancing the data.

Arguments:

Returns:

A new version of both ts and ys, subject to rectilinear interpolation.

Example

Suppose we wish to use a rectilinearly interpolated control to drive a neural CDE. Then this should be done something like the following:

ts = jnp.array([0., 1., 1.5, 2.]) ys = jnp.array([5., 6., 5., 6.]) ts, ys = rectilinear_interpolation(ts, ys) data = jnp.stack([ts, ys], axis=-1) interp_ts = jnp.arange(7) interp = LinearInterpolation(interp_ts, data)

Note how time and observations are stacked together as the data of the interpolation (as usual for a neural CDE), and how the interpolation times are something we are free to pick.


Calculating coefficients¤

`` diffrax.backward_hermite_coefficients(ts: Real[Array, 'times'], ys: PyTree[Shaped[Array, 'times ?*shape'], 'Y'], *, deriv0: PyTree[Shaped[Array, '?#*shape'], 'Y'] | None = None, replace_nans_at_start: PyTree[Shaped[ArrayLike, '?#*shape'], 'Y'] | None = None, fill_forward_nans_at_end: bool = False) -> tuple[PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y'], PyTree[Shaped[Array, 'times-1 ?*shape'], 'Y']] ¤

Interpolates the data with a cubic spline. Specifically, this calculates the coefficients for Hermite cubic splines with backward differences.

This is most useful prior to using diffrax.CubicInterpolation to create a smooth path from discrete observations.

Reference

Hermite cubic splines with backward differences were introduced in this paper:

@article{morrill2021cdeonline, title={{N}eural {C}ontrolled {D}ifferential {E}quations for {O}nline {P}rediction {T}asks}, author={Morrill, James and Kidger, Patrick and Yang, Lingyi and Lyons, Terry}, journal={arXiv:2106.11028}, year={2021} }

Arguments:

Returns:

The coefficients of the Hermite cubic spline. If ts has length \(T\) then the coefficients will be of length \(T - 1\), covering each of the intervals from ts[0]to ts[1], and ts[1] to ts[2] etc.