Developer guide for LoRA finetuning — AWS Neuron Documentation (original) (raw)
This document is relevant for: Inf2
, Trn1
, Trn2
Developer guide for LoRA finetuning#
This document will introduce how to enable model finetuning with LoRA.
For a complete api guide, refer to API.
Enable LoRA finetuning:#
We first set up LoRA-related configurations:
lora_config = nxd.modules.lora.LoraConfig( enable_lora=True, lora_rank=16, lora_alpha=32, lora_dropout=0.05, bias="none", lora_verbose=True, target_modules=["q_proj", "v_proj", "k_proj"], save_lora_base=False, merge_lora=False, )
The default target modules for different model architectures can be found in model.py.
We then initialize NxD model with LoRA enabled:
nxd_config = nxd.neuronx_distributed_config( ... lora_config=lora_config, ) model = nxd.initialize_parallel_model(nxd_config, ...)
Save LoRA checkpoint#
Users can save the LoRA adapter with
nxd.save_checkpoint( checkpoint_dir_str=checkpoint_dir, # checkpoint path tag=tag, # sub-directory under checkpoint path model=model )
Because save_lora_base=False
and merge_lora=False
, only the LoRA adapter is saved under checkpoint_dir/tag/
. We can also set merge_lora=True
to save the merged model, i.e., merging LoRA adapter into the base model.
Load LoRA checkpoint:#
A sample usage:
lora_config = LoraConfig( enable_lora=True, load_lora_from_ckpt=True, lora_save_dir=checkpoint_dir, # checkpoint path lora_load_tag=tag, # sub-directory under checkpoint path ) nxd_config = nxd.neuronx_distributed_config( ... lora_config=lora_config, ) model = nxd.initialize_parallel_model(nxd_config, ...)
The NxD model with be initialized with LoRA enabled and LoRA weights loaded. LoRA-related configurations are the same as the LoRA adapter checkpoint.
This document is relevant for: Inf2
, Trn1
, Trn2