Serving OPT-175B, BLOOM-176B and CodeGen-16B using Alpa — Alpa 0.2.3.dev17 documentation (original) (raw)

This tutorial shows how to setup a serving system to serve one of the largest available pretrained language models OPT-175B. The instructions for other models (BLOOM and CodeGen) are also listed at the end.

👉 Try a live demo at Alpa-OPT Demo 👈

Overview

As a serving system, Alpa offers the following unique advantages:

In this example, we use Alpa to serve the open-source OPT model, supporting all sizes ranging from 125M to 175B. Specifically, Alpa provides:

Note

The pre-trained OPT model weights can be obtained from Metaseq, subject to their license.

Note

You will need at least 350GB GPU memory on your entire cluster to serve the OPT-175B model. For example, you can use 4 x AWS p3.16xlarge instances, which provide 4 (instance) x 8 (GPU/instance) x 16 (GB/GPU) = 512 GB memory.

You can also follow this guide to setup a serving system to serve smaller versions of OPT, such as OPT-66B, OPT-30B, etc. Pick an appropriate size from OPT weight downloading page based on your available resources.

Demo

The code below shows how to use huggingface/transformers interface and Alpa distributed backend for large model inference.

from transformers import AutoTokenizer from llm_serving.model.wrapper import get_model

Load the tokenizer. All OPT models with different sizes share the same tokenizer

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-2.7b") tokenizer.add_bos_token = False

Load the model. Alpa automatically downloads the weights to the specificed path

model = get_model(model_name="alpa/opt-2.7b", path="~/opt_weights/")

Generate

prompt = "Paris is the capital city of"

input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate(input_ids=input_ids, max_length=256, do_sample=True) generated_string = tokenizer.batch_decode(output, skip_special_tokens=True)

print(generated_string)

Requirements

  1. Install Alpa following the installation guide. You can either install by python wheel or build from source.
  2. Install additional requirements for llm_serving:

pip3 install "transformers<=4.23.1" fastapi uvicorn omegaconf jinja2

Install torch corresponding to your CUDA version, e.g., for CUDA 11.3:

pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113

  1. Clone the alpa repo. If you install alpa by python wheel, please clone the alpa repo. If you install from source, you already did this step.

git clone git@github.com:alpa-projects/alpa.git

  1. Install llm_serving package. Go to the examples folder and install the package.

cd alpa/examples pip3 install -e .

Convert Weights Format

The weights of OPT 125M–66B models are publicly available. Huggingface hosts copies of these weights. For OPT 125M–66B, you do not need to download or convert the weights manually. Alpa will automatically download the weights from huggingface to the given path if Alpa cannot find cached weights locally.

The weights of OPT-175B can be got from meta by filling a request form . You then need to manually convert the obtained weights into Alpa format.

Convert OPT-175B weights into Alpa formats

We provide detailed instructions below on how to convert the original OPT-175B weights into Alpa-compatible formats. You can skip this section if you only want to run smaller models.

Note

The procedures below for converting OPT-175B weights will take about 1 hour.

  1. Download and verify the original weights
    First, download Metaseq’s original OPT-175B weights in 992 shards, verify the MD5 of each shard , and put the shards under a folder, say, PATH_TO_992_SHARDS/.
  2. Consolidate the weights from 992 shards into one single checkpoint
    Use the script step_2_consolidate_992_shards_to_singleton.py as:

python3 step_2_consolidate_992_shards_to_singleton.py --read-prefix [PATH_TO_992_SHARDS]/checkpoint_last --save-prefix [PATH_TO_SAVE_CHECKPOINT]

The consolidated checkpoint will be saved at PATH_TO_SAVE_CHECKPOINT as specified in the command.

Note

The above script will require a peak memory (RAM) usage as large as twice of the model size. For example, if you are performing consolidation for the 175B model, it will approximately have a peak memory usage of 175B x 2 bytes x 2 = 700GB. Please make sure your RAM is sufficient to run the script without throwing an OOM exception.

Note

The above script will save the model weights as a single consolidated checkpoint at PATH_TO_SAVE_CHECKPOINT, hence will require at least 350GB disk space available.

  1. Convert the single checkpoint into Alpa-compatible formats
    Alpa ingests weights simply from numpy formats. Use the script step_3_convert_to_numpy_weights.py to convert the single checkpoint into numpy formats:
    python3 step_3_convert_to_numpy_weights.py --ckpt-path PATH_TO_SAVE_CHECKPOINT --output-folder OUTPUT_PATH
    The weights will be saved at the folder OUTPUT_PATH as specified in the command.

Note

The above script also requires 350GB free disk space to write the numpy-formatted weights.

Converted weights for other models

You do not need to download the weights manually for OPT 125M–66B. However, if you have trouble with the automatic downloading or huggingface. We also provide the converted weights for the following models.

Copy Weights to Multiple Nodes

If you want to run the model on multiple nodes, you can use one of the following methods to copy the weights to all nodes.

  1. Put the weights under a shared network file system, so all nodes can access it.
  2. Run the script first on a driver node. The driver node will download the weights to its local disk, but the script will fail later because worker nodes cannot access the weights. You can then manually copy all downloaded weights under path from the driver node to all worker nodes.

Run Generation in the Command Line

The code of this tutorial is under examples/llm_serving.

Start ray on the node

ray start --head
python3 textgen.py --model alpa/opt-2.7b

Launch a Web Server to Serve the OPT Models

We need to run two scripts: one for web server and another for the model serving worker. They will use two ports. The port of the website is defined in the command line and the port of the worker is defined in service/constants.py

Launch the model worker

python3 launch_model_worker.py --model alpa/opt-175b

Launch the website (in a new terminal)

uvicorn launch_website:app --host 0.0.0.0 --port 8001

Then open http://[IP-ADDRESS]:8001 in your browser to try out the model!

There is also a client library which can be used to query the model worker via a python script. Please check test_completions.py for the usage.

Improving Generation Speed

Here are some tips for improving the generation speed.

  1. Batching. Single sequence generation cannot fully utilize the GPU power. Applying batching can greatly boost the performace. See textgen.py for the usage.
  2. Tune the encoder_chunk_sizes argument of get_model. Alpa compiles multiple executables and uses these executables to encode a prompt chunk by chunk. This argument controls the possible chunk sizes. Depending on the length of your prompt, you can try different combinations. For example, if your prompt lengths are around 1000-1500, a good combination is [1, 256, 1024].
  3. Tune parallelization strategy. If you are familiar with alpa, you can tune the method argument of alpa.parallelize and try different parallelization methods.

If you find the generation speed too slow and want to accelerate it, please join Alpa slack and tell us your use cases. We are actively working on improving the performance.

OPT License

The use of the OPT pretrained weights is subject to the Model License by Metaseq.

Other Models (BLOOM)

Alpa also supports BLOOM. You can use commands similar to OPT but with a different model name.

Huggingface/pytorch backend

python3 textgen.py --model bigscience/bloom-560m

Jax backend

python3 textgen.py --model jax/bloom-560m

Alpa backend

python3 textgen.py --model alpa/bloom-560m

Other Models (CodeGen)

Alpa also supports CodeGen. You can use commands similar to OPT but with a different model name.

Huggingface/pytorch backend

python3 codegen.py --model Salesforce/codegen-2B-mono

Alpa backend

python3 codegen.py --model alpa/codegen-2B-mono