Tensorflow ResNet 50 Optimization Tutorial — AWS Neuron Documentation (original) (raw)

Tensorflow ResNet 50 Optimization Tutorial#

Note: this tutorial runs on tensorflow-neuron 1.x only#

Introduction:#

In this tutorial we provide three main sections:

Verify that this Jupyter notebook is running the Python kernel environment that was set up according to the Tensorflow Installation Guide. You can select the Kernel from the “Kernel -> Change Kernel” option on the top of this Jupyter notebook page.

Install Dependencies#

!pip install pillow requests # Necessary for loading images !pip install tensorflow_neuron==1.15.5.2.8.9.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com/ !pip install neuron_cc==1.13.5.0 --extra-index-url=https://pip.repos.neuron.amazonaws.com

Compile#

The following example shows how to compile a FP16 ResNet50 network using various batching parameters to find the optimal solution. On inf1.6xlarge, run through the following steps to get a optimized Resnet 50 model. First, extract Keras ResNet50 FP32 (resnet50_fp32_keras.pb will be generated):

import re import argparse import tensorflow as tf import numpy as np

from tensorflow.keras.applications.resnet50 import ResNet50 from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

from google.protobuf import text_format import tensorflow.python.saved_model

set Keras global configurations

tf.keras.backend.set_learning_phase(0) tf.keras.backend.set_image_data_format('channels_last')

float_type = 'float32' float_type2 = 'fp32' tf.keras.backend.set_floatx(float_type)

load pre-trained model using Keras

model_name = 'resnet50_%s_keras'%float_type2 model = ResNet50(weights='imagenet')

various save files

frozen_file = model_name + '.pb' opt_file = model_name + '_opt.pb'

obtain parameters

model_input = model.input.name.replace(':0', '') model_output = model.output.name.replace(':0', '') batch, height, width, channels = model.input.shape

print ("model, frozen file, optimized file, input size, input node, output node,") print ("%s, %s, %s, %dx%dx%d, %s, %s" %(model_name, frozen_file, opt_file, width, height, channels, model_input, model_output) )

obtain the TF session

sess = tf.compat.v1.keras.backend.get_session()

save checkpoint files for freeze_graph

ckpt_file = '/tmp/' + model_name + '/' + model_name + '.ckpt' graph_file = '/tmp/' + model_name + '/' + model_name + '.pb' tf.compat.v1.train.Saver().save(sess, ckpt_file) tf.io.write_graph(sess.graph.as_graph_def(), logdir='.', name=graph_file, as_text=False)

print(model_output) with tf.compat.v1.Session(graph=tf.Graph()) as sess: saver = tf.compat.v1.train.import_meta_graph(ckpt_file + '.meta') saver.restore(sess, ckpt_file) output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( sess, tf.compat.v1.get_default_graph().as_graph_def(), [model_output]) output_graph_def = tf.compat.v1.graph_util.remove_training_nodes( output_graph_def, protected_nodes=[model_output]) with open(frozen_file, 'wb') as f: f.write(output_graph_def.SerializeToString())

Optimize the extracted Keras ResNet50 FP32 graph for inference before casting (resnet50_fp32_keras_opt.pb will be generated) with the following transformations to the graph:

import copy import string

from google.protobuf import text_format from tensorflow.core.framework import node_def_pb2 from tensorflow.core.framework import attr_value_pb2 from tensorflow.python.framework import tensor_util from tensorflow.tools.graph_transforms import TransformGraph

def clear_input(node): for i in range(len(node.input)): node.input.pop()

def replace_name(node, name): node.name = name

def replace_input(node, input_name, new_name):

node.input.replace(input_name, new_name)

temp = [] for i in node.input: temp.extend([new_name if i == input_name else i]) clear_input(node) for i in temp: node.input.extend([i])

def swap_names(node1, node2): temp = node2.name node2.name = node1.name node1.name = temp

def get_const_node(const_node_name, const_by_name): name = re.sub("/read$", "", const_node_name) return const_by_name[name]

def get_const_ndarray(const_node_name, const_by_name): name = re.sub("/read$", "", const_node_name) node = const_by_name[name] return tf.make_ndarray(node.attr.get("value").tensor)

def adjust_bias_values(bias_node, fbn_node, const_by_name): bias_val = get_const_ndarray(bias_node.input[1], const_by_name) gamma_val = get_const_ndarray(fbn_node.input[1], const_by_name) mean_val = get_const_ndarray(fbn_node.input[3], const_by_name) variance_val = get_const_ndarray(fbn_node.input[4], const_by_name) new_bias = bias_val * gamma_val / np.sqrt(variance_val) new_tensor = tensor_util.make_tensor_proto(new_bias, new_bias.dtype, new_bias.shape) bias_const_node = get_const_node(bias_node.input[1], const_by_name) bias_const_node.attr["value"].CopyFrom(attr_value_pb2.AttrValue(tensor=new_tensor))

def MoveBiasAddAfterFusedBatchNorm(graphdef): """fold_batch_norm function of TransformGraph is unable to fold Keras ResNet50 because of BiasAdd between Conv2D and FusedBatchNorm (BiasAdd is not needed if FusedBatchNorm is used, but it exists in Keras ResNet50). Here, we move BiasAdd to after FusedBatchNorm, and adjust bias value by gamma/sqrt(variance). """ sess = tf.compat.v1.Session(graph=tf.import_graph_def(graphdef)) output_graph_def = tf.compat.v1.GraphDef() node_by_name = {} const_by_name = {} for node in graphdef.node: # Hack: use FusedBatchNormV2 so fold_batch_norm can recognize if node.op == "FusedBatchNormV3": node.op = "FusedBatchNorm" del(node.attr["U"]) #import pdb; pdb.set_trace() copied_node = node_def_pb2.NodeDef() copied_node.CopyFrom(node) node_by_name[node.name] = copied_node skip_add_node = False # Switch Mul/BiasAdd in Keras RN50 so fold_batch_norm transform would work if node.op == "Const": const_by_name[node.name] = copied_node elif node.op.startswith("FusedBatchNorm"): inputs = node.input for i in inputs: input_node = node_by_name[i] if input_node.op == "BiasAdd": output_graph_def.node.remove(input_node) input_node_input0 = input_node.input[0] # Adjust bias values (multiply by scale/sqrt(variance)) adjust_bias_values(input_node, node, const_by_name) # Hack: swap names to avoid changing input of activation swap_names(copied_node, input_node) # Fix inputs for these two ops replace_input(copied_node, i, input_node_input0) replace_input(input_node, input_node_input0, copied_node.name) # Fix order in node list output_graph_def.node.extend([copied_node]) output_graph_def.node.extend([input_node]) skip_add_node = True # Add maybe-modified nodes if not already done if not skip_add_node: output_graph_def.node.extend([copied_node]) return output_graph_def

def FoldFusedBatchNorm(graph_def): """Optimize training graph for inference: - Remove Identity and CheckNumerics nodes - Fold FusedBatchNorm constants into previous Conv2D weights - Fold other constants - Strip unused nodes - Sort by execution order """ transformed_graph_def = TransformGraph ( graph_def, ['input_1'], ['probs/Softmax'], [ 'add_default_attributes', 'remove_nodes(op=Identity, op=CheckNumerics)', 'fold_constants(ignore_errors=true)', 'fold_batch_norms', 'fold_old_batch_norms', 'strip_unused_nodes', 'sort_by_execution_order', ]) return transformed_graph_def

def load_graph(model_file): graph_def = tf.compat.v1.GraphDef()

with open(model_file, "rb") as f: graph_def.ParseFromString(f.read()) return graph_def

graph_orig = load_graph('resnet50_fp32_keras.pb') graph_mod = MoveBiasAddAfterFusedBatchNorm(graph_orig) graph_mod2 = FoldFusedBatchNorm(graph_mod) with tf.io.gfile.GFile('resnet50_fp32_keras_opt.pb', "wb") as f: f.write(graph_mod2.SerializeToString())

