PyTorch Neuron for Trainium Hugging Face BERT MRPC task finetuning using Hugging Face Trainer API — AWS Neuron Documentation (original) (raw)

This document is relevant for: Trn1, Trn2

PyTorch Neuron for Trainium Hugging Face BERT MRPC task finetuning using Hugging Face Trainer API#

Note

Please use Hugging Face Optimum-Neuronhttps://huggingface.co/docs/optimum-neuron/index for best coverage and support of Hugging Face models running on Trainium and Inferentia devices.

In this tutorial, we show how to run a Hugging Face script that uses Hugging Face Trainer API to do fine-tuning on Trainium. The example follows the text-classification examplewhich fine-tunes BERT-base model for sequence classification on the GLUE benchmark.

Table of Contents

Note

Logs used in tutorials do not present latest performance numbers

For latest performance numbers visit Neuron performance

Setup and compilation#

Before running the tutorial please follow the installation instructions at:

Install PyTorch Neuron on Trn1

Please set the storage of instance to 512GB or more if you also want to run through the BERT pretraining and GPT pretraining tutorials.

For all the commands below, make sure you are in the virtual environment that you have created above before you run the commands:

source ~/aws_neuron_venv_pytorch/bin/activate

First we install a recent version of HF transformers, scikit-learn and evaluate packages in our environment as well as download the source matching the installed version. In this example, we use the text classification example from HF transformers source:

export HF_VER=4.44.0 pip install -U transformers==$HF_VER datasets evaluate scikit-learn cd ~/ git clone https://github.com/huggingface/transformers --branch v$HF_VER cd ~/transformers/examples/pytorch/text-classification

Single-worker training#

We will run MRPC task fine-tuning following the example in README.md located in the path ~/transformers/examples/pytorch/text-classification. In this part of the tutorial we will use the Hugging Face model hub’s pretrained bert-large-uncased model.

Note

If you are using older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0, please see section Older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0 for necessary workarounds.

We use BF16 mixed-precision casting using trainer API --bf16 option and compiler flag --model-type=transformer to enable best performance. We also launch the run_glue.py script with torchrun using --nproc_per_node=N option to specify the number of workers. Here we start of with 1 worker.

Note

With transformers version 4.44 and up, please use torchrun even for one worker (--nproc_per_node=1) to avoid execution hang.

First, paste the following script into your terminal to create a “run.sh” file and change it to executable:

tee run.sh > /dev/null <<EOF #!/usr/bin/env bash set -eExuo export TASK_NAME=mrpc export NEURON_CC_FLAGS="--model-type=transformer" NEURON_RT_STOCHASTIC_ROUNDING_EN=1 torchrun --nproc_per_node=1 ./run_glue.py \ --model_name_or_path bert-large-uncased \ --task_name $TASK_NAME \ --do_train \ --do_eval \ --bf16 \ --max_seq_length 128 \ --per_device_train_batch_size 8 \ --learning_rate 2e-5 \ --num_train_epochs 5 \ --save_total_limit 1 \ --overwrite_output_dir \ --output_dir /tmp/$TASK_NAME/ |& tee log_run EOF

chmod +x run.sh

We optionally precompile the model and training script using neuron_parallel_compile to warm up the persistent graph cache (Neuron Cache) such that the actual run has fewer compilations (faster run time):

neuron_parallel_compile ./run.sh

Please ignore the results from this precompile run as it is only for extracting and compiling the XLA graphs.

Note

With both train and evaluation options (--do_train and --do_eval), you will encounter harmless errorValueError: Target is multiclass but average='binary' when using neuron_parallel_compile.

Precompilation is optional and only needed to be done once unless hyperparameters such as batch size are modified. After the optional precompilation, the actual run will be faster with minimal additional compilations.

If precompilation was not done, the first execution of ./run.sh will be slower due to serial compilations. Rerunning the same script a second time would show quicker execution as the compiled graphs will be already cached in persistent cache.

Multi-worker data-parallel training#

The above script would run one worker on one Logical NeuronCore. To run on multiple Logical NeuronCores in data-parallel configuration, launch the run_glue.py script with torchrun using --nproc_per_node=N option to specify the number of workers (N=2 for trn1.2xlarge, and N=2, 8, or 32 for trn1.32xlarge).

Note

If you are using older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0, please see section Older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0 for necessary workarounds.

The following example runs 2 workers. Paste the following script into your terminal to create a “run_2w.sh” file and change it to executable:

tee run_2w.sh > /dev/null <<EOF #!/usr/bin/env bash set -eExuo export TASK_NAME=mrpc export NEURON_CC_FLAGS="--model-type=transformer" NEURON_RT_STOCHASTIC_ROUNDING_EN=1 torchrun --nproc_per_node=2 ./run_glue.py \ --model_name_or_path bert-large-uncased \ --task_name $TASK_NAME \ --do_train \ --do_eval \ --bf16 \ --max_seq_length 128 \ --per_device_train_batch_size 8 \ --learning_rate 2e-5 \ --num_train_epochs 5 \ --save_total_limit 1 \ --overwrite_output_dir \ --output_dir /tmp/$TASK_NAME/ |& tee log_run_2w EOF

chmod +x run_2w.sh

