tf.recompute_grad  |  TensorFlow v2.16.1 (original) (raw)

tf.recompute_grad

Stay organized with collections Save and categorize content based on your preferences.

Defines a function as a recompute-checkpoint for the tape auto-diff.

View aliases

Compat aliases for migration

SeeMigration guide for more details.

tf.compat.v1.recompute_grad

tf.recompute_grad(
    f
)

Tape checkpointing is a technique to reduce the memory consumption of the auto-diff tape:

y = tf.Variable(1.0)

def my_function(x): tf.print('running') z = x*y return z

my_function_recompute = tf.recompute_grad(my_function)

with tf.GradientTape() as tape: r = tf.constant(1.0) for i in range(4): r = my_function_recompute(r) running running running running

grad = tape.gradient(r, [y]) running running running running

Without recompute_grad, the tape contains all intermitate steps, and no recomputation is performed.

with tf.GradientTape() as tape: r = tf.constant(1.0) for i in range(4): r = my_function(r) running running running running

grad = tape.gradient(r, [y])

If f was a tf.keras Model or Layer object, methods and attributes such as f.variables are not available on the returned function g. Either keep a reference of f , or use g.__wrapped__ for accessing these variables and methods.

def print_running_and_return(x): tf.print("running") return x

model = tf.keras.Sequential([ tf.keras.layers.Lambda(print_running_and_return), tf.keras.layers.Dense(2) ])

model_recompute = tf.recompute_grad(model)

with tf.GradientTape(persistent=True) as tape: r = tf.constant([[1,2]]) for i in range(4): r = model_recompute(r) running running running running

grad = tape.gradient(r, model.variables) running running running running

Alternatively, use the __wrapped__ attribute to access the original model object.

grad = tape.gradient(r, model_recompute.__wrapped__.variables) running running running running

Args
f function f(*x) that returns a Tensor or sequence of Tensor outputs.
Returns
A function g wrapping f that defines a custom gradient, which recomputesf on the backwards pass of a gradient call.