Core Concepts — TextBrewer 0.2.1.post1 documentation (original) (raw)

Conventions

Batch Format(important)

Foward conventions: each batch to be passed to the model should be a dict or tuple:

Hence if the batch is not a dict, users should make sure that the order of each element in the batch is the same as the order of the arguments of model.forward. args is used for passing additional parameters.

Users can additionaly define a batch_postprocessor function to post-process batches if needed. batch_postprocessor should take a batch and return a batch. See the explanation on train method of Distillers for more details.

Since version 0.2.1, TextBrewer supports more flexible inputs scheme: users can feed different batches to student and teacher, or feed the cached values to save the forward pass time. See Feed Different batches to Student and Teacher, Feed Cached Values.

Configurations

Distillers

Distillers are in charge of conducting the actual experiments. The following distillers are available:

User-Defined Functions

In TextBrewer, there are two functions that should be implemented by users: callback() and adaptor() .

callback(model, step) → None

At each checkpoint, after saving the student model, the callback function will be called by the distiller. callback can be used to evaluate the performance of the student model at each checkpoint.

Note

If users want to do an evaluation in the callback, remember to add model.eval() in the callback.

Parameters

adaptor(batch, model_inputs) → dict

It converts the model inputs and outputs to the specified format so that they can be recognized by the distiller. At each training step, batch and model outputs will be passed to the adaptor; adaptor reorganize the data and returns a dict.

The functionality of the adaptor is shown in the figure below:

_images/adaptor.png

Parameters

Return type

dict

Returns

a dictionary that may contain the following keys and values:

Note

Note

These keys are all optional:

In most cases logits should be provided, unless you are doing multi-stage training or non-classification tasks, etc.

Example:

''' Suppose the model outputs are: logits, sequence_output, total_loss class MyModel(): def forward(self, input_ids, attention_mask, labels, ...): ... return logits, sequence_output, total_loss

logits: Tensor of shape (batch_size, num_classes) sequence_output: List of tensors of (batch_size, length, hidden_dim) total_loss: scalar tensor

model inputs are: input_ids = batch[0] : input_ids (batch_size, length) attention_mask = batch[1] : attention_mask (batch_size, length) labels = batch[2] : labels (batch_size, num_classes) ''' def SimpleAdaptor(batch, model_outputs): return {'logits': (model_outputs[0],), 'hidden': model.outputs[1], 'inputs_mask': batch[1]}

Feed Different batches to Student and Teacher, Feed Cached Values

Feed Different batches

In some cases, student and teacher read different inputs. For example, if you distill a RoBERTa model to a BERT model, they cannot share the inputs since they have different vocabularies.

To solve this, one can build a dataset that returns a dict as the batch with keys 'student' and 'teacher'. TextBrewer will unpack the dict, and feeds batch['student'] to the student and its adaptor, feeds batch['teacher'] to the teacher and its adaptor, following the forward conventions.

Here is an example.

import torch from torch.utils.data import Dataset, TensorDataset, DataLoader

class TSDataset(Dataset): def init(self, teacher_dataset, student_dataset): # teacher_dataset and student_dataset are normal datasets # whose each element is a tuple or a dict. assert len(teacher_dataset) == len(student_dataset),
f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!"

   self.teacher_dataset = teacher_dataset
   self.student_dataset = student_dataset

def len(self): return len(self.teacher_dataset)

def getitem(self,i): return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i]}

teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3)) student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2)) tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset) dataloader = DataLoader(dataset=tsdataset, ... )

Feed Cached Values

If you are ready to provide a dataset that returns dict with keys 'student' and 'teacher' like the one above, you can also add a another key 'teacher_cache', which stores the pre-computed outputs from the teacher. Then TextBrewer will treat batch['teacher_cache'] as the output from the teacher and feed it to the teacher’s adaptor. No teacher’s forward will be called.

Here is an example.

import torch from torch.utils.data import Dataset, TensorDataset, DataLoader

class TSDataset(Dataset): def init(self, teacher_dataset, student_dataset, teacher_cache): # teacher_dataset and student_dataset are normal datasets # whose each element is a tuple or a dict. # teacher_cache is a list of items; each item is the output from the teacher. assert len(teacher_dataset) == len(student_dataset),
f"lengths of teacher_dataset {len(teacher_dataset)} and student_dataset {len(student_dataset)} are not the same!" assert len(teacher_dataset) == len(teacher_cache),
f"lengths of teacher_dataset {len(teacher_dataset)} and teacher_cache {len(teacher_cache)} are not the same!" self.teacher_dataset = teacher_dataset self.student_dataset = student_dataset self.teacher_cache = teacher_cache

def len(self): return len(self.teacher_dataset)

def getitem(self,i): return {'teacher' : self.teacher_dataset[i], 'student' : self.student_dataset[i], 'teacher_cache':self.teacher_cache[i]}

teacher_dataset = TensorDataset(torch.randn(32,3),torch.randn(32,3)) student_dataset = TensorDataset(torch.randn(32,2),torch.randn(32,2))

We make some fake data and assume teacher model outputs are (logits, loss)

fake_logits = [torch.randn(3) for _ in range(32)] fake_loss = [torch.randn(1)[0] for _ in range(32)] teacher_cache = [(fake_logits[i],fake_loss[i]) for i in range(32)]

tsdataset = TSDataset(teacher_dataset=teacher_dataset,student_dataset=student_dataset, teacher_cache=teacher_cache) dataloader = DataLoader(dataset=tsdataset, ... )