GitHub - davechallis/rust-xgboost: Rust bindings for XGBoost. (original) (raw)

rust-xgboost

Travis Build Status Documentation link

Rust bindings for the XGBoost gradient boosting library.

Basic usage example:

extern crate xgboost;

use xgboost::{parameters, DMatrix, Booster};

fn main() { // training matrix with 5 training examples and 3 features let x_train = &[1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0]; let num_rows = 5; let y_train = &[1.0, 1.0, 1.0, 0.0, 1.0];

// convert training data into XGBoost's matrix format
let mut dtrain = DMatrix::from_dense(x_train, num_rows).unwrap();

// set ground truth labels for the training matrix
dtrain.set_labels(y_train).unwrap();

// test matrix with 1 row
let x_test = &[0.7, 0.9, 0.6];
let num_rows = 1;
let y_test = &[1.0];
let mut dtest = DMatrix::from_dense(x_test, num_rows).unwrap();
dtest.set_labels(y_test).unwrap();

// configure objectives, metrics, etc.
let learning_params = parameters::learning::LearningTaskParametersBuilder::default()
    .objective(parameters::learning::Objective::BinaryLogistic)
    .build().unwrap();

// configure the tree-based learning model's parameters
let tree_params = parameters::tree::TreeBoosterParametersBuilder::default()
        .max_depth(2)
        .eta(1.0)
        .build().unwrap();

// overall configuration for Booster
let booster_params = parameters::BoosterParametersBuilder::default()
    .booster_type(parameters::BoosterType::Tree(tree_params))
    .learning_params(learning_params)
    .verbose(true)
    .build().unwrap();

// specify datasets to evaluate against during training
let evaluation_sets = &[(&dtrain, "train"), (&dtest, "test")];

// overall configuration for training/evaluation
let params = parameters::TrainingParametersBuilder::default()
    .dtrain(&dtrain)                         // dataset to train with
    .boost_rounds(2)                         // number of training iterations
    .booster_params(booster_params)          // model parameters
    .evaluation_sets(Some(evaluation_sets)) // optional datasets to evaluate against in each iteration
    .build().unwrap();

// train model, and print evaluation data
let bst = Booster::train(&params).unwrap();

println!("{:?}", bst.predict(&dtest).unwrap());

}

See the examples directory for more detailed examples of different features.

Status

Currently in a very early stage of development, so the API is changing as usability issues occur, or new features are supported.

Builds against XGBoost 0.81.

Platforms

Tested:

Unsupported: