jax.ref module — JAX documentation (original) (raw)

jax.ref module#

jax.ref has the API for working with ArrayRef.

API#

AbstractRef(inner_aval[, memory_space, kind]) Abstract mutable array reference.
Ref(aval, refs) Mutable array reference.
freeze(ref) Invalidate a given reference and return its final value.
get(ref[, idx]) Read a value from an Ref.
new_ref(init_val, *[, memory_space]) Create a mutable array reference with initial value init_val.
set(ref, idx, value) Set a value in an Ref in-place.
swap(ref, idx, value[, _function_name]) Update an array value inplace while returning the previous value.
addupdate(ref, idx, x) Add to an element in an Ref in-place.