predict - Predict responses using generalized additive model (GAM) - MATLAB (original) (raw)
Predict responses using generalized additive model (GAM)
Since R2021a
Syntax
Description
[yFit](#mw%5F6cea463c-eadc-4021-8a0d-cd70f6b98ead) = predict([Mdl](#mw%5F3ccccf12-5495-4e7a-80c0-21ed8e07158e),[X](#mw%5F4b4c8dfd-587e-4d7f-8c19-2c3955c887fc%5Fsep%5Fshared-X))
returns a vector of predicted responses for the predictor data in the table or matrixX
, based on the generalized additive model Mdl
for regression. The trained model can be either full or compact.
[yFit](#mw%5F6cea463c-eadc-4021-8a0d-cd70f6b98ead) = predict([Mdl](#mw%5F3ccccf12-5495-4e7a-80c0-21ed8e07158e),[X](#mw%5F4b4c8dfd-587e-4d7f-8c19-2c3955c887fc%5Fsep%5Fshared-X),[Name,Value](#namevaluepairarguments))
specifies options using one or more name-value arguments. For example,'IncludeInteractions',true
specifies to include interaction terms in computations.
[[yFit](#mw%5F6cea463c-eadc-4021-8a0d-cd70f6b98ead),[ySD](#mw%5Ff8d3cf87-97dd-4641-8def-c74966936f09),[yInt](#mw%5F7c755c3a-6ab4-4aa4-a441-c0640ba8a47b)] = predict(___)
also returns the standard deviations and prediction intervals of the response variable, evaluated at each observation in the predictor data X, using any of the input argument combinations in the previous syntaxes. This syntax is valid only when you specify 'FitStandardDeviation' of fitrgam
astrue
for training Mdl and the IsStandardDeviationFit property of Mdl
istrue
.
Examples
Train a generalized additive model using training samples, and then predict the test sample responses.
Load the patients
data set.
Create a table that contains the predictor variables (Age
, Diastolic
, Smoker
, Weight
, Gender
, SelfAssessedHealthStatus
) and the response variable (Systolic
).
tbl = table(Age,Diastolic,Smoker,Weight,Gender,SelfAssessedHealthStatus,Systolic);
Randomly partition observations into a training set and a test set. Specify a 10% holdout sample for testing.
rng('default') % For reproducibility cv = cvpartition(size(tbl,1),'HoldOut',0.10);
Extract the training and test indices.
trainInds = training(cv); testInds = test(cv);
Train a univariate GAM that contains the linear terms for the predictors in tbl
.
Mdl = fitrgam(tbl(trainInds,:),'Systolic')
Mdl = RegressionGAM PredictorNames: {'Age' 'Diastolic' 'Smoker' 'Weight' 'Gender' 'SelfAssessedHealthStatus'} ResponseName: 'Systolic' CategoricalPredictors: [3 5 6] ResponseTransform: 'none' Intercept: 122.7444 IsStandardDeviationFit: 0 NumObservations: 90
Properties, Methods
Mdl
is a RegressionGAM
model object.
Predict responses for the test set.
yFit = predict(Mdl,tbl(testInds,:));
Create a table containing the observed response values and the predicted response values.
table(tbl.Systolic(testInds),yFit, ... 'VariableNames',{'Observed Value','Predicted Value'})
ans=10×2 table Observed Value Predicted Value ______________ _______________
124 126.58
121 123.95
130 116.72
115 117.35
121 117.45
116 118.5
123 126.16
132 124.14
125 127.36
124 115.99
Predict responses for new observations using a generalized additive model that contains both linear and interaction terms for predictors. Use a memory-efficient model object, and specify whether to include interaction terms when predicting responses.
Load the carbig
data set, which contains measurements of cars made in the 1970s and early 1980s.
Specify Acceleration
, Displacement
, Horsepower
, and Weight
as the predictor variables (X
) and MPG
as the response variable (Y
).
X = [Acceleration,Displacement,Horsepower,Weight]; Y = MPG;
Partition the data set into two sets: one containing training data, and the other containing new, unobserved test data. Reserve 10 observations for the new test data set.
rng('default') n = size(X,1); newInds = randsample(n,10); inds = ~ismember(1:n,newInds); XNew = X(newInds,:); YNew = Y(newInds);
Train a GAM that contains all the available linear and interaction terms in X
.
Mdl = fitrgam(X(inds,:),Y(inds),'Interactions','all');
Mdl
is a RegressionGAM
model object.
Conserve memory by reducing the size of the trained model.
CMdl = compact(Mdl); whos('Mdl','CMdl')
Name Size Bytes Class Attributes
CMdl 1x1 1255766 classreg.learning.regr.CompactRegressionGAM
Mdl 1x1 1289882 RegressionGAM
CMdl
is a CompactRegressionGAM
model object.
Predict the responses using both linear and interaction terms, and then using only linear terms. To exclude interaction terms, specify 'IncludeInteractions',false
.
yFit = predict(CMdl,XNew); yFit_nointeraction = predict(CMdl,XNew,'IncludeInteractions',false);
Create a table containing the observed response values and the predicted response values.
t = table(YNew,yFit,yFit_nointeraction, ... 'VariableNames',{'Observed Response', ... 'Predicted Response','Predicted Response Without Interactions'})
t=10×3 table Observed Response Predicted Response Predicted Response Without Interactions _________________ __________________ _______________________________________
27.9 23.04 23.649
NaN 37.163 35.779
NaN 25.876 21.978
13 12.786 14.141
36 28.889 27.281
19.9 22.199 18.451
24.2 23.995 24.885
12 14.247 13.982
38 33.797 33.528
13 12.225 11.127
Train a generalized additive model (GAM), and then compute and plot the prediction intervals of response values.
Load the patients
data set.
Create a table that contains the predictor variables (Age
, Diastolic
, Smoker
, Weight
, Gender
, SelfAssessedHealthStatus
) and the response variable (Systolic
).
tbl = table(Age,Diastolic,Smoker,Weight,Gender,SelfAssessedHealthStatus,Systolic);
Train a univariate GAM that contains the linear terms for the predictors in tbl
. Specify the FitStandardDeviation
name-value argument as true
so that you can use the trained model to compute prediction intervals. A recommended practice is to use optimal hyperparameters when you fit the standard deviation model for the accuracy of the standard deviation estimates. Specify 'OptimizeHyperparameters'
as 'all-univariate'
. For reproducibility, use the 'expected-improvement-plus'
acquisition function. Specify 'ShowPlots'
as false
and 'Verbose'
as 0 to disable plot and message displays, respectively.
rng('default') % For reproducibility Mdl = fitrgam(tbl,'Systolic','FitStandardDeviation',true, ... 'OptimizeHyperparameters','all-univariate', ... 'HyperparameterOptimizationOptions',struct('AcquisitionFunctionName','expected-improvement-plus', ... 'ShowPlots',false,'Verbose',0))
Mdl = RegressionGAM PredictorNames: {'Age' 'Diastolic' 'Smoker' 'Weight' 'Gender' 'SelfAssessedHealthStatus'} ResponseName: 'Systolic' CategoricalPredictors: [3 5 6] ResponseTransform: 'none' Intercept: 122.7800 IsStandardDeviationFit: 1 NumObservations: 100 HyperparameterOptimizationResults: [1×1 BayesianOptimization]
Properties, Methods
Mdl
is a RegressionGAM
model object that uses the best estimated feasible point. The best estimated feasible point indicates the set of hyperparameters that minimizes the upper confidence bound of the objective function value based on the underlying objective function model of the Bayesian optimization process. For more details on the optimization process, see Optimize GAM Using OptimizeHyperparameters.
Predict responses for the training data in tbl
, and compute the 99% prediction intervals of the response variable. Specify the significance level ('Alpha'
) as 0.01 to set the confidence level of the prediction intervals to 99%.
[yFit,~,yInt] = predict(Mdl,tbl,'Alpha',0.01);
Plot the sorted true responses together with the predicted responses and prediction intervals.
figure yTrue = tbl.Systolic; [sortedYTrue,I] = sort(yTrue); plot(sortedYTrue,'o') hold on plot(yFit(I)) plot(yInt(I,1),'k:') plot(yInt(I,2),'k:') legend('True responses','Predicted responses', ... 'Prediction interval limits','Location','best') hold off
Input Arguments
Data Types: table
| double
| single
Name-Value Arguments
Specify optional pairs of arguments asName1=Value1,...,NameN=ValueN
, where Name
is the argument name and Value
is the corresponding value. Name-value arguments must appear after other arguments, but the order of the pairs does not matter.
Before R2021a, use commas to separate each name and value, and enclose Name
in quotes.
Example: 'Alpha',0.01,'IncludeInteractions',false
specifies the confidence level as 99% and excludes interaction terms from computations.
Significance level for the confidence level of the prediction intervalsyInt, specified as a numeric scalar in the range[0,1]
. The confidence level of yInt
is equal to 100(1 – Alpha)%
.
This argument is valid only when the IsStandardDeviationFit property of Mdl istrue
. Specify the 'FitStandardDeviation' name-value argument offitrgam
as true
to fit the model for the standard deviation.
Example: 'Alpha',0.01
specifies to return 99% prediction intervals.
Data Types: single
| double
Output Arguments
Predicted responses, returned as a column vector of length n, where n is the number of observations in the predictor dataX.
Standard deviations of the response variable, evaluated at each observation in the predictor data X, returned as a column vector of length_n_, where n is the number of observations inX
. The i
th element ySD(i)
contains the standard deviation of the i
th response for thei
th observation `X`(i,:)
, estimated using the trained standard deviation model in Mdl.
This argument is valid only when the IsStandardDeviationFit property of Mdl
istrue
. Specify the 'FitStandardDeviation' name-value argument of fitrgam
as true
to fit the model for the standard deviation.
Prediction intervals of the response variable, evaluated at each observation in the predictor data X, returned as an _n_-by-2 matrix, where n is the number of observations in X
. Thei
th row yInt(i,:)
contains the100(1–[Alpha](#mw%5F23dc46ae-309f-469a-ab07-153f332b40d9))%
prediction interval of thei
th response for the i
th observation`X`(i,:)
. The Alpha
value is the probability that the prediction interval does not contain the true response value for `X`(i,:)
. The first column ofyInt
contains the lower limits of the prediction intervals, and the second column contains the upper limits.
This argument is valid only when the IsStandardDeviationFit property of Mdl istrue
. Specify the 'FitStandardDeviation' name-value argument of fitrgam
as true
to fit the model for the standard deviation.
Algorithms
predict
returns the predicted responses (yFit) and, optionally, the standard deviations (ySD) and prediction intervals (yInt) of the response variable, estimated at each observation inX.
A Generalized Additive Model (GAM) for Regression assumes that the response variable y follows the normal distribution with mean μ and standard deviation σ. If you specify'FitStandardDeviation' of fitrgam
asfalse
(default), then fitrgam
trains a model for_μ_. If you specify 'FitStandardDeviation'
astrue
, then fitrgam
trains an additional model for_σ_ and sets the IsStandardDeviationFit
property of the GAM object to true
. The outputs yFit
andySD
correspond to the estimated mean μ and standard deviation σ, respectively.
Version History
Introduced in R2021a