Convert full graph to FP16 (resnet50_fp16_keras_opt.pb will be generated. This will take about a minute.

from tensorflow.core.framework import graph_pb2 from tensorflow.python.platform import gfile

def ConvertFP32ToOther(graphdef): """Converts an FP32 network by casting all constants (weights) to a lower precision floating point type (FP16) and updating the dtypes everywhere.""" cast_type = "float16" sess = tf.Session(graph=tf.import_graph_def(graphdef)) output_graph_def = graph_pb2.GraphDef() dummy_tensor = sess.run(tf.constant([0.1])) dummy_tensor_proto = tensor_util.make_tensor_proto(dummy_tensor,
dtype=cast_type, shape=dummy_tensor.shape) dummy_tensor32 = sess.run(tf.constant([0.1])) dummy_tensor_proto32 = tensor_util.make_tensor_proto(dummy_tensor,
dtype=tf.float32, shape=dummy_tensor.shape) dt_float_type_attr = attr_value_pb2.AttrValue(type=dummy_tensor_proto32.dtype) dt_half_type_attr = attr_value_pb2.AttrValue(type=dummy_tensor_proto.dtype) for node in graphdef.node: output_node = node_def_pb2.NodeDef() output_node.CopyFrom(node) if (node.op == "Const"): if (node.attr["dtype"] == dt_float_type_attr): a = tensor_util.MakeNdarray(node.attr["value"].tensor) a = tf.cast(a, cast_type) a = sess.run(a) output_node.attr["dtype"].CopyFrom(dt_half_type_attr) output_node.attr["value"].CopyFrom( attr_value_pb2.AttrValue( tensor=tensor_util.make_tensor_proto(a,
dtype=cast_type, shape=a.shape))) else: if ("T" in node.attr.keys()): if (output_node.attr["T"] == dt_float_type_attr): output_node.attr["T"].CopyFrom(dt_half_type_attr) if ("Tparams" in node.attr.keys()): if (output_node.attr["Tparams"] == dt_float_type_attr): output_node.attr["Tparams"].CopyFrom(dt_half_type_attr) if ("dtype" in node.attr.keys()): if (node.attr["dtype"] == dt_float_type_attr): output_node.attr["dtype"].CopyFrom(dt_half_type_attr) if ("SrcT" in node.attr.keys()): if (node.attr["SrcT"] == dt_float_type_attr): output_node.attr["SrcT"].CopyFrom(dt_half_type_attr) if ("DstT" in node.attr.keys()): if (node.attr["DstT"] == dt_float_type_attr): output_node.attr["DstT"].CopyFrom(dt_half_type_attr) output_graph_def.node.extend([output_node]) return output_graph_def

def load_graph(model_file): graph_def = tf.GraphDef()

with open(model_file, "rb") as f: graph_def.ParseFromString(f.read())

return graph_def

graph_f32 = load_graph('resnet50_fp32_keras_opt.pb') graph_f16 = ConvertFP32ToOther(graph_f32) output_xformed_graph_name = 'resnet50_fp16_keras_opt.pb' with gfile.GFile(output_xformed_graph_name, "wb") as f: f.write(graph_f16.SerializeToString())

Run the compilation script to sweep through various batch sizes up to 5 and several NeuronCore Group sizes up to 16. The script calls the compilation script pb2sm_compile.py which tries to perform compilation. Some error messages are expected due to known issues (see Known Issues section in the tutorial). If you run all the configurations it will take about 45 minutes.

%%bash #!/usr/bin/env bash

echo "" > full_sweep.log echo "" > full_sweep_results.txt

results=() for b in $(seq 1 5); do for i in 1 2 4 8 12 16; do python pb2sm_compile.py --batch_size=$b --neuroncore-pipeline-cores=$i | tee -a full_sweep.log; results[$b]+=", "tail -1 full_sweep.log done done

head="batch" for i in 1 2 4 8 12 16; do head+=", nc${i}" done echo $head | tee -a full_sweep_results.txt for b in $(seq 1 5); do echo bbb{results[$b]} | tee -a full_sweep_results.txt done

You should see some output like this:

INFO: Compilation finished in 95 seconds with 99.5% operations placed on Inferentia

1

*** Batch size 1, num NeuronCores 2 (input shape: (1, 224, 224, 3), saved model dir: rn50_fp16_compiled_b1_nc2) ***

INFO: Compilation finished in 95 seconds with 99.5% operations placed on Inferentia

1

*** Batch size 1, num NeuronCores 4 (input shape: (1, 224, 224, 3), saved model dir: rn50_fp16_compiled_b1_nc4) ***

INFO: Compilation finished in 95 seconds with 99.5% operations placed on Inferentia

1

... (outputs removed)

*** Batch size 5, num NeuronCores 16 (input shape: (5, 224, 224, 3), saved model dir: rn50_fp16_compiled_b5_nc16) ***

ERROR: Compilation finished in 120 seconds with less than 50% operations placed on Inferentia (0.0%)

INFO: Retry compilation without static weights

ERROR: Retry compilation finished in 137 seconds with less than 50% operations placed on Inferentia (0.0%)

0

The file full_sweep_results.txt shows a summary of the sweep results with latest Neuron 1/27/20 release (0 means compilation unsuccessful and 0 ops mapped to Inferentia, 1 means most ops mapped to Inferentia and non-static weights, 2 means most ops mapped to Inferentia and using static weights):

batch, nc1, nc2, nc4, nc8, nc12, nc16 1, 1, 1, 1, 2, 2, 2 2, 1, 1, 0, 1, 2, 2 3, 1, 1, 1, 1, 1, 1 4, 1, 1, 0, 1, 1, 1 5, 1, 1, 0, 0, 0, 0

Inference#

Run inference over different batch sizes and Neuroncore groups to obtain throughput and latency results for ResNet50. To apply dynamic batching, the user batch size is set to 10x the compiled batch size, in order to keep input queue full and to amortize framework-to-Neuron overhead.

Note: The results are based on the Neuron v1.12.2 (Mar 4th 2021) release. These will continue improve as we increase Neuron performance.

!cd ~/aws-neuron-sdk/src/examples/tensorflow/keras_resnet50/ !echo "" > batch.log !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=1 | tee -a batch.log; done !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=2 | tee -a batch.log; done !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=4 | tee -a batch.log; done !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=8 | tee -a batch.log; done !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=12 | tee -a batch.log; done !for i in (seq15);dopythoninferresnet50kerasloadtest.py−−batchsize=(seq 1 5); do python infer_resnet50_keras_loadtest.py --batch_size=(seq15);dopythoninferresnet50kerasloadtest.pybatchsize=i --neuroncore-pipeline-cores=16 | tee -a batch.log; done

The file batch.log now contains the results for each batch size. We can look at the throughput values to get an idea of which models are performing well. The output should look something like this:

The model best model configuration for throughput (if you run on an Inf1.6xlarge as suggested in the tutorial) is batch size 5 NeuronCore group size 2. Increasing batch size usually helps to increase throughput (up to a certain extent).

*** Compiled batch size 5, user batch size 10, num NeuronCores 2 (input shape: (10, 224, 224, 3), saved model dir: ./rn50_fp16_compiled_b5_nc2/1) ***

Instance type inf1.6xlarge with 16 NeuronCores NEURON_MAX_NUM_INFERS (env): 5 NEURONCORE_GROUP_SIZES (env): 2,2,2,2,2,2,2,2 NUM THREADS: 16 NUM_LOOPS_PER_THREAD: 400 USER_BATCH_SIZE: 10 Throughput values collected: [10680, 10700, 10660]

(rest of outputs removed)

Known Issues#

Unable to compile with batch and num NeuronCores combination#

For some combination of batch and number of NeuronCores setting, you may see an internal compiler error as below. Please see the sweep result above for Neuron 1/27/20 release. Furthermore, if using auto-casting to bfloat16 from FP32 network and batch size is larger than 1 would result in the same error.

INFO:tensorflow:fusing subgraph neuron_op_a73aed4b95ca5d5b with neuron-cc; log file is at /home/ubuntu/keras_fp16_benchmarking_db/compiler_workdir/neuron_op_a73aed4b95ca5d5b/graph_def.neuron-cc.log WARNING:tensorflow:Failed to fuse subgraph neuron_op_a73aed4b95ca5d5b with '/home/ubuntu/test_venv/bin/neuron-cc compile /home/ubuntu/keras_fp16_benchmarking_db/compiler_workdir/neuron_op_a73aed4b95ca5d5b/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /home/ubuntu/keras_fp16_benchmarking_db/compiler_workdir/neuron_op_a73aed4b95ca5d5b/graph_def.neff --io-config "{"inputs": {"input_10/_0:0": [[6, 224, 224, 3], "float16"]}, "outputs": ["probs/Softmax:0"]}" --batching_en --rematerialization_en --sb_size 120 --spill_dis --enable-replication True' WARNING:tensorflow:neuron-cc error message: WARNING:tensorflow:01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: *************************************************************** 01/23/2020 01:15:40 AM ERROR [neuron-cc]: An Internal Compiler Error has occurred 01/23/2020 01:15:40 AM ERROR [neuron-cc]: *************************************************************** 01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Please contact Customer Support and provide the following details. 01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Error message: Non-zero exit status (134) for command: /home/ubuntu/test_venv/lib/python3.6/site-packages/neuroncc/starfish/bin/list_sch --hhir hh-tr-external-move.json --verbose 0 --sb_size 120 --arith_intensity_target 2300 --sb_watermark_low 0.250000 --sb_watermark_high 0.750000 --sb_size_tol 1 --alloc simple1 --alloc_opt --depth_diff 0.100000 --verbose_start_cycle 0 --tt_dist --mm_meet_cnt 1 --load_speed_factor 0.300000 --schir sch_tmp.json --spill_depth_limit 5 --spill_dis --true_dep --mm_order --batching_en --rematerialization_en 01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Error class: CompilerInternalError 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Error location: job.Scheduler.3 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Command line: /home/ubuntu/test_venv/bin/neuron-cc compile /home/ubuntu/keras_fp16_benchmarking_db/compiler_workdir/neuron_op_a73aed4b95ca5d5b/graph_def.pb --framework TENSORFLOW --pipeline compile SaveTemps --output /home/ubuntu/keras_fp16_benchmarking_db/compiler_workdir/neuron_op_a73aed4b95ca5d5b/graph_def.neff --io-config '{"inputs": {"input_10/_0:0": [[6, 224, 224, 3], "float16"]}, "outputs": ["probs/Softmax:0"]}' --batching_en --rematerialization_en --sb_size 120 --spill_dis --enable-replication True 01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Internal details: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: File "neuroncc/driver/Job.py", line 207, in neuroncc.driver.Job.runSingleInputFn 01/23/2020 01:15:40 AM ERROR [neuron-cc]: File "neuroncc/driver/jobs/Scheduler.py", line 58, in neuroncc.driver.jobs.Scheduler.Scheduler.runSingleInput 01/23/2020 01:15:40 AM ERROR [neuron-cc]: File "neuroncc/driver/Job.py", line 145, in neuroncc.driver.Job.Job.shellCommand 01/23/2020 01:15:40 AM ERROR [neuron-cc]: 01/23/2020 01:15:40 AM ERROR [neuron-cc]: Version information: 01/23/2020 01:15:41 AM ERROR [neuron-cc]: Neuron Compiler version 1.0.6632.0+6001610955 01/23/2020 01:15:41 AM ERROR [neuron-cc]: 01/23/2020 01:15:41 AM ERROR [neuron-cc]: HWM version 1.0.839.0-6001300654 01/23/2020 01:15:41 AM ERROR [neuron-cc]: NEFF version 0.6 01/23/2020 01:15:41 AM ERROR [neuron-cc]: TVM version 1.0.1589.0+6001610955 01/23/2020 01:15:41 AM ERROR [neuron-cc]: NumPy version 1.16.5 01/23/2020 01:15:41 AM ERROR [neuron-cc]: MXNet not available 01/23/2020 01:15:41 AM ERROR [neuron-cc]: TF version 1.15.0 01/23/2020 01:15:41 AM ERROR [neuron-cc]: