Tensorflow 2 efficace (original) (raw)
Aperçu
Ce guide fournit une liste des meilleures pratiques pour écrire du code à l'aide de TensorFlow 2 (TF2). Il est destiné aux utilisateurs qui ont récemment basculé depuis TensorFlow 1 (TF1). Reportez-vous à la section migration du guide pour plus d'informations sur la migration de votre code TF1 vers TF2.
Installer
Importez TensorFlow et d'autres dépendances pour les exemples de ce guide.
import tensorflow as tf
import tensorflow_datasets as tfds
Recommandations pour TensorFlow 2 idiomatique
Refactorisez votre code en modules plus petits
Une bonne pratique consiste à refactoriser votre code en fonctions plus petites qui sont appelées selon les besoins. Pour de meilleures performances, vous devez essayer de décorer les plus grands blocs de calcul que vous pouvez dans un tf.function (notez que les fonctions python imbriquées appelées par un tf.function ne nécessitent pas leurs propres décorations séparées, sauf si vous souhaitez utiliser différents jit_compile
paramètres de la tf.function .). Selon votre cas d'utilisation, il peut s'agir de plusieurs étapes d'entraînement ou même de toute votre boucle d'entraînement. Pour les cas d'utilisation d'inférence, il peut s'agir d'une seule passe avant de modèle.
Ajuster le taux d'apprentissage par défaut pour certains tf.keras.optimizer
s
Certains optimiseurs Keras ont des taux d'apprentissage différents dans TF2. Si vous constatez un changement dans le comportement de convergence de vos modèles, vérifiez les taux d'apprentissage par défaut.
Il n'y a aucun changement pour les optimizers.SGD , les optimizers.Adam ou les optimizers.RMSprop .
Les taux d'apprentissage par défaut suivants ont changé :
- optimizers.Adagrad de
0.01
à0.001
- optimizers.Adadelta de
1.0
à0.001
- optimizers.Adamax de
0.002
à0.001
- optimizers.Nadam de
0.002
à0.001
Utiliser les tf.Module s et Keras pour gérer les variables
tf.Module s et tf.keras.layers.Layer s offrent les variables
pratiques et les propriétés trainable_variables
, qui rassemblent de manière récursive toutes les variables dépendantes. Cela facilite la gestion des variables localement là où elles sont utilisées.
Les couches/modèles Keras héritent de tf.train.Checkpointable
et sont intégrés à @tf.function , ce qui permet de contrôler directement ou d'exporter des SavedModels à partir d'objets Keras. Vous n'avez pas nécessairement besoin d'utiliser l'API Model.fit de Keras pour tirer parti de ces intégrations.
Lisez la section sur l'apprentissage par transfert et le réglage fin du guide Keras pour savoir comment collecter un sous-ensemble de variables pertinentes à l'aide de Keras.
Combinez tf.data.Dataset s et tf.function
Le package TensorFlow Datasets ( tfds ) contient des utilitaires permettant de charger des ensembles de données prédéfinis en tant tf.data.Dataset . Pour cet exemple, vous pouvez charger le jeu de données MNIST à l'aide tfds :
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Préparez ensuite les données pour la formation :
- Redimensionnez chaque image.
- Mélangez l'ordre des exemples.
- Collectez des lots d'images et d'étiquettes.
BUFFER_SIZE = 10 # Use a much larger value for real code
BATCH_SIZE = 64
NUM_EPOCHS = 5
def scale(image, label):
image = tf.cast(image, tf.float32)
image /= 255
return image, label
Pour que l'exemple reste court, découpez l'ensemble de données pour ne renvoyer que 5 lots :
train_data = mnist_train.map(scale).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
test_data = mnist_test.map(scale).batch(BATCH_SIZE)
STEPS_PER_EPOCH = 5
train_data = train_data.take(STEPS_PER_EPOCH)
test_data = test_data.take(STEPS_PER_EPOCH)
image_batch, label_batch = next(iter(train_data))
2021-12-08 17:15:01.637157: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Utilisez l'itération Python régulière pour itérer sur les données d'entraînement qui tiennent dans la mémoire. Sinon, tf.data.Dataset est le meilleur moyen de diffuser des données d'entraînement à partir du disque. Les ensembles de données sont des itérables (et non des itérateurs) et fonctionnent comme les autres itérables Python dans une exécution hâtive. Vous pouvez utiliser pleinement les fonctionnalités de prélecture/diffusion asynchrones des ensembles de données en enveloppant votre code dans tf.function , qui remplace l'itération Python par les opérations de graphe équivalentes à l'aide d'AutoGraph.
@tf.function
def train(model, dataset, optimizer):
for x, y in dataset:
with tf.GradientTape() as tape:
# training=True is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
prediction = model(x, training=True)
loss = loss_fn(prediction, y)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Si vous utilisez l'API Keras Model.fit , vous n'aurez pas à vous soucier de l'itération de l'ensemble de données.
model.compile(optimizer=optimizer, loss=loss_fn)
model.fit(dataset)
Utiliser les boucles d'entraînement Keras
Si vous n'avez pas besoin d'un contrôle de bas niveau de votre processus d'entraînement, il est recommandé d'utiliser les méthodes d' fit
, d' evaluate
et de predict
intégrées de Keras. Ces méthodes fournissent une interface uniforme pour former le modèle quelle que soit l'implémentation (séquentielle, fonctionnelle ou sous-classée).
Les avantages de ces méthodes incluent :
- Ils acceptent les tableaux Numpy, les générateurs Python et
tf.data.Datasets
. - Ils appliquent automatiquement la régularisation et les pertes d'activation.
- Ils prennent en charge tf.distribute où le code de formation reste le même quelle que soit la configuration matérielle .
- Ils prennent en charge les callables arbitraires comme les pertes et les métriques.
- Ils prennent en charge les rappels tels que tf.keras.callbacks.TensorBoard et les rappels personnalisés.
- Ils sont performants, utilisant automatiquement les graphes TensorFlow.
Voici un exemple d'entraînement d'un modèle à l'aide d'un Dataset
. Pour plus de détails sur la façon dont cela fonctionne, consultez les didacticiels .
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
# Model is the full model w/o custom layers
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.fit(train_data, epochs=NUM_EPOCHS)
loss, acc = model.evaluate(test_data)
print("Loss {}, Accuracy {}".format(loss, acc))
Epoch 1/5
5/5 [==============================] - 9s 7ms/step - loss: 1.5762 - accuracy: 0.4938
Epoch 2/5
2021-12-08 17:15:11.145429: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.5087 - accuracy: 0.8969
Epoch 3/5
2021-12-08 17:15:11.559374: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
5/5 [==============================] - 2s 5ms/step - loss: 0.3348 - accuracy: 0.9469
Epoch 4/5
2021-12-08 17:15:13.860407: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
5/5 [==============================] - 0s 5ms/step - loss: 0.2445 - accuracy: 0.9688
Epoch 5/5
2021-12-08 17:15:14.269850: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
5/5 [==============================] - 0s 6ms/step - loss: 0.2006 - accuracy: 0.9719
2021-12-08 17:15:14.717552: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
5/5 [==============================] - 1s 4ms/step - loss: 1.4553 - accuracy: 0.5781
Loss 1.4552843570709229, Accuracy 0.578125
2021-12-08 17:15:15.862684: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Personnalisez la formation et écrivez votre propre boucle
Si les modèles Keras fonctionnent pour vous, mais que vous avez besoin de plus de flexibilité et de contrôle de l'étape d'entraînement ou des boucles d'entraînement externes, vous pouvez mettre en œuvre vos propres étapes d'entraînement ou même des boucles d'entraînement entières. Consultez le guide Keras sur la personnalisation de l' fit pour en savoir plus.
Vous pouvez également implémenter de nombreuses choses en tant que tf.keras.callbacks.Callback .
Cette méthode présente de nombreux avantages mentionnés précédemment , mais vous donne le contrôle de l'étape du train et même de la boucle extérieure.
Une boucle d'entraînement standard comporte trois étapes :
- Itérez sur un générateur Python ou tf.data.Dataset pour obtenir des lots d'exemples.
- Utilisez tf.GradientTape pour collecter les dégradés.
- Utilisez l'un des tf.keras.optimizers pour appliquer des mises à jour de poids aux variables du modèle.
Rappelles toi:
- Incluez toujours un argument de
training
sur la méthode d'call
des couches et modèles sous-classés. - Assurez-vous d'appeler le modèle avec l'argument d'
training
défini correctement. - Selon l'utilisation, les variables de modèle peuvent ne pas exister tant que le modèle n'est pas exécuté sur un lot de données.
- Vous devez gérer manuellement des choses comme les pertes de régularisation pour le modèle.
Il n'est pas nécessaire d'exécuter des initialiseurs de variables ou d'ajouter des dépendances de contrôle manuel. tf.function gère pour vous les dépendances de contrôle automatique et l'initialisation des variables lors de la création.
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(32, 3, activation='relu',
kernel_regularizer=tf.keras.regularizers.l2(0.02),
input_shape=(28, 28, 1)),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dropout(0.1),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.001)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
for epoch in range(NUM_EPOCHS):
for inputs, labels in train_data:
train_step(inputs, labels)
print("Finished epoch", epoch)
2021-12-08 17:15:16.714849: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Finished epoch 0
2021-12-08 17:15:17.097043: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Finished epoch 1
2021-12-08 17:15:17.502480: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Finished epoch 2
2021-12-08 17:15:17.873701: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Finished epoch 3
Finished epoch 4
2021-12-08 17:15:18.344196: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Tirez parti de tf.function avec le flux de contrôle Python
tf.function fournit un moyen de convertir le flux de contrôle dépendant des données en équivalents en mode graphique comme tf.cond et tf.while_loop .
Un endroit commun où le flux de contrôle dépendant des données apparaît est dans les modèles de séquence. tf.keras.layers.RNN enveloppe une cellule RNN, vous permettant de dérouler la récurrence de manière statique ou dynamique. Par exemple, vous pouvez réimplémenter le déroulement dynamique comme suit.
class DynamicRNN(tf.keras.Model):
def __init__(self, rnn_cell):
super(DynamicRNN, self).__init__(self)
self.cell = rnn_cell
@tf.function(input_signature=[tf.TensorSpec(dtype=tf.float32, shape=[None, None, 3])])
def call(self, input_data):
# [batch, time, features] -> [time, batch, features]
input_data = tf.transpose(input_data, [1, 0, 2])
timesteps = tf.shape(input_data)[0]
batch_size = tf.shape(input_data)[1]
outputs = tf.TensorArray(tf.float32, timesteps)
state = self.cell.get_initial_state(batch_size = batch_size, dtype=tf.float32)
for i in tf.range(timesteps):
output, state = self.cell(input_data[i], state)
outputs = outputs.write(i, output)
return tf.transpose(outputs.stack(), [1, 0, 2]), state
lstm_cell = tf.keras.layers.LSTMCell(units = 13)
my_rnn = DynamicRNN(lstm_cell)
outputs, state = my_rnn(tf.random.normal(shape=[10,20,3]))
print(outputs.shape)
(10, 20, 13)
Lisez le guide des tf.function pour plus d'informations.
Métriques et pertes de style nouveau
Les métriques et les pertes sont à la fois des objets qui fonctionnent avec impatience et dans tf.function s.
Un objet loss est appelable et attend ( y_true
, y_pred
) comme arguments :
cce = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
cce([[1, 0]], [[-1.0,3.0]]).numpy()
4.01815
Utiliser des métriques pour collecter et afficher des données
Vous pouvez utiliser tf.metrics pour agréger les données et tf.summary pour consigner les résumés et les rediriger vers un rédacteur à l'aide d'un gestionnaire de contexte. Les résumés sont émis directement au rédacteur, ce qui signifie que vous devez fournir la valeur du step
au site d'appel.
summary_writer = tf.summary.create_file_writer('/tmp/summaries')
with summary_writer.as_default():
tf.summary.scalar('loss', 0.1, step=42)
Utilisez tf.metrics pour agréger les données avant de les enregistrer sous forme de résumés. Les métriques sont avec état ; ils accumulent des valeurs et renvoient un résultat cumulé lorsque vous appelez la méthode result
(comme Mean.result ). Effacez les valeurs accumulées avec Model.reset_states .
def train(model, optimizer, dataset, log_freq=10):
avg_loss = tf.keras.metrics.Mean(name='loss', dtype=tf.float32)
for images, labels in dataset:
loss = train_step(model, optimizer, images, labels)
avg_loss.update_state(loss)
if tf.equal(optimizer.iterations % log_freq, 0):
tf.summary.scalar('loss', avg_loss.result(), step=optimizer.iterations)
avg_loss.reset_states()
def test(model, test_x, test_y, step_num):
# training=False is only needed if there are layers with different
# behavior during training versus inference (e.g. Dropout).
loss = loss_fn(model(test_x, training=False), test_y)
tf.summary.scalar('loss', loss, step=step_num)
train_summary_writer = tf.summary.create_file_writer('/tmp/summaries/train')
test_summary_writer = tf.summary.create_file_writer('/tmp/summaries/test')
with train_summary_writer.as_default():
train(model, optimizer, dataset)
with test_summary_writer.as_default():
test(model, test_x, test_y, optimizer.iterations)
Visualisez les résumés générés en faisant pointer TensorBoard vers le répertoire du journal des résumés :
tensorboard --logdir /tmp/summaries
Utilisez l'API tf.summary pour écrire des données récapitulatives à visualiser dans TensorBoard. Pour plus d'informations, lisez le guide tf.summary .
# Create the metrics
loss_metric = tf.keras.metrics.Mean(name='train_loss')
accuracy_metric = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
@tf.function
def train_step(inputs, labels):
with tf.GradientTape() as tape:
predictions = model(inputs, training=True)
regularization_loss=tf.math.add_n(model.losses)
pred_loss=loss_fn(labels, predictions)
total_loss=pred_loss + regularization_loss
gradients = tape.gradient(total_loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
# Update the metrics
loss_metric.update_state(total_loss)
accuracy_metric.update_state(labels, predictions)
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
accuracy_metric.reset_states()
for inputs, labels in train_data:
train_step(inputs, labels)
# Get the metric results
mean_loss=loss_metric.result()
mean_accuracy = accuracy_metric.result()
print('Epoch: ', epoch)
print(' loss: {:.3f}'.format(mean_loss))
print(' accuracy: {:.3f}'.format(mean_accuracy))
2021-12-08 17:15:19.339736: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Epoch: 0
loss: 0.142
accuracy: 0.991
2021-12-08 17:15:19.781743: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Epoch: 1
loss: 0.125
accuracy: 0.997
2021-12-08 17:15:20.219033: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Epoch: 2
loss: 0.110
accuracy: 0.997
2021-12-08 17:15:20.598085: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Epoch: 3
loss: 0.099
accuracy: 0.997
Epoch: 4
loss: 0.085
accuracy: 1.000
2021-12-08 17:15:20.981787: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
Noms des métriques Keras
Les modèles Keras sont cohérents quant à la gestion des noms de métriques. Lorsque vous transmettez une chaîne dans la liste des métriques, cette chaîne exacte est utilisée comme name
de la métrique . Ces noms sont visibles dans l'objet historique renvoyé par model.fit
et dans les journaux transmis à keras.callbacks . est défini sur la chaîne que vous avez transmise dans la liste des métriques.
model.compile(
optimizer = tf.keras.optimizers.Adam(0.001),
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics = ['acc', 'accuracy', tf.keras.metrics.SparseCategoricalAccuracy(name="my_accuracy")])
history = model.fit(train_data)
5/5 [==============================] - 1s 5ms/step - loss: 0.0963 - acc: 0.9969 - accuracy: 0.9969 - my_accuracy: 0.9969
2021-12-08 17:15:21.942940: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to dataset.cache().take(k).repeat()
. You should use dataset.take(k).cache().repeat()
instead.
history.history.keys()
dict_keys(['loss', 'acc', 'accuracy', 'my_accuracy'])
Débogage
Utilisez une exécution rapide pour exécuter votre code étape par étape afin d'inspecter les formes, les types de données et les valeurs. Certaines API, comme tf.function , tf.keras , etc. sont conçues pour utiliser l'exécution de Graph, pour les performances et la portabilité. Lors du débogage, utilisez tf.config.run_functions_eagerly(True) pour utiliser une exécution rapide dans ce code.
Par example:
@tf.function
def f(x):
if x > 0:
import pdb
pdb.set_trace()
x = x + 1
return x
tf.config.run_functions_eagerly(True)
f(tf.constant(1))
>>> f()
-> x = x + 1
(Pdb) l
6 @tf.function
7 def f(x):
8 if x > 0:
9 import pdb
10 pdb.set_trace()
11 -> x = x + 1
12 return x
13
14 tf.config.run_functions_eagerly(True)
15 f(tf.constant(1))
[EOF]
Cela fonctionne également à l'intérieur des modèles Keras et d'autres API qui prennent en charge l'exécution rapide :
class CustomModel(tf.keras.models.Model):
@tf.function
def call(self, input_data):
if tf.reduce_mean(input_data) > 0:
return input_data
else:
import pdb
pdb.set_trace()
return input_data // 2
tf.config.run_functions_eagerly(True)
model = CustomModel()
model(tf.constant([-2, -4]))
>>> call()
-> return input_data // 2
(Pdb) l
10 if tf.reduce_mean(input_data) > 0:
11 return input_data
12 else:
13 import pdb
14 pdb.set_trace()
15 -> return input_data // 2
16
17
18 tf.config.run_functions_eagerly(True)
19 model = CustomModel()
20 model(tf.constant([-2, -4]))
Remarques:
- Les méthodes tf.keras.Model telles que
fit
,predict
etevaluate
s'exécutent sous forme de graphiques avec tf.function sous le capot. - Lors de l'utilisation tf.keras.Model.compile , définissez
run_eagerly = True
pour empêcher la logique duModel
d'être enveloppée dans une tf.function . - Utilisez tf.data.experimental.enable_debug_mode pour activer le mode débogage pour tf.data . Lisez la documentation de l' API pour plus de détails.
Ne gardez pas tf.Tensors
dans vos objets
Ces objets tenseurs peuvent être créés soit dans une tf.function soit dans le contexte impatient, et ces tenseurs se comportent différemment. Utilisez toujours tf.Tensor s uniquement pour les valeurs intermédiaires.
Pour suivre l'état, utilisez tf.Variable s car ils sont toujours utilisables dans les deux contextes. Lisez le guide tf.Variable pour en savoir plus.
Ressources et lectures complémentaires
- Lisez les guides et tutoriels TF2 pour en savoir plus sur l'utilisation de TF2.
- Si vous utilisiez auparavant TF1.x, il est fortement recommandé de migrer votre code vers TF2. Lisez les guides de migration pour en savoir plus.