Developer guide for LoRA finetuning — AWS Neuron Documentation (original) (raw)

This document is relevant for: Inf2, Trn1, Trn2

Developer guide for LoRA finetuning#

This document will introduce how to enable model finetuning with LoRA.

For a complete api guide, refer to API.

Enable LoRA finetuning:#

We first set up LoRA-related configurations:

lora_config = nxd.modules.lora.LoraConfig( enable_lora=True, lora_rank=16, lora_alpha=32, lora_dropout=0.05, bias="none", lora_verbose=True, target_modules=["q_proj", "v_proj", "k_proj"], save_lora_base=False, merge_lora=False, )

The default target modules for different model architectures can be found in model.py.

We then initialize NxD model with LoRA enabled:

nxd_config = nxd.neuronx_distributed_config( ... lora_config=lora_config, ) model = nxd.initialize_parallel_model(nxd_config, ...)

Save LoRA checkpoint#

Users can save the LoRA adapter with

nxd.save_checkpoint( checkpoint_dir_str=checkpoint_dir, # checkpoint path tag=tag, # sub-directory under checkpoint path model=model )

Because save_lora_base=False and merge_lora=False, only the LoRA adapter is saved under checkpoint_dir/tag/. We can also set merge_lora=True to save the merged model, i.e., merging LoRA adapter into the base model.

Load LoRA checkpoint:#

A sample usage:

lora_config = LoraConfig( enable_lora=True, load_lora_from_ckpt=True, lora_save_dir=checkpoint_dir, # checkpoint path lora_load_tag=tag, # sub-directory under checkpoint path ) nxd_config = nxd.neuronx_distributed_config( ... lora_config=lora_config, ) model = nxd.initialize_parallel_model(nxd_config, ...)

The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.

This document is relevant for: Inf2, Trn1, Trn2