RegressionPartitionedGP - Cross-validated Gaussian process regression (GPR) model - MATLAB (original) (raw)
Cross-validated Gaussian process regression (GPR) model
Since R2022b
Description
RegressionPartitionedGP
is a set of Gaussian process regression models trained on cross-validated folds. Estimate the quality of the cross-validated regression by using one or more kfold functions: kfoldPredict, kfoldLoss, and kfoldfun.
Every kfold object function uses models trained on training-fold (in-fold) observations to predict the response for validation-fold (out-of-fold) observations. For example, suppose you cross-validate using five folds. The software randomly assigns each observation into five groups of equal size (roughly). The training fold contains four of the groups (roughly 4/5 of the data), and the validation fold contains the other group (roughly 1/5 of the data). In this case, cross-validation proceeds as follows:
- The software trains the first model (stored in
CVMdl.Trained{1}
) by using the observations in the last four groups, and reserves the observations in the first group for validation. - The software trains the second model (stored in
CVMdl.Trained{2}
) by using the observations in the first group and the last three groups. The software reserves the observations in the second group for validation. - The software proceeds in a similar manner for the third, fourth, and fifth models.
If you validate by using kfoldPredict
, the software computes predictions for the observations in group i by using the_i_th model. In short, the software estimates a response for every observation by using the model trained without that observation.
Creation
You can create a RegressionPartitionedGP
object in two ways:
- Create a cross-validated model from a GPR model object RegressionGP by using the crossval object function.
- Create a cross-validated model by using the fitrgp function and specifying one of the name-value arguments
CrossVal
,CVPartition
,Holdout
,KFold
, orLeaveout
.
Regardless of whether you train a full or cross-validated GPR model first, you cannot specify an ActiveSet
value in the call tofitrgp
.
Properties
Cross-Validation Properties
This property is read-only.
Cross-validated model name, specified as 'GP'
.
Data Types: char
This property is read-only.
Number of cross-validated folds, specified as a positive integer.
Data Types: double
This property is read-only.
Cross-validation parameter values, specified as anEnsembleParams
object. The parameter values correspond to the values of the name-value arguments used to cross-validate the GPR model.ModelParameters
does not contain estimated parameters.
You can access the properties of ModelParameters
using dot notation.
This property is read-only.
Data partition indicating how the software splits the data into cross-validation folds, specified as a cvpartition model.
This property is read-only.
Compact models trained on cross-validation folds, specified as a cell array ofCompactRegressionGP model objects.Trained
has k cells, where_k_ is the number of folds.
Data Types: cell
Other Regression Properties
This property is read-only.
Categorical predictor indices, specified as a vector of positive integers. CategoricalPredictors
contains index values indicating that the corresponding predictors are categorical. The index values are between 1 and p
, where p
is the number of predictors used to train the model. If none of the predictors are categorical, then this property is empty ([]
).
Data Types: double
This property is read-only.
Number of observations in the training data stored in X
and Y
, specified as a numeric scalar.
Data Types: double
This property is read-only.
Predictor variable names, specified as a cell array of character vectors. The order of the elements in PredictorNames
corresponds to the order in which the predictor names appear in the training data.
Data Types: cell
This property is read-only.
Response variable name, specified as a character vector.
Data Types: char
Response transformation function, specified as 'none'
or a function handle. In general, ResponseTransform
describes how the software transforms raw response values.
For GPR model objects, ResponseTransform
is'none'
by default. Regardless of theResponseTransform
value, the software does not use a response transformation when making predictions.
Data Types: char
| function_handle
This property is read-only.
Observation weights, specified as an _n_-by-1 numeric vector, where n is the number of observations (NumObservations
). The software normalizes the observation weights so that the elements of W
sum to 1.
For GPR model objects, the software ignores observation weights when training a model or making predictions.
Data Types: double
This property is read-only.
Predictors used to cross-validate the model, specified as a numeric matrix or table.
Each row of X
corresponds to one observation, and each column corresponds to one variable.
Data Types: single
| double
| table
This property is read-only.
Response used to cross-validate the model, specified as a numeric vector.
Each row of Y
represents the observed response of the corresponding row of X
.
Data Types: single
| double
Object Functions
kfoldLoss | Loss for cross-validated partitioned regression model |
---|---|
kfoldPredict | Predict responses for observations in cross-validated regression model |
kfoldfun | Cross-validate function for regression |
Examples
Compute the cross-validation mean squared error of a Gaussian process regression (GPR) model. Compare the predicted response values to the true response values.
Simulate 1000 observations from the model y=1+0.05x+sin(x)/x+0.2ϵ where:
x
is a 1000-by-1 vector of evenly spaced values between –10 and 10.- ϵ is a 1000-by-1 vector of random normal errors with mean 0 and standard deviation 0.2.
rng("default"); % For reproducibility n = 1000; x = linspace(-10,10,n)'; y = 1 + 0.05x + sin(x)./x + 0.2randn(n,1);
Create a 5-fold cross-validated GPR model. Use a linear basis function, the exact fitting method to estimate model parameters, and the exact method to make predictions.
cvMdl = fitrgp(x,y,Basis="linear", ... FitMethod="exact",PredictMethod="exact", ... KFold=5);
cvMdl
is a RegressionPartitionedGP
object that contains five trained CompactRegressionGP
model objects (cvMdl.Trained
). Each of the five GPR models is trained using approximately 4/5 of the observations in x
.
Compute the average mean squared error.
Predict the response values using the cross-validated model. The predicted response values are the predictions on the holdout (validation) observations. In other words, the software obtains each prediction by using a model that was trained without the corresponding observation.
ypred = kfoldPredict(cvMdl);
Plot the true response values and the predicted response values for the cross-validated model.
plot(x,y,"."); hold on plot(x,ypred,"."); xlabel("x") ylabel("y") title("Cross-Validation Predictions") legend(["True","Predicted"]) hold off
The five CompactRegressionGP
models seem generally to agree, but some of the predictions differ close to the endpoints of the predictor data range (around –10 and 10).
You cannot use the cross-validated model directly to make predictions on new data. If you want to predict response values for a new data set, you can train a new GPR model using all the data in x
and then use the predict
object function. For example, predict response values for each even integer between –10 and 10.
mdl = fitrgp(x,y,Basis="linear", ... FitMethod="exact",PredictMethod="exact"); xnew = (-10:2:10)'; prednew = predict(mdl,xnew)
prednew = 11×1
0.5473
0.7012
0.6395
0.5945
1.3450
2.0073
1.5643
0.9842
1.2214
1.4962
1.5089
⋮
Alternatively, you can use the individual compact models in the Trained
property of the cross-validated model and then combine the predictions (for example, through averaging). For example, predict average response values for each even integer between –10 and 10.
preds = zeros(length(xnew),cvMdl.KFold); for i = 1:cvMdl.KFold preds(:,i) = predict(cvMdl.Trained{i},xnew); end meanpreds = mean(preds,2)
meanpreds = 11×1
0.5462
0.7012
0.6395
0.5949
1.3451
2.0067
1.5640
0.9851
1.2213
1.4963
1.5111
⋮
Version History
Introduced in R2022b