Stratified K Fold Cross Validation (original) (raw)
Last Updated : 15 Jul, 2025
**Stratified K-Fold Cross Validation is a technique used for evaluating a model. It is particularly useful for classification problems in which the class labels are not evenly distributed i.e data is imbalanced. It is a enhanced version of K-Fold Cross Validation. Key difference is that it uses **stratification which allows original distribution of each class to be maintained across each fold.
For example, if your original dataset had 80% Class 0 and 20% Class 1 your folds would reflect the same proportion of classes in your data. This creates improved and more reliable accuracy metrics.
Problem with Random Splitting
Random splitting techniques like train_test_split() or regular K-Fold can create problem if they produce imbalanced class proportions in the training and test sets. For example imagine a binary classification dataset with 100 samples where:
- 80 samples are Class 0
- 20 samples are Class 1
Using random sampling in an 80:20 split then all 80 Class 0 in the training set and all 20 Class 1 in the test set. In this case model will never learn to classify Class 1 and would give misleading accuracy.
Now, let’s use **stratified sampling on same dataset:
**1. Training Set (80 samples):
- 64 from Class 0 (80% of 80)
- 16 from Class 1 (80% of 20)
**2. Test Set (20 samples):
- 16 from Class 0 (20% of 80)
- 4 from Class 1 (20% of 20)
This ensures that both training and test sets provide an accurate representation of the full dataset's class proportions and better generalization in the evaluation set.
In real-world classification tasks distribution of observations per class is often imbalanced like in a medical dataset it could be the case that 90% of patients are healthy (Class 0) and 10% have a disease (Class 1). If we randomly split this data there may be some training/test sets that have very few sample or even no samples for the minority class that where Stratified K Fold Cross Validation becomes important.
**Implementation of Stratified K-Fold Cross-Validation
1. Importing Required Libraries
We will be using statistics and scikit learn module.
Python `
from statistics import mean, stdev from sklearn import preprocessing from sklearn.model_selection import StratifiedKFold from sklearn import linear_model from sklearn import datasets
`
2. Loading Dataset and Extracting Features
Here we will be using breast cancer dataset available in scikit learn.
x = cancer.data: feature/input valuesy = cancer.target: output/class labels (0 or 1) Python `
cancer = datasets.load_breast_cancer()
x = cancer.data y = cancer.target
`
3. Feature Scaling (Normalization)
MinMaxScaler(): scales features to a range between 0 and 1fit_transform(x): fits scaler on data and applies transformation Python `
scaler = preprocessing.MinMaxScaler() x_scaled = scaler.fit_transform(x)
`
4. Model and K-Fold Object Setup
Here we will be using logistic regression model.
StratifiedKFold(...): sets up 10-fold stratified cross-validationlst_accu_stratified: empty list to store accuracy scores Python `
lr = linear_model.LogisticRegression()
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=1) lst_accu_stratified = []
`
5. Applying Stratified K-Fold and Training Model
skf.split(x, y): splits dataset into stratified train-test indicesx_train_fold, x_test_fold: features for training and testingy_train_fold, y_test_fold: labels for training and testing Python `
for train_index, test_index in skf.split(x, y): x_train_fold, x_test_fold = x_scaled[train_index], x_scaled[test_index] y_train_fold, y_test_fold = y[train_index], y[test_index] lr.fit(x_train_fold, y_train_fold) lst_accu_stratified.append(lr.score(x_test_fold, y_test_fold))
`
6. Printing Accuracy Results
max(): highest accuracymin(): lowest accracymean(): average accuracy Python `
print('List of possible accuracy:', lst_accu_stratified) print('\nMaximum Accuracy That can be obtained from this model is:', max(lst_accu_stratified)*100, '%') print('\nMinimum Accuracy:', min(lst_accu_stratified)*100, '%') print('\nOverall Accuracy:', mean(lst_accu_stratified)*100, '%') print('\nStandard Deviation is:', stdev(lst_accu_stratified))
`
**Output:

Here we can see that we got a overall accuracy of 96.6% and standard deviation of 0.02 which means our model is working fine.
By using Stratified K-Fold Cross Validation we can ensure that our machine learning model is evaluated fairly and consistently leading to more accurate predictions and better real-world performance.