Again, we optionally precompile the model and training script using neuron_parallel_compile to warm up the persistent graph cache (Neuron Cache), ignoring the results from this precompile run as it is only for extracting and compiling the XLA graphs:

neuron_parallel_compile ./run_2w.sh

Precompilation is optional and only needed to be done once unless hyperparameters such as batch size are modified. After the optional precompilation, the actual run will be faster with minimal additional compilations.

During run, you will now notice that the “Total train batch size” is now 16 and the “Total optimization steps” is now half the number for one worker training.

Converting BERT pretrained checkpoint to Hugging Face pretrained model format#

If you have a pretrained checkpoint (i.e., from the BERT phase 2 pretraining tutorial), you can run the script below (saved as “convert.py”) to convert BERT pretrained saved checkpoint to Hugging Face pretrained model format. An example phase 2 pretrained checkpoint can be downloaded from s3://neuron-s3/training_checkpoints/pytorch/dp_bert_large_hf_pretrain/ckpt_29688.pt. Note that here we also use the bert-large-uncased model configuration to match the BERT-Large model trained following BERT phase 2 pretraining tutorial.

tee convert.py > /dev/null <<EOF import os import sys import argparse import torch import transformers from transformers import ( BertForPreTraining, ) import torch_xla.core.xla_model as xm from transformers.utils import check_min_version from transformers.utils.versions import require_version

if name == 'main': parser = argparse.ArgumentParser() parser.add_argument('--model_name', type=str, default='bert-large-uncased', help="Path to model identifier from huggingface.co/models") parser.add_argument('--output_saved_model_path', type=str, default='./hf_saved_model', help="Directory to save the HF pretrained model format.") parser.add_argument('--checkpoint_path', type=str, required=True, help="Path to pretrained checkpoint which needs to be converted to a HF pretrained model format") args = parser.parse_args(sys.argv[1:])

model = BertForPreTraining.from_pretrained(args.model_name)
check_point = torch.load(args.checkpoint_path, map_location='cpu')
model.load_state_dict(check_point['model'], strict=False)
model.save_pretrained(args.output_saved_model_path, save_config=True, save_function=xm.save)
print("Done converting checkpoint {} to HuggingFace saved model in directory {}.".format(args.checkpoint_path, args.output_saved_model_path))

EOF

Run the conversion script as:

python convert.py --checkpoint_path ckpt_29688.pt

After conversion, the new Hugging Face pretrained model is stored in the output directory specified by the --output_saved_model_path option which is hf_saved_model by default. You will use this directory in the next step.

Paste the following script into your terminal to create a “run_converted.sh” file and change it to executable: (note that it uses the converted Hugging Face pretrained model in hf_saved_model directory):

tee run_converted.sh > /dev/null <<EOF #!/usr/bin/env bash set -eExuo export TASK_NAME=mrpc export NEURON_CC_FLAGS="--model-type=transformer" NEURON_RT_STOCHASTIC_ROUNDING_EN=1 torchrun --nproc_per_node=2 ./run_glue.py \ --model_name_or_path hf_saved_model \ --tokenizer_name bert-large-uncased \ --task_name $TASK_NAME \ --do_train \ --do_eval \ --bf16 \ --max_seq_length 128 \ --per_device_train_batch_size 8 \ --learning_rate 2e-5 \ --num_train_epochs 5 \ --save_total_limit 1 \ --overwrite_output_dir \ --output_dir /tmp/$TASK_NAME/ |& tee log_run_converted EOF

chmod +x run_converted.sh

If it is the first time running with bert-large-uncased model or if hyperparameters have changed, then the optional one-time precompilation step can save compilation time:

neuron_parallel_compile ./run_converted.sh

If you have run the single worker training in a previous section, then you can skip the precompilation step and just do:

Older versions of transformers <4.27.0 or PyTorch Neuron <1.13.0#

If using older versions of transformers package before 4.27.0 or PyTorch Neuron before 1.13.0, please edit the python script run_glue.py and add the following lines after the Python imports. They set the compiler flag for transformer model type and enable data parallel training using torchrun:

Enable torchrun

import os import torch import torch_xla.distributed.xla_backend from packaging import version from transformers import version, Trainer if version.parse(version) < version.parse("4.26.0") and os.environ.get("WORLD_SIZE"): torch.distributed.init_process_group('xla')

Disable DDP for torchrun

import contextlib if version.parse(version) < version.parse("4.20.0"): def _wrap_model(self, model, training=True): model.no_sync = lambda: contextlib.nullcontext() return model else: def _wrap_model(self, model, training=True, dataloader=None): model.no_sync = lambda: contextlib.nullcontext() return model Trainer._wrap_model = _wrap_model

Workaround for NaNs seen with transformers version >= 4.21.0

https://github.com/aws-neuron/aws-neuron-sdk/issues/593

import transformers if os.environ.get("XLA_USE_BF16") or os.environ.get("XLA_DOWNCAST_BF16"): transformers.modeling_utils.get_parameter_dtype = lambda x: torch.bfloat16

Known issues and limitations#

The following are currently known issues:

import os if os.environ.get("NEURON_EXTRACT_GRAPHS_ONLY", None): import torch.distributed as dist _verify_param_shape_across_processes = lambda process_group, tensors, logger=None: True

The following are resolved issues:

This document is relevant for: Trn1, Trn2