GitHub - salesforce/jaxformer: Minimal library to train LLMs on TPU in JAX with pjit(). (original) (raw)

Jaxformer

JAX library for training of large language models with data and model parallelism based on the pjit() operator on TPU-v3/v4.

Citation

Please cite:

@misc{Jaxformer, title={Jaxformer: A minimal library for training LLMs on TPU}, author={Nijkamp, Erik}, howpublished = {\url{https://github.com/salesforce/jaxformer}}, year={2022} }

Acknowledgments: Ben Wang, James Bradbury, Zak Stone, Bo Pang.

Models

CodeGen

350M

gs://sfr-codegen-research/checkpoints/codegen-350M-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-350M-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-350M-mono/150000

2B

gs://sfr-codegen-research/checkpoints/codegen-2B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-2B-multi/150000
gs://sfr-codegen-research/checkpoints/codegen-2B-mono/100000

6B

gs://sfr-codegen-research/checkpoints/codegen-6B-nl/350000
gs://sfr-codegen-research/checkpoints/codegen-6B-multi/100000
gs://sfr-codegen-research/checkpoints/codegen-6B-mono/140000

Sanity TPU

import jax jax.devices() device_count = jax.device_count() local_device_count = jax.local_device_count() xs = jax.numpy.ones(jax.local_device_count()) r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs) print('global device count:', jax.device_count()) print('local device count:', jax.local_device_count()) print('pmap result:', r)

gcloud compute tpus tpu-vm ssh erik.nijkamp@sfr-erik.nijkamp-tpu-v3-128-us-east1-d-1 --zone=us-east1-d --internal-ip --worker=all --command="pip install 'jax[tpu]==0.3.16' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" gcloud compute tpus tpu-vm scp test.py erik.nijkamp@sfr-erik.nijkamp-tpu-v3-128-us-east1-d-1:/home/erik.nijkamp/ --zone=us-east1-d --internal-ip --worker=all gcloud compute tpus tpu-vm ssh erik.nijkamp@sfr-erik.nijkamp-tpu-v3-128-us-east1-d-1 --zone=us-east1-d --internal-ip --worker=all --command="python3 /home/erik.nijkamp/test.py"

Training

Mode 1: CPU local

brew install python@3.9 apt install --yes python3.9 python3.9-venv

git clone https://:@github.com/salesforce/jaxformer.git/ cd jaxformer

python3.9 -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools pip install -r requirements.txt

python3 -m jaxformer.train --config config/debug_cpu.json

Mode 2: TPU local

gcloud compute tpus list --zone=europe-west4-a

gcloud compute tpus tpu-vm delete sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --quiet

gcloud compute tpus tpu-vm create sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --accelerator-type=v3-8 --version=v2-alpha

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-8-europe-west4-d-1 --zone=europe-west4-a --project --worker 0

export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials//adc.json export GCLOUD_PROJECT=

git clone https://:@github.com/salesforce/jaxformer.git/ cd jaxformer

./jaxformer/env/env_tpu_v3.sh pip install -r requirements.txt

source .venv/bin/activate

python3 import jax jax.devices() quit()

python3 -m jaxformer.train --config config/debug_tpu_v3_8.json

Mode 3: TPU remote

gcloud beta compute --project= instances create sfr--cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account= --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any

gcloud beta compute ssh sfr--cpu-small-us-east1-d-1 --project= --zone=us-east1-d

sudo apt update sudo apt install --yes git screen python3.9 python3.9-venv

screen -S codegen_350M_nl

curl https://sdk.cloud.google.com | bash source ~/.bashrc gcloud init ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''

export WANDB_API_KEY= export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials//adc.json export GCLOUD_PROJECT=

git clone https://:@github.com/salesforce/jaxformer.git/ cd jaxformer

python3.9 -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools pip install -r requirements.txt

python3 -m jaxformer.train --config config/codegen_350M_nl.json

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0

Fine-tuning

TPU fine-tune

gcloud beta compute --project= instances create sfr--cpu-small-us-east1-d-1 --zone=us-east1-d --machine-type=e2-standard-4 --network-tier=PREMIUM --maintenance-policy=MIGRATE --service-account= --scopes=https://www.googleapis.com/auth/devstorage.read_only,https://www.googleapis.com/auth/logging.write,https://www.googleapis.com/auth/monitoring.write,https://www.googleapis.com/auth/servicecontrol,https://www.googleapis.com/auth/service.management.readonly,https://www.googleapis.com/auth/trace.append --image=ubuntu-minimal-2004-focal-v20210720 --image-project=ubuntu-os-cloud --boot-disk-size=50GB --boot-disk-type=pd-balanced --boot-disk-device-name=sfr-cpu-small --no-shielded-secure-boot --shielded-vtpm --shielded-integrity-monitoring --reservation-affinity=any

gcloud beta compute ssh sfr--cpu-small-us-east1-d-1 --project= --zone=us-east1-d

sudo apt update sudo apt install --yes git screen python3.9 python3.9-venv

screen -S codegen_350M_mono

curl https://sdk.cloud.google.com | bash source ~/.bashrc gcloud init ssh-keygen -t rsa -f ~/.ssh/google_compute_engine -N ''

export WANDB_API_KEY= export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials//adc.json export GCLOUD_PROJECT=

git clone https://:@github.com/salesforce/jaxformer.git/ cd jaxformer

python3.9 -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools pip install -r requirements.txt

python3 -m jaxformer.train --config config/codegen_350M_multi.json

gcloud compute tpus tpu-vm ssh sfr-erik.nijkamp-tpu-v3-64-us-east1-d-1 --zone us-east1-d --internal-ip --worker=0

A100 fine-tune

apt install python3.8 python3.8-venv python3.8-dev

curl https://sdk.cloud.google.com | bash source ~/.bashrc gcloud init

export GOOGLE_APPLICATION_CREDENTIALS=~/.config/gcloud/legacy_credentials//adc.json export GCLOUD_PROJECT=

python3.8 -m venv .venv source .venv/bin/activate pip install --upgrade pip setuptools pip install torch --extra-index-url https://download.pytorch.org/whl/cu113 pip install transformers==4.21.1 datasets==1.16.1 deepspeed==0.7.0 tensorflow-cpu==2.5.0

pip install -e .

deepspeed --num_gpus=1 jaxformer/hf/train.py

Conversion

python3 -m jaxformer.hf.convert --config=config/codegen_1B_mono.json --step=150000

Features

v1