Using XGBoost External Memory Version — xgboost 3.1.0-dev documentation (original) (raw)

Overview

When working with large datasets, training XGBoost models can be challenging as the entire dataset needs to be loaded into the main memory. This can be costly and sometimes infeasible.

External memory training is sometimes called out-of-core training. It refers to the capability that XGBoost can optionally cache data in a location external to the main processor, be it CPU or GPU. XGBoost doesn’t support network file systems by itself. As a result, for CPU, the external memory usually refers to a harddrive. And for GPU, it refers to either the host memory or a harddrive.

Users can define a custom iterator to load data in chunks for running XGBoost algorithms. External memory can be used for training and prediction, but training is the primary use case and it will be our focus in this tutorial. For prediction and evaluation, users can iterate through the data themselves, whereas training requires the entire dataset to be loaded into the memory. During model training, XGBoost fetches the cache in batches to construct the decision trees, hence avoiding loading the entire dataset into the main memory and achieve better vertical scaling (scaling within the same node).

Significant progress was made in the 3.0 release for the GPU implementation. We will introduce the difference between CPU and GPU in the following sections.

Note

Training on data from external memory is not supported by the exact tree method. We recommend using the default hist tree method for performance reasons.

Note

The feature is considered experimental but ready for public testing in 3.0. Vector-leaf is not yet supported.

The external memory support has undergone multiple development iterations. See below sections for a brief history.

Contents

Data Iterator

To start using the external memory, users need define a data iterator. The data iterator interface was added to the Python and C interfaces in 1.5, and to the R interface in 3.0.0. Like the QuantileDMatrix with DataIter, XGBoost loads data batch-by-batch using the custom iterator supplied by the user. However, unlike the QuantileDMatrix, external memory does not concatenate the batches (unless specified by the extmem_single_page for GPU) . Instead, it caches all batches in the external memory and fetch them on-demand. Go to the end of the document to see a comparison between QuantileDMatrix and the external memory version of ExtMemQuantileDMatrix.

Some examples are in the demo directory for a quick start. To enable external memory training, the custom data iterator needs to have two class methods: next andreset.

import os from typing import List, Callable

import numpy as np import xgboost

class Iterator(xgboost.DataIter): """A custom iterator for loading files in batches."""

def __init__(
    self, device: Literal["cpu", "cuda"], file_paths: List[Tuple[str, str]]
) -> None:
    self.device = device

    self._file_paths = file_paths
    self._it = 0
    # XGBoost will generate some cache files under the current directory with the
    # prefix "cache"
    super().__init__(cache_prefix=os.path.join(".", "cache"))

def load_file(self) -> Tuple[np.ndarray, np.ndarray]:
    """Load a single batch of data."""
    X_path, y_path = self._file_paths[self._it]
    # When the `ExtMemQuantileDMatrix` is used, the device must match. GPU cannot
    # consume CPU input data and vice-versa.
    if self.device == "cpu":
        X = np.load(X_path)
        y = np.load(y_path)
    else:
        import cupy as cp

        X = cp.load(X_path)
        y = cp.load(y_path)

    assert X.shape[0] == y.shape[0]
    return X, y

def next(self, input_data: Callable) -> bool:
    """Advance the iterator by 1 step and pass the data to XGBoost.  This function
    is called by XGBoost during the construction of ``DMatrix``

    """
    if self._it == len(self._file_paths):
        # return False to let XGBoost know this is the end of iteration
        return False

    # input_data is a keyword-only function passed in by XGBoost and has the similar
    # signature to the ``DMatrix`` constructor.
    X, y = self.load_file()
    input_data(data=X, label=y)
    self._it += 1
    return True

def reset(self) -> None:
    """Reset the iterator to its beginning"""
    self._it = 0

After defining the iterator, we can to pass it into the DMatrix or the ExtMemQuantileDMatrix constructor:

it = Iterator(device="cpu", file_paths=["file_0.npy", "file_1.npy", "file_2.npy"])

Use the ExtMemQuantileDMatrix for the hist tree method, recommended.

Xy = xgboost.ExtMemQuantileDMatrix(it) booster = xgboost.train({"tree_method": "hist"}, Xy)

The approx tree method also works, but with lower performance and cannot be used

with the quantile DMatrix.

Xy = xgboost.DMatrix(it) booster = xgboost.train({"tree_method": "approx"}, Xy)

The above snippet is a simplified version of Experimental support for external memory. For an example in C, please see demo/c-api/external-memory/. The iterator is the common interface for using external memory with XGBoost, you can pass the resultingDMatrix object for training, prediction, and evaluation.

The ExtMemQuantileDMatrix is an external memory version of theQuantileDMatrix. These two classes are specifically designed for thehist tree method for reduced memory usage and data loading overhead. See respective references for more info.

It is important to set the batch size based on the memory available. A good starting point for CPU is to set the batch size to 10GB per batch if you have 64GB of memory. It is _not_recommended to set small batch sizes like 32 samples per batch, as this can severely hurt performance in gradient boosting. See below sections for information about the GPU version and other best practices.

