TensorFlow API — sagemaker 2.199.0 documentation (original) (raw)

sagemaker

To use the TensorFlow-specific APIs for SageMaker distributed model parallism, you need to add the following import statement at the top of your training script.

import smdistributed.modelparallel.tensorflow as smp

class smp. DistributedModel

A sub-class of the Keras Model class, which defines the model to be partitioned. Model definition is done by sub-classingsmp.DistributedModel class, and implementing the call() method, in the same way as the Keras model sub-classing API. Any operation that is part of the smp.DistributedModel.call() method is subject to partitioning, meaning that every operation placed inside executes in exactly one of the devices (the operations outside run on all devices).

Similar to the regular Keras API, the forward pass is done by directly calling the model object on the input tensors. For example:

predictions = model(inputs)   # model is a smp.DistributedModel object

However, model() calls can only be made inside asmp.step-decorated function.

The outputs from a smp.DistributedModel are available in all ranks, regardless of which rank computed the last operation.

Methods:

save_model(save_path='/opt/ml/model')

Inputs- save_path (string): A path to save an unpartitioned model with latest training weights.

Saves the entire, unpartitioned model with the latest trained weights to save_path in TensorFlow SavedModel format. Defaults to "/opt/ml/model", which SageMaker monitors to upload the model artifacts to Amazon S3.

smp. partition(index)

Inputs

A context manager which places all operations defined inside into the partition whose ID is equal to index. Whensmp.partition contexts are nested, the innermost context overrides the rest. The index argument must be smaller than the number of partitions.

smp.partition is used in the manual partitioning API; if "auto_partition" parameter is set to True while launching training, then smp.partition contexts are ignored. Any operation that is not placed in any smp.partition context is placed in thedefault_partition, as shown in the following example:

auto_partition: False

default_partition: 0

smp.init() [...] x = tf.constant(1.2)                     # placed in partition 0 with smp.partition(1):     y = tf.add(x, tf.constant(2.3))      # placed in partition 1     with smp.partition(3):         z = tf.reduce_sum(y)             # placed in partition 3

register_post_partition_hook(hook)

Registers a callable hook to be executed after the model is partitioned. This is useful in situations where an operation needs to be executed after the model partition during the first call to smp.step, but before the actual execution of the first forward pass.

@smp.register_post_partition_hook def test_eager(): # All statements here will be executed right after partition but before the first forward pass tf.print("Entered hook through eager context")

class smp. CheckpointManager

A subclass of TensorFlowCheckpointManager, which is used to manage checkpoints. The usage is similar to TensorFlowCheckpointManager.

The following returns a CheckpointManager object.

smp.CheckpointManager(checkpoint,                       directory="/opt/ml/checkpoints",                       max_to_keep=None,                       checkpoint_name="ckpt")

Parameters

Methods:

save()

Saves a new checkpoint in the specified directory. Internally uses tf.train.CheckpointManager.save().

restore()

Restores the latest checkpoint in the specified directory. Internally uses tf.train.CheckpointManager.restore().

Examples:

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model) ckpt_manager = smp.CheckpointManager(checkpoint, max_to_keep=5)  # use /opt/ml/checkpoints

for inputs in train_ds:     loss = train_step(inputs)     # [...]     ckpt_manager.save()  # save a new checkpoint in /opt/ml/checkpoints

for step, inputs in enumerate(train_ds):     if step == 0:         ckpt_manager.restore()     loss = train_step(inputs)