gradient checkpointing failed in xla_device · Issue #5766 · pytorch/xla (original) (raw)
❓ Questions and Help
I try to fine-tune a large language model on xla_device, these models come from huggingface. The error is reported
Traceback (most recent call last):
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
r = call_item.fn(*call_item.args, **call_item.kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
return [fn(*args) for args in chunk]
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
return [fn(*args) for args in chunk]
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/runtime.py", line 85, in wrapper
return fn(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 75, in _run_thread_per_device
replica_results = list(
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
yield fs.pop().result()
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 444, in result
return self.__get_result()
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
raise self._exception
File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/thread.py", line 57, in run
result = self.fn(*self.args, **self.kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 68, in _thread_fn
return fn()
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 184, in __call__
self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 196, in _mp_fn
train()
File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 193, in train
train_loop_fn(train_device_loader, epoch)
File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 184, in train_loop_fn
output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 692, in forward
outputs = self.model(
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 459, in forward
layer_outputs = torch.utils.checkpoint.checkpoint(
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
return fn(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 457, in checkpoint
next(gen)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 1157, in _checkpoint_without_reentrant_generator
device_module = _get_device_module(device)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 67, in _get_device_module
device_module = getattr(torch, device)
File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/__init__.py", line 1833, in __getattr__
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'xla'
The function torch.utils.checkpoint.checkpoint decorated with _disable_dynamo.
https://github.com/pytorch/pytorch/blob/0d95378341b4eb19849295c7ccab08cc9be328a7/torch/utils/checkpoint.py#L341
Does this mean that if the model's device is set to xla, then torch.utils.checkpoint.checkpoint cannot be used?
If so, are there any alternative approaches to avoid using gradient checkpoint in LLM?
Any help on this would be greatly appreciated!
fine-tune.py
import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json
import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments
import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")
@dataclass
class DataArguments:
data_path: str = field(
default=None, metadata={"help": "Path to the training data."}
)
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
use_lora: bool = field(default=False)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(
self,
data_path,
tokenizer,
model_max_length,
user_tokens=[195],
assistant_tokens=[196],
):
super(SupervisedDataset, self).__init__()
self.data = json.load(open(data_path))
self.tokenizer = tokenizer
self.model_max_length = model_max_length
self.user_tokens = user_tokens
self.assistant_tokens = assistant_tokens
self.ignore_index = -100
item = self.preprocessing(self.data[0])
labels = []
for id_ in item["labels"]:
if id_ == -100:
continue
labels.append(id_)
def __len__(self):
return len(self.data)
def preprocessing(self, example):
input_ids = []
labels = []
for message in example["conversations"]:
from_ = message["from"]
value = message["value"]
value_ids = self.tokenizer.encode(value)
if from_ == "human":
input_ids += self.user_tokens + value_ids
labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
value_ids
)
else:
input_ids += self.assistant_tokens + value_ids
labels += [self.ignore_index] + value_ids
input_ids.append(self.tokenizer.eos_token_id)
labels.append(self.tokenizer.eos_token_id)
input_ids = input_ids[: self.model_max_length]
labels = labels[: self.model_max_length]
input_ids += [self.tokenizer.pad_token_id] * (
self.model_max_length - len(input_ids)
)
labels += [self.ignore_index] * (self.model_max_length - len(labels))
input_ids = torch.LongTensor(input_ids)
labels = torch.LongTensor(labels)
attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
return self.preprocessing(self.data[idx])
def train():
device = xm.xla_device()
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments)
)
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
cache_dir=training_args.cache_dir,
).to(device)
print('model device:', model.device)
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
trust_remote_code=True,
use_fast=False,
model_max_length=training_args.model_max_length,
cache_dir=training_args.cache_dir,
)
# if training_args.use_lora:
# from peft import LoraConfig, TaskType, get_peft_model
#
# peft_config = LoraConfig(
# task_type=TaskType.CAUSAL_LM,
# target_modules=["W_pack"],
# inference_mode=False,
# r=1,
# lora_alpha=32,
# lora_dropout=0.1,
# )
# model.enable_input_require_grads()
# model = get_peft_model(model, peft_config)
# model.print_trainable_parameters()
dataset = SupervisedDataset(
data_args.data_path, tokenizer, training_args.model_max_length
)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=training_args.per_device_train_batch_size,
shuffle=True,
num_workers=8,
persistent_workers=False,
prefetch_factor=16)
torch.manual_seed(training_args.seed)
optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay)
train_device_loader = pl.MpDeviceLoader(
data_loader,
device,
loader_prefetch_size=8,
device_prefetch_size=4,
host_to_device_transfer_threads=8)
def train_loop_fn(loader, epoch):
tracker = xm.RateTracker()
model.train()
for step, batch in enumerate(loader):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
labels = batch["labels"]
with xp.StepTrace('train_baichuan2-13b-chat'):
with xp.Trace('build_graph'):
optimizer.zero_grad()
output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = output.loss
loss.backward()
xm.optimizer_step(optimizer)
tracker.add(training_args.per_device_train_batch_size)
for epoch in range(1, int(training_args.num_train_epochs)):
train_loop_fn(train_device_loader, epoch)
def _mp_fn(index):
train()
if __name__ == "__main__":
xmp.spawn(_mp_fn, args=(), nprocs=None)
fine-tune.py can be run by executing the command
GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA python3 fine-tune.py \
--report_to "none" \
--data_path "Baichuan2-13B-Chat/data/belle_chat_ramdon_10k.json" \
--model_name_or_path "Baichuan2-13B-Chat/models--baichuan-inc--Baichuan2-13B-Chat/snapshots/8f6e343d545c503b91429582231d1d354dac2740/" \
--output_dir "output" \
--model_max_length 512 \
--num_train_epochs 4 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 1 \
--save_strategy epoch \
--learning_rate 2e-5 \
--lr_scheduler_type constant \
--adam_beta1 0.9 \
--adam_beta2 0.98 \
--adam_epsilon 1e-8 \
--max_grad_norm 1.0 \
--weight_decay 1e-4 \
--warmup_ratio 0.0 \
--logging_steps 1 \
--gradient_checkpointing False \
--bf16 False \
--tf32 False
dataset: https://github.com/baichuan-inc/Baichuan2/tree/main/fine-tune/data
model: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/tree/main
torch version: Version: 2.1.0
torch-xla:2.1.0