GPU Version (GPU Hist tree method)

External memory is supported by GPU algorithms (i.e., when device is set tocuda). Starting with 3.0, the default GPU implementation is similar to what the CPU version does. It also supports the use of ExtMemQuantileDMatrix when the hist tree method is employed (default). For a GPU device, the main memory is the device memory, whereas the external memory can be either a disk or the CPU memory. XGBoost stages the cache on CPU memory by default. Users can change the backing storage to disk by specifying the on_host parameter in the DataIter. However, using the disk is not recommended as it’s likely to make the GPU slower than the CPU. The option is here for experimentation purposes only. In addition,ExtMemQuantileDMatrix parameters min_cache_page_bytes, andmax_quantile_batches can help control the data placement and memory usage.

Inputs to the ExtMemQuantileDMatrix (through the iterator) must be on the GPU. Following is a snippet from Experimental support for external memory:

import cupy as cp import rmm from rmm.allocators.cupy import rmm_cupy_allocator

It's important to use RMM for GPU-based external memory to improve performance.

If XGBoost is not built with RMM support, a warning will be raised.

We use the pool memory resource here for simplicity, you can also try the

ArenaMemoryResource for # improved memory fragmentation handling.

mr = rmm.mr.PoolMemoryResource(rmm.mr.CudaAsyncMemoryResource()) rmm.mr.set_current_device_resource(mr)

Set the allocator for cupy as well.

cp.cuda.set_allocator(rmm_cupy_allocator)

Make sure XGBoost is using RMM for all allocations.

with xgboost.config_context(use_rmm=True): # Construct the iterators for ExtMemQuantileDMatrix # ... # Build the ExtMemQuantileDMatrix and start training Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins) # Use the training DMatrix as a reference Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train) booster = xgboost.train( { "tree_method": "hist", "max_depth": 6, "max_bin": n_bins, "device": device, }, Xy_train, num_boost_round=n_rounds, evals=[(Xy_train, "Train"), (Xy_valid, "Valid")] )

It’s crucial to use RAPIDS Memory Manager (RMM) with an asynchronous memory resource for all memory allocation when training with external memory. XGBoost relies on the asynchronous memory pool to reduce the overhead of data fetching. In addition, the open source NVIDIA Linux driveris required for Heterogeneous memory management (HMM) support. Usually, users need not to change ExtMemQuantileDMatrix parameters likemin_cache_page_bytes, they are automatically configured based on the device and don’t change model accuracy. However, the max_quantile_batches can be useful ifExtMemQuantileDMatrix is running out of device memory during construction, see QuantileDMatrix and the following sections for more info. Currently, we focus on devices with NVLink-C2C support for GPU-based external memory support.

In addition to the batch-based data fetching, the GPU version supports concatenating batches into a single blob for the training data to improve performance. For GPUs connected via PCIe instead of nvlink, the performance overhead with batch-based training is significant, particularly for non-dense data. Overall, it can be at least five times slower than in-core training. Concatenating pages can be used to get the performance closer to in-core training. This option should be used in combination with subsampling to reduce the memory usage. During concatenation, subsampling removes a portion of samples, reducing the training dataset size. The GPU hist tree method supports gradient-based sampling, enabling users to set a low sampling rate without compromising accuracy. Before 3.0, concatenation with subsampling was the only option for GPU-based external memory. After 3.0, XGBoost uses the regular batch fetching as the default while the page concatenation can be enabled by:

param = { "device": "cuda", "extmem_single_page": true, 'subsample': 0.2, 'sampling_method': 'gradient_based', }

For more information about the sampling algorithm and its use in external memory training, see this paper. Lastly, see following sections for best practices.

The newer NVIDIA platforms like Grace-Hopper use NVLink-C2C, which facilitates a fast interconnect between the CPU and the GPU. With the host memory serving as the data cache, XGBoost can retrieve data with significantly lower overhead. When the input data is dense, there’s minimal to no performance loss for training, except for the initial construction of the ExtMemQuantileDMatrix. The initial construction iterates through the input data twice, as a result, the most significant overhead compared to in-core training is one additional data read when the data is dense. Please note that there are multiple variants of the platform and they come with different C2C bandwidths. During initial development of the feature, we used the LPDDR5 480G version, which has about 350GB/s bandwidth for host to device transfer. When choosing the variant for training XGBoost models, one should pay extra attention to the C2C bandwidth.

To run experiments on these platforms, the open source NVIDIA Linux driverwith version >=565.47 is required, it should come with CTK 12.7 and later versions. Lastly, there’s a known issue with Linux 6.11 that can lead to CUDA host memory allocation failure with an invalid argument error.

Distributed Training

