RAFT — LMFlow documentation (original) (raw)
1 Introduction#
We remark that the example is built on LLaMA whose licensed is for non-commercial use only.
Reinforcement Learning from Human Feedback (RLHF) requires a reward function to guide the adjustment of the generative model. In this example, we show how to use LMFlow framework to train a reward model following the procedure in the InstructGPT paper: https://arxiv.org/abs/2203.02155 and then align the model via the RAFT algorithm (Reward rAnked FineTuning).
This example contains both reward modeling and RAFT alignment for completeness. For users’ convenience, we have already provided a reward model based on GPT-Neo-2.7B in huggingface repo so one can skip the reward modeling first.
1.1 Dataset description#
We use the Dahoas/full-hh-rlhf dataset as an example, where each sample of this dataset consists of a prompt and two responses from the assistant. In particular, the response with label “chosen” is preferred as compared to the response with label “rejected”. The dataset consists of 112K training samples and 12.5K test samples. The following is an example sample of the dataset:
" Human: What kind of noises did dinosaurs make? Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be Human: yes they did Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. Human: you cant read Assistant:
Chosen response: "You can read?"
Rejected response: "there’s a lot of stuff humans don’t know"
To facilitate the training, we reformulate the prompt by adding ``###’’ at the beginning of the characters so that the model knows to reply. The new sample will be of the form:
"###Human: What kind of noises did dinosaurs make? ###Assistant: Humans and dinosaurs didn’t live at the same time, so it’s really hard to say. The best place to find out what noises dinosaurs made would be ###Human: yes they did ###Assistant: to guess, and that would probably require lots of reading and a certain amount of imagination, so we’re not really prepared to do that. ###Human: you cant read ###Assistant:
Chosen response: "You can read?"
Rejected response: "there’s a lot of stuff humans don’t know"
We prepare all the used dataset in the directory ./data/hh_rlhf, which can be obtained by running the following command at LMFlow
cd data && ./download.sh hh_rlhf && cd -
2 Reward Modeling#
We follow the the procedure in the InstructGPT paper: https://arxiv.org/abs/2203.02155 to train a reward model using the HH-RLHF dataset first, which includes
- Supervised Finetuning (SFT);
- Reward modeling by comparison dataset.
2.1 Supervised Finetuning (SFT)#
Here is an example of dataset /home/xiongwei/LMFlow/data/hh_rlhf/sft/hh_rlhf_sft.json. We use only the preferred responses so we get 112K training samples.
{"type": "text_only",
"instances":
[
{"text": "###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances. If you carry your phone in a pocket or a purse then you probably want a case. But if you only need a phone for quick interactions, a case may actually cause more harm than good. What do you need the phone for? Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing. It can also get you in trouble if you have your phone turned off for some reason. Then you will turn it back on and it won’t do anything. If you can afford to replace it, then you need a case to protect it. The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant: You’re welcome."},
{"text": "###Human: I'm trying to learn about the salam witch trials###Assistant: If you’re looking for books about witchcraft trials, I can recommend some books for you. But you probably mean you’re looking for more historical information about the Salem witch trials in 1692, and specifically about the salam witch trials in 1692?###Human: What are some good books to learn about the salam witch trials###Assistant: What would you like to learn? If you’re interested in history, one of the best books is The Witch Persecutions in Colonial America: A History. If you’re interested in witchcraft as a cultural phenomenon, you might enjoy two excellent books: Religion and the Decline of Magic: Studies in Popular Beliefs in Sixteenth- and Seventeenth-Century England by Keith Thomas and Magic, Witchcraft, and the Otherworld: An Anthropology of Superstition by Jack Goody. If you’re interested in history specifically as it relates to religion, you might enjoy The Popish Plot, or Prelates' Plot: A History of the Popish Plot in England, by K. J. Everett."}
]
}
You can edit /scripts/run_finetune.sh and modify the parameters. We present an example with GPT-Neo-2.7B here.
- –model_name_or_path: EleutherAI/gpt-neo-2.7B
- –dataset_path: ${project_dir}/data/hh_rlhf/sft
- –output_dir: the path you want to store the sft model
- –num_train_epochs: 1
- –learning_rate: 2e-5
- –per_device_train_batch_size: adjust according to your GPU memory source.
Then, we can run the following command to perform SFT.
./scripts/run_finetune.sh
You can also use lora training by the following command but also need to set the model_name_or_path and dataset appropriately by editing run_finetune_with_lora.sh .
./scripts/run_finetune_with_lora.sh
2.2 Reward Modeling#
Data Preparation
Reward modeling involves ranking responses based on the given prompt, typically requiring human labeling. Fortunately, we already have two responses for each prompt in this example, with one preferred over the other. To create a new dataset, we modify the original dataset to include two keys, “positive” and “negative”, with the former being the preferred response.
Here is an example of the dataset LMFlow/data/hh_rlhf/rm/hh_rlhf_rm_training.json, consisting of two samples.
{"instances": [ {"positive": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that. Can you tell me more about the situation? Are you using the heater on a fireplace? Do you have any other heater in the room? What materials are in the room? Do you feel any heat from the heater? How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”. It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like. You’d also have to make sure you’re using it correctly. Is it switched off? Is it safe to touch? It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: Glad to be helpful. I’m always happy to assist.", "negative": "###Human: Are space heaters safe?###Assistant: What do you mean? Can you give me more details?###Human: I mean, can they be a fire hazard?###Assistant: I don’t know about that. Can you tell me more about the situation? Are you using the heater on a fireplace? Do you have any other heater in the room? What materials are in the room? Do you feel any heat from the heater? How are you touching the heater?###Human: Can you just tell me if there are any situations where they could be a fire hazard?###Assistant: The answer is “it depends”. It would depend on the specific heater you are using, how you are using it, what kind of fuel it uses, what materials you have in the room, and what the surroundings are like. You’d also have to make sure you’re using it correctly. Is it switched off? Is it safe to touch? It’s not the sort of thing I could just answer by looking at the text you provided.###Human: Thank you. That information is helpful.###Assistant: You’re welcome. I’m glad to help. Is there anything else you want to know?"}, {"positive": "###Human: how do i teach my kids to make forts at home?###Assistant: Do you mean you want to teach them to build a simple structure they can hide in? Or something more elaborate? I don’t know exactly what you mean by a “fort”, but maybe this gives you some ideas?", "negative": "###Human: how do i teach my kids to make forts at home?###Assistant: There are many ways to make forts! The simplest might be to just pile up all the furniture in one room. This works if you have a single room in your house, although it might still be a bit challenging to build a taller and sturdier fort this way. The best way to build forts is to start with something easy, and then have fun exploring how you can improve it."} ]}
Training
To start from a model from a previous sft step, you may edit the ./scripts/run_reward_modeling.sh script and update the “dataset_path” to use the desired dataset. By default, we will use LoRA training in reward modeling as we found that it is superior in terms of optimization and gives comparable performance.
- –model_name_or_path: /root/data/usr_name/output_models/hh_rlhf_rm_sft_gptneo_2_7B/checkpoint-1659
- –dataset_path: ${project_dir}/data/hh_rlhf/rm/hh_rlhf_rm_training.json
- –output_dir: the path you want to store the reward model
- –num_train_epochs: 1
- –learning_rate: 3e-5
- –per_device_train_batch_size: adjust according to your GPU memory source.
- –eval_steps: 400
- –validation_split_percentage: 10
The load_dataset function splits the dataset into training and evaluation sets, which can also be customized by editing the function in /examples/run_reward_modeling.py if you want to prepare your own dataset when running the script. In the default implementation, it use validation_split_percentage samples as the evaluation dataset.
The reward modeling script can be used by
./scripts/run_reward_modeling.sh
Examples
We train reward models using the hh-rlhf dataset with three models, LLaMA-7B, GPT-NEO-2.7B, and GPT-NEO-1.3B. The model is first supervised fine-tuned with the training dataset in last step. The reward modeling is trained using the 112K training samples and 12.5 test samples.
2.3 LoRA Merge and Get Reward Model#
We use ./examples/merge_lora.py to merge the LoRA adapter with the sft rm model. We are ready to align our model.
3 RAFT Alignment#
Original paper: RAFT: Reward rAnked FineTuning for Generative Foundation Model Alignment
3.1 Algorithms Overview#
Main ideas of RAFT
Clearly the global ranking strategy is more efficient in terms of the reward learning. However, in some cases (e.g. the example presented here), the rewards are heavily influenced by the prompts, so a local ranking with the same prompt is more appropriate. We can choose the data collection strategy by changing the hyper-parameter ``data_collection’’ as we introduce in next subsection.
3.2 Hyper-parameters#
Table 1: Hyper-parameters of RAFT.
3.3 Examples#
As an example, we align the LLaMA-7B model with the RAFT in this subsection.
3.3.1 SFT#
We also first fine-tune the base model on the HH-RLHF dataset. We only use a different –model_name_or_path to use LLaMA model. We note that LLaMA with licensed is for non-commercial use only. We refer readers to https://optimalscale.github.io/LMFlow/examples/checkpoints.html for more details to get the LLaMA-7B model.
3.3.2 RAFT Alignment#
We align the LLaMA-7B-SFT model in this subsection. Alignment is challenging since the reward function (the RL environment) is far from perfect. Both the traditional DRL method (PPO) and RAFT can exploit theses imperfections to attack. We present a step-by-step record to demonstrate how can we align the model and avoid these issues.
Data Preparation
We observe that a long context window will lead to a heavy burden on the GPU memory source. Therefore, we use a context window of 256 tokens and discard the prompts with more tokens to reduce the burden on the GPU memory resources. This results in a prompt set of 82147 samples (originally 112K). The following is an example of the prompt where we simply discard the response:
"###Human: Should you buy a case to protect your cell phone?###Assistant: It depends on your circumstances. If you carry your phone in a pocket or a purse then you probably want a case. But if you only need a phone for quick interactions, a case may actually cause more harm than good. What do you need the phone for? Are you a parent, or do you work from home?###Human: What harm could it do?###Assistant: A phone case can damage the screen, for one thing. It can also get you in trouble if you have your phone turned off for some reason. Then you will turn it back on and it won’t do anything. If you can afford to replace it, then you need a case to protect it. The problem is that most people aren’t able to afford to replace their phones all the time.###Human: Thanks for letting me know.###Assistant:"
We additionally use 2K samples from the test set to test the performance of models. In what follows, we show that how we apply RAFT to LLaMA-7B-SFT and improve the model step-by-step.
Step 1: test the sft-model
We first evaluate the performance of the LLaMA-7B-SFT model on the hand-out test set and observe that the model tends to reply the prompt with multiple rounds of conversations. Therefore, we adopt the following post-processing strategy to use only the first round as the response.
def _clean_text(self, text): stext = [x for x in text.split("###Human") if x] return stext[0].strip().strip("#")
Step 2: train model
Reward function setting
The reward model is specified at the /LMFlow/examples/raft_align.py to set up the reward model we want to use. In our case, we will use the GPT-Neo-2.7B-rm trained in the last step, which is set as follows:
reward_model_or_path: Optional[str] = field( default="weqweasdas/hh_rlhf_rm", metadata={ "help": ( "reward model name (huggingface) or its path" ), }, )
Note that in general, if the reward function is not trained by following the steps in last section, you may also need to modify the ``get_reward_function’’ function in the same file to use your customized reward function.
We run the alignment with the following command and hyper-parameters
./scripts/run_raft_align.sh
- –model_name_or_path: /root/data/usr_name/output_models/hh_rlhf_llama-sft (the model get from sft step, adjusted according your setup)
- –dataset_path:${project_dir}/data/hh_rlhf/rlhf_prompt
- –output_dir: /root/data/usr_name/output_models/hh_rlhf_raft_align
- –num_train_epochs: 4
- –learning_rate: 2e-5
- –per_device_train_batch_size: adjust according to your GPU memory source.
- –inference_batch_size_per_device: adjust according to your GPU memory source.
- –num_raft_iteration 20
- –top_reward_percentage 0.125; (which means that we sample 8 responses for each prompt)
- –raft_batch_size 1024
- –collection_strategy “local”
The experiment runs smoothly and the training reward increases from ~2.7 to 3.4. However, we observe a significant drop in the diversity metric (e.g. distinct-2 drops to 0.22 from 0.39). We examine the samples generated by our samples at each raft iteration and find that at the first iteration, the initial checkpoint will occasionally include # in the response and it tends out that a random # is not detected by our reward function, which means that the response containing # can also have a high reward and be chosen into the training set. Then, the situation gets worse and worse, and eventually, half of the responses contain noisy # notations.
Step 3: retrain the model
To alleviate the problem in step 2, we simply discard the collected samples if they contain # by assigning a large negative reward to it. It turns out that this works for our goal. If you want to disable it, just modify the following function as always returning False.
def _discard_sample(self, text): if "#" in text: return True return False
The following figure shows the reward curve of RAFT (note that we use a smaller temperature to test the model, leading to a higher evaluation reward):
It tends out that the obtained model achieves a good reward and also an acceptable diversity metric, where we refer the interested readers to the original paper for details. However, it is more like a starting point of our journey. We present some randomly sampled responses here. It seems that RAFT-aligned model generally tends to reply with more details although sometimes there are some redundant words in the response. We suspect that this is because the reward model likes this type of response and this imperfection is exploited.
3.3.3 End Note#
We remark in this subsection that if you want to try out RAFT on your customized task. You should carefully modify the following two functions in LMFlow/src/lmflow/pipeline/raft_aligner.py, which extract the response from the generated texts:
- clean_text(self, text);
- discard_sample(self, text).
Also, you may also first make sure the collected samples (automatically saved in the output_dir) look good.
There are still many rooms of improvement. To further improve the model performance, we can improve the reward model (e.g. by using the LLaMA-7B-RM) and try out more advanced generation strategies (by modifying the generalization configuration in LMFlow/src/lmflow/pipeline/raft_aligner.py), which we leave for future work. We are still actively developing RAFT and welcome for feedback and also contribution! Also checkout our LMFlow framework to get more fun of LLMs: