txtai - Trainer (original) (raw)

HFTrainer

pipeline pipeline

Trains a new Hugging Face Transformer model using the Trainer framework.

Example

The following shows a simple example using this pipeline.

`import pandas as pd

from datasets import load_dataset

from txtai.pipeline import HFTrainer

trainer = HFTrainer()

Pandas DataFrame

df = pd.read_csv("training.csv") model, tokenizer = trainer("bert-base-uncased", df)

Hugging Face dataset

ds = load_dataset("glue", "sst2") model, tokenizer = trainer("bert-base-uncased", ds["train"], columns=("sentence", "label"))

List of dicts

dt = [{"text": "sentence 1", "label": 0}, {"text": "sentence 2", "label": 1}]] model, tokenizer = trainer("bert-base-uncased", dt)

Support additional TrainingArguments

model, tokenizer = trainer("bert-base-uncased", dt, learning_rate=3e-5, num_train_epochs=5) `

All TrainingArguments are supported as function arguments to the trainer call.

See the links below for more detailed examples.

Notebook Description
Train a text labeler Build text sequence classification models Open In Colab
Train without labels Use zero-shot classifiers to train new models Open In Colab
Train a QA model Build and fine-tune question-answering models Open In Colab
Train a language model from scratch Build new language models Open In Colab

Training tasks

The HFTrainer pipeline builds and/or fine-tunes models for following training tasks.

Task Description
language-generation Causal language model for text generation (e.g. GPT)
language-modeling Masked language model for general tasks (e.g. BERT)
question-answering Extractive question-answering model, typically with the SQuAD dataset
sequence-sequence Sequence-Sequence model (e.g. T5)
text-classification Classify text with a set of labels
token-detection ELECTRA-style pre-training with replaced token detection

PEFT

Parameter-Efficient Fine-Tuning (PEFT) is supported through Hugging Face's PEFT library. Quantization is provided through bitsandbytes. See the examples below.

`from txtai.pipeline import HFTrainer

trainer = HFTrainer() trainer(..., quantize=True, lora=True) `

When these parameters are set to True, they use default configuration. This can also be customized.

`quantize = { "load_in_4bit": True, "bnb_4bit_use_double_quant": True, "bnb_4bit_quant_type": "nf4", "bnb_4bit_compute_dtype": "bfloat16" }

lora = { "r": 16, "lora_alpha": 8, "target_modules": "all-linear", "lora_dropout": 0.05, "bias": "none" }

trainer(..., quantize=quantize, lora=lora) `

The parameters also accept transformers.BitsAndBytesConfig and peft.LoraConfig instances.

See the following PEFT documentation links for more information.

Merge

An important parameter for language-generation and language-modeling tasks is merge or the packing of data into chunks.

It supports the following options.

Merging helps reduce training time as data can be processed efficiently without padding. concat maximizes this as it guarantees each chunk will be up to maxlength size. pack is a middle ground where data is combined but records are preserved.

For general language modeling tasks like masked language modeling, concat is the best choice. For instruction/prompt fine-tuning, pack or None are the better choices as it guarantees complex logic is not split across chunks.

Methods

Python documentation for the pipeline.

__call__(base, train, validation=None, columns=None, maxlength=None, stride=128, task='text-classification', prefix=None, metrics=None, tokenizers=None, checkpoint=None, quantize=None, lora=None, merge='concat', **args)

Builds a new model using arguments.

Parameters:

Name Type Description Default
base path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple required
train training data required
validation validation data None
columns tuple of columns to use for text/label, defaults to (text, None, label) None
maxlength maximum sequence length, defaults to tokenizer.model_max_length None
stride chunk size for splitting data for QA tasks 128
task optional model task or category, determines the model type, defaults to "text-classification" 'text-classification'
prefix optional source prefix None
metrics optional function that computes and returns a dict of evaluation metrics None
tokenizers optional number of concurrent tokenizers, defaults to None None
checkpoint optional resume from checkpoint flag or path to checkpoint directory, defaults to None None
quantize quantization configuration to pass to base model None
lora lora configuration to pass to PEFT model None
merge determines how chunks are combined for language modeling tasks - "concat" (default), "pack" or None 'concat'
args training arguments {}

Returns:

Type Description
(model, tokenizer)

Source code in txtai/pipeline/train/hftrainer.py

45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 def __call__( self, base, train, validation=None, columns=None, maxlength=None, stride=128, task="text-classification", prefix=None, metrics=None, tokenizers=None, checkpoint=None, quantize=None, lora=None, merge="concat", **args ): """ Builds a new model using arguments. Args: base: path to base model, accepts Hugging Face model hub id, local path or (model, tokenizer) tuple train: training data validation: validation data columns: tuple of columns to use for text/label, defaults to (text, None, label) maxlength: maximum sequence length, defaults to tokenizer.model_max_length stride: chunk size for splitting data for QA tasks task: optional model task or category, determines the model type, defaults to "text-classification" prefix: optional source prefix metrics: optional function that computes and returns a dict of evaluation metrics tokenizers: optional number of concurrent tokenizers, defaults to None checkpoint: optional resume from checkpoint flag or path to checkpoint directory, defaults to None quantize: quantization configuration to pass to base model lora: lora configuration to pass to PEFT model merge: determines how chunks are combined for language modeling tasks - "concat" (default), "pack" or None args: training arguments Returns: (model, tokenizer) """ # Quantization / LoRA support if (quantize or lora) and not PEFT: raise ImportError('PEFT is not available - install "pipeline" extra to enable') # Parse TrainingArguments args = self.parse(args) # Set seed for model reproducibility set_seed(args.seed) # Load model configuration, tokenizer and max sequence length config, tokenizer, maxlength = self.load(base, maxlength) # Default tokenizer pad token if it's not set tokenizer.pad_token = tokenizer.pad_token if tokenizer.pad_token is not None else tokenizer.eos_token # Prepare parameters process, collator, labels = self.prepare(task, train, tokenizer, columns, maxlength, stride, prefix, merge, args) # Tokenize training and validation data train, validation = process(train, validation, os.cpu_count() if tokenizers and isinstance(tokenizers, bool) else tokenizers) # Create model to train model = self.model(task, base, config, labels, tokenizer, quantize) # Default config pad token if it's not set model.config.pad_token_id = model.config.pad_token_id if model.config.pad_token_id is not None else model.config.eos_token_id # Load as PEFT model, if necessary model = self.peft(task, lora, model) # Add model to collator if collator: collator.model = model # Build trainer trainer = Trainer( model=model, processing_class=tokenizer, data_collator=collator, args=args, train_dataset=train, eval_dataset=validation if validation else None, compute_metrics=metrics, ) # Run training trainer.train(resume_from_checkpoint=checkpoint) # Run evaluation if validation: trainer.evaluate() # Save model outputs if args.should_save: trainer.save_model() trainer.save_state() # Put model in eval mode to disable weight updates and return (model, tokenizer) return (model.eval(), tokenizer)