dataset_utils — Model Optimizer 0.31.0 (original) (raw)
Utility functions for getting samples and forward loop function for different datasets.
Functions
create_forward_loop | Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer. |
---|---|
get_dataset_dataloader | Get a dataloader with the dataset name and toknizer of the target model. |
get_max_batch_size | Get the maximum batch size that can be used for the model. |
get_supported_datasets | Retrieves a list of datasets supported. |
create_forward_loop(model=None, dataset_name='cnn_dailymail', tokenizer=None, batch_size=1, num_samples=512, max_sample_length=512, device=None, include_labels=False, dataloader=None)
Creates and returns a forward loop function configured for a specific model, dataset, and tokenizer.
This function initializes a forward loop function tailored to process batches of data from the specified dataset using the given model and tokenizer. The forward loop function, when called, iterates over the dataset, applies the tokenizer to prepare the input data, feeds it into the model, and returns the model’s predictions.
Parameters:
- model (Module | None) – The PyTorch model for inference.
- dataset_name (str) – The name of the dataset to be used. Must be one of the datasets in get_supported_datasets().
- tokenizer (PreTrainedTokenizerBase | None) – The tokenizer used to preprocess text data into a format suitable for the model.
- batch_size (int) – Batch size of the returned dataloader. If 0 is provided, we auto determine the batch_size.
- num_samples (int) – Number of samples from the dataset.
- max_sample_length (int) – Maximum length of a sample.
- device (str | None) – Target device for the returned dataloader.
- include_labels (bool) – Whether to include labels in the dataloader.
- dataloader (DataLoader | None) – If provided, use the provided dataloader instead.
Return type:
Callable
Example usage for quantization:
import modelopt.torch.quantization as mtq from modelopt.torch.utils import create_forward_loop
Initialize model and tokenizer
...
Create forward loop for calibration
forward_loop = create_forward_loop( model=model, dataset_name="cnn_dailymail", tokenizer=tokenizer )
Quantize the model with the calibration dataset
mtq.quantize(model, quant_cfg, forward_loop=forward_loop)
Returns:
A forward loop function that can be called with no arguments. When called, this function iterates over
the dataset specified by dataset_name.
Parameters:
- model (Module | None)
- dataset_name (str)
- tokenizer (PreTrainedTokenizerBase | None)
- batch_size (int)
- num_samples (int)
- max_sample_length (int)
- device (str | None)
- include_labels (bool)
- dataloader (DataLoader | None)
Return type:
Callable
get_dataset_dataloader(dataset_name='cnn_dailymail', tokenizer=None, batch_size=1, num_samples=512, max_sample_length=512, device=None, include_labels=False)
Get a dataloader with the dataset name and toknizer of the target model.
Parameters:
- dataset_name (str) – Name of the dataset to load.
- tokenizer (PreTrainedTokenizerBase | None) – Instancne of Hugginface tokenizer.
- batch_size (int) – Batch size of the returned dataloader.
- num_samples (int) – Number of samples from the dataset.
- max_sample_length (int) – Maximum length of a sample.
- device (str | None) – Target device for the returned dataloader.
- include_labels (bool) – Whether to include labels in the dataloader.
Returns:
A instance of dataloader.
Return type:
DataLoader
get_max_batch_size(model, max_sample_length=512, sample_memory_usage_ratio=1.0, sample_input_single_batch=None)
Get the maximum batch size that can be used for the model.
Parameters:
- model (Module)
- max_sample_length (int)
- sample_memory_usage_ratio (float)
- sample_input_single_batch (Tensor)
get_supported_datasets()
Retrieves a list of datasets supported.
Returns:
A list of strings, where each string is the name of a supported dataset.
Return type:
_list_[_str_]
Example usage:
from modelopt.torch.utils import get_supported_datasets
print("Supported datasets:", get_supported_datasets())