gradient checkpointing failed in xla_device · Issue #5766 · pytorch/xla (original) (raw)

❓ Questions and Help

I try to fine-tune a large language model on xla_device, these models come from huggingface. The error is reported

Traceback (most recent call last):
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 239, in _process_worker
    r = call_item.fn(*call_item.args, **call_item.kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in _process_chunk
    return [fn(*args) for args in chunk]
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/process.py", line 198, in <listcomp>
    return [fn(*args) for args in chunk]
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/runtime.py", line 85, in wrapper
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 75, in _run_thread_per_device
    replica_results = list(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 68, in _thread_fn
    return fn()
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch_xla/_internal/pjrt.py", line 184, in __call__
    self.fn(runtime.global_ordinal(), *self.args, **self.kwargs)
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 196, in _mp_fn
    train()
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 193, in train
    train_loop_fn(train_device_loader, epoch)
  File "/projs/framework/root/llm2/fine-tune/fine-tune.py", line 184, in train_loop_fn
    output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 692, in forward
    outputs = self.model(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/root/.cache/huggingface/modules/transformers_modules/modeling_llm.py", line 459, in forward
    layer_outputs = torch.utils.checkpoint.checkpoint(
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_compile.py", line 24, in inner
    return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 328, in _fn
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/_dynamo/external_utils.py", line 17, in inner
    return fn(*args, **kwargs)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 457, in checkpoint
    next(gen)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 1157, in _checkpoint_without_reentrant_generator
    device_module = _get_device_module(device)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/utils/checkpoint.py", line 67, in _get_device_module
    device_module = getattr(torch, device)
  File "/home/root/.conda/envs/torch-xla/lib/python3.8/site-packages/torch/__init__.py", line 1833, in __getattr__
    raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
AttributeError: module 'torch' has no attribute 'xla'

The function torch.utils.checkpoint.checkpoint decorated with _disable_dynamo.
https://github.com/pytorch/pytorch/blob/0d95378341b4eb19849295c7ccab08cc9be328a7/torch/utils/checkpoint.py#L341
Does this mean that if the model's device is set to xla, then torch.utils.checkpoint.checkpoint cannot be used?
If so, are there any alternative approaches to avoid using gradient checkpoint in LLM?
Any help on this would be greatly appreciated!

fine-tune.py

import os
import math
import pathlib
from typing import Optional, Dict
from dataclasses import dataclass, field
import json

import torch
from torch.utils.data import Dataset
import transformers
from transformers.training_args import TrainingArguments

import torch_xla
from torch_xla import runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.debug.metrics as met
import torch_xla.debug.profiler as xp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_backend
import torch_xla.distributed.xla_multiprocessing as xmp

@dataclass
class ModelArguments:
    model_name_or_path: Optional[str] = field(default="baichuan-inc/Baichuan2-7B-Base")

@dataclass
class DataArguments:
    data_path: str = field(
        default=None, metadata={"help": "Path to the training data."}
    )

@dataclass
class TrainingArguments(transformers.TrainingArguments):
    cache_dir: Optional[str] = field(default=None)
    optim: str = field(default="adamw_torch")
    model_max_length: int = field(
        default=512,
        metadata={
            "help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
        },
    )
    use_lora: bool = field(default=False)


class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(
        self,
        data_path,
        tokenizer,
        model_max_length,
        user_tokens=[195],
        assistant_tokens=[196],
    ):
        super(SupervisedDataset, self).__init__()
        self.data = json.load(open(data_path))
        self.tokenizer = tokenizer
        self.model_max_length = model_max_length
        self.user_tokens = user_tokens
        self.assistant_tokens = assistant_tokens
        self.ignore_index = -100
        item = self.preprocessing(self.data[0])
        labels = []
        for id_ in item["labels"]:
            if id_ == -100:
                continue

            labels.append(id_)

    def __len__(self):
        return len(self.data)

    def preprocessing(self, example):
        input_ids = []
        labels = []

        for message in example["conversations"]:
            from_ = message["from"]
            value = message["value"]
            value_ids = self.tokenizer.encode(value)

            if from_ == "human":
                input_ids += self.user_tokens + value_ids
                labels += [self.tokenizer.eos_token_id] + [self.ignore_index] * len(
                    value_ids
                )
            else:
                input_ids += self.assistant_tokens + value_ids
                labels += [self.ignore_index] + value_ids
        input_ids.append(self.tokenizer.eos_token_id)
        labels.append(self.tokenizer.eos_token_id)
        input_ids = input_ids[: self.model_max_length]
        labels = labels[: self.model_max_length]
        input_ids += [self.tokenizer.pad_token_id] * (
            self.model_max_length - len(input_ids)
        )
        labels += [self.ignore_index] * (self.model_max_length - len(labels))
        input_ids = torch.LongTensor(input_ids)
        labels = torch.LongTensor(labels)
        attention_mask = input_ids.ne(self.tokenizer.pad_token_id)
        return {
            "input_ids": input_ids,
            "labels": labels,
            "attention_mask": attention_mask,
        }

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        return self.preprocessing(self.data[idx])

def train():
    device = xm.xla_device()
    parser = transformers.HfArgumentParser(
        (ModelArguments, DataArguments, TrainingArguments)
    )
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    model = transformers.AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        cache_dir=training_args.cache_dir,
    ).to(device)

    print('model device:', model.device)
    tokenizer = transformers.AutoTokenizer.from_pretrained(
        model_args.model_name_or_path,
        trust_remote_code=True,
        use_fast=False,
        model_max_length=training_args.model_max_length,
        cache_dir=training_args.cache_dir,
    )
#    if training_args.use_lora:
#        from peft import LoraConfig, TaskType, get_peft_model
#
#        peft_config = LoraConfig(
#            task_type=TaskType.CAUSAL_LM,
#            target_modules=["W_pack"],
#            inference_mode=False,
#            r=1,
#            lora_alpha=32,
#            lora_dropout=0.1,
#        )
#        model.enable_input_require_grads()
#        model = get_peft_model(model, peft_config)
#        model.print_trainable_parameters()

    dataset = SupervisedDataset(
        data_args.data_path, tokenizer, training_args.model_max_length
    )

    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=training_args.per_device_train_batch_size,
        shuffle=True,
        num_workers=8,
        persistent_workers=False,
        prefetch_factor=16)

    torch.manual_seed(training_args.seed)
    optimizer = torch.optim.AdamW(model.parameters(), lr=training_args.learning_rate,
                                  betas=(training_args.adam_beta1, training_args.adam_beta2),
                                  eps=training_args.adam_epsilon, weight_decay=training_args.weight_decay)

    train_device_loader = pl.MpDeviceLoader(
        data_loader,
        device,
        loader_prefetch_size=8,
        device_prefetch_size=4,
        host_to_device_transfer_threads=8)

    def train_loop_fn(loader, epoch):
        tracker = xm.RateTracker()
        model.train()
        for step, batch in enumerate(loader):
          input_ids = batch["input_ids"]
          attention_mask = batch["attention_mask"]
          labels = batch["labels"]
          with xp.StepTrace('train_baichuan2-13b-chat'):
            with xp.Trace('build_graph'):
              optimizer.zero_grad()
              output = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
              loss = output.loss
              loss.backward()
              xm.optimizer_step(optimizer)
              tracker.add(training_args.per_device_train_batch_size)

    for epoch in range(1, int(training_args.num_train_epochs)):
      train_loop_fn(train_device_loader, epoch)

def _mp_fn(index):
    train()


if __name__ == "__main__":
    xmp.spawn(_mp_fn, args=(), nprocs=None)

fine-tune.py can be run by executing the command

GPU_NUM_DEVICES=1 PJRT_DEVICE=CUDA python3 fine-tune.py  \
    --report_to "none" \
    --data_path "Baichuan2-13B-Chat/data/belle_chat_ramdon_10k.json" \
    --model_name_or_path "Baichuan2-13B-Chat/models--baichuan-inc--Baichuan2-13B-Chat/snapshots/8f6e343d545c503b91429582231d1d354dac2740/" \
    --output_dir "output" \
    --model_max_length 512 \
    --num_train_epochs 4 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 1 \
    --save_strategy epoch \
    --learning_rate 2e-5 \
    --lr_scheduler_type constant \
    --adam_beta1 0.9 \
    --adam_beta2 0.98 \
    --adam_epsilon 1e-8 \
    --max_grad_norm 1.0 \
    --weight_decay 1e-4 \
    --warmup_ratio 0.0 \
    --logging_steps 1 \
    --gradient_checkpointing False \
    --bf16 False \
    --tf32 False

dataset: https://github.com/baichuan-inc/Baichuan2/tree/main/fine-tune/data
model: https://huggingface.co/baichuan-inc/Baichuan2-13B-Chat/tree/main
torch version: Version: 2.1.0
torch-xla:2.1.0