Distributed training is similar to in-core learning, but the work for framework integration is still on-going. See Experimental support for distributed training with external memoryfor an example for using the communicator to build a simple pipeline. Since users can define their custom data loader, it’s unlikely that existing distributed frameworks interface in XGBoost can meet all the use cases, the example can be a starting point for users who have custom infrastructure.

Best Practices

In previous sections, we demonstrated how to train a tree-based model with data residing on an external memory and made some recommendations for batch size. Here are some other configurations we find useful. The external memory feature involves iterating through data batches stored in a cache during tree construction. For optimal performance, we recommend using the grow_policy=depthwise setting, which allows XGBoost to build an entire layer of tree nodes with only a few batch iterations. Conversely, using the lossguide policy requires XGBoost to iterate over the data set for each tree node, resulting in significantly slower performance.

In addition, the hist tree method should be preferred over the approx tree method as the former doesn’t recreate the histogram bins for every iteration. Creating the histogram bins requires loading the raw input data, which is prohibitively expensive. TheExtMemQuantileDMatrix designed for the hist tree method can speed up the initial data construction and the evaluation significantly for external memory.

Since the external memory implementation focuses on training where XGBoost needs to access the entire dataset, only the X is divided into batches while everything else is concatenated. As a result, it’s recommended for users to define their own management code to iterate through the data for inference, especially for SHAP value computation. The size of SHAP matrix can be larger than the feature matrix X, making external memory in XGBoost less effective.

When external memory is used, the performance of CPU training is limited by disk IO (input/output) speed. This means that the disk IO speed primarily determines the training speed. Similarly, PCIe bandwidth limits the GPU performance, assuming the CPU memory is used as a cache and address translation services (ATS) is unavailable. During development, we observed that typical data transfer in XGBoost with PCIe4x16 has about 24GB/s bandwidth, which is significantly lower than the GPU processing performance. Whereas with a C2C-enabled machine, the performance of data transfer and processing in training are close to each other.

Running inference is much less computation-intensive than training and, hence, much faster. As a result, the performance bottleneck of inference is back to data transfer. For GPU, the time it takes to read the data from host to device completely determines the time it takes to run inference, even if a C2C link is available.

Xy_train = xgboost.ExtMemQuantileDMatrix(it_train, max_bin=n_bins) Xy_valid = xgboost.ExtMemQuantileDMatrix(it_valid, max_bin=n_bins, ref=Xy_train)

In addition, since the GPU implementation relies on asynchronous memory pool, which is subject to memory fragmentation even if the CudaAsyncMemoryResource is used. You might want to start the training with a fresh pool instead of starting training right after the ETL process. If you run into out-of-memory errors and you are convinced that the pool is not full yet (pool memory usage can be profiled with nsight-system), consider using the ArenaMemoryResource memory resource. Alternatively, using CudaAsyncMemoryResource in conjunction withBinningMemoryResource(mr, 21, 25) instead of the default PoolMemoryResource can be an option.

During CPU benchmarking, we used an NVMe connected to a PCIe-4 slot. Other types of storage can be too slow for practical usage. However, your system will likely perform some caching to reduce the overhead of the file read. See the following sections for remarks.

Compared to the QuantileDMatrix

Passing an iterator to the QuantileDMatrix enables direct construction of QuantileDMatrix with data chunks. On the other hand, if it’s passed to the DMatrix or theExtMemQuantileDMatrix, it instead enables the external memory feature. The QuantileDMatrix concatenates the data in memory after compression and doesn’t fetch data during training. On the other hand, the external memoryDMatrix (ExtMemQuantileDMatrix) fetches data batches from external memory on demand. Use the QuantileDMatrix (with iterator if necessary) when you can fit most of your data in memory. For many platforms, the training speed can be an order of magnitude faster than external memory.

Brief History

For a long time, external memory support has been an experimental feature and has undergone multiple development iterations. Here’s a brief summary of major changes:

Text File Inputs

Warning

This is the original form of external memory support before 1.5 and is now deprecated, users are encouraged to use a custom data iterator instead.

There is no significant difference between using the external memory version of text input and the in-memory version of text input. The only difference is the filename format.

The external memory version takes in the following URI format:

filename?format=libsvm#cacheprefix

The filename is the typical path to LIBSVM format file you want to load in, andcacheprefix is a path to a cache file that XGBoost will use for caching preprocessed data in binary form.

To load from csv files, use the following syntax:

filename.csv?format=csv&label_column=0#cacheprefix

where label_column should point to the csv column acting as the label.

If you have a dataset stored in a file similar to demo/data/agaricus.txt.train with LIBSVM format, the external memory support can be enabled by:

dtrain = DMatrix('../data/agaricus.txt.train?format=libsvm#dtrain.cache')

XGBoost will first load agaricus.txt.train in, preprocess it, then write to a new file nameddtrain.cache as an on disk cache for storing preprocessed data in an internal binary format. For more notes about text input formats, see Text Input Format of DMatrix.

For the CLI version, simply add the cache suffix, e.g. "../data/agaricus.txt.train?format=libsvm#dtrain.cache".