tfp.math.custom_gradient | TensorFlow Probability (original) (raw)
Embeds a custom gradient into a Tensor
.
tfp.math.custom_gradient(
fx, gx, x, fx_gx_manually_stopped=False, name=None
)
This function works by clever application of stop_gradient
. I.e., observe that:
h(x) = stop_gradient(f(x)) + stop_gradient(g(x)) * (x - stop_gradient(x))
is such that h(x) == stop_gradient(f(x))
andgrad[h(x), x] == stop_gradient(g(x)).
In addition to scalar-domain/scalar-range functions, this function also supports tensor-domain/scalar-range functions.
Partial Custom Gradient:
Suppose h(x) = htilde(x, y)
. Note that dh/dx = stop(g(x))
but dh/dy = None
. This is because a Tensor
cannot have only a portion of its gradient stopped. To circumvent this issue, one must manually stop_gradient
the relevant portions of f
, g
. For example see the unit-test,test_works_correctly_fx_gx_manually_stopped
.
Args | |
---|---|
fx | Tensor. Output of function evaluated at x. |
gx | Tensor or list of Tensors. Gradient of function at (each) x. |
x | Tensor or list of Tensors. Args of evaluation for f. |
fx_gx_manually_stopped | Python bool indicating that fx, gx manually have stop_gradient applied. |
name | Python str name prefixed to Ops created by this function. |
Returns | |
---|---|
fx | Floating-type Tensor equal to f(x) but which has gradientstop_gradient(g(x)). |