Pruning Decision Trees (original) (raw)

Last Updated : 2 Feb, 2026

Decision trees, when allowed to grow freely, tend to learn noise and very specific patterns from the training data, leading to overfitting. Pruning addresses this issue by simplifying the tree structure, improving generalization to unseen data, enhancing interpretability and reducing computational cost, while maintaining or even improving overall model accuracy.

Decision Tree Pruning is a model optimization technique used to control the growth of decision tree models by removing unnecessary branches and nodes that do not contribute significantly to predictive performance.

Types of Decision Tree Pruning

Decision tree pruning techniques are broadly classified into two categories:

1. Pre-Pruning (Early Stopping)

Pre-pruning is also known as early stopping, is a pruning strategy in which the growth of the decision tree is restricted during the training phase itself. Instead of allowing the tree to grow fully and then trimming it later, pre-pruning prevents certain splits from being created if they do not satisfy predefined constraints.

Working

Common Pre-Pruning Techniques

**1. Maximum Depth

**2. Minimum Samples per Split

**3. Minimum Samples per Leaf

**4. Maximum Features

2. Post-Pruning (Pruning After Full Growth)

Post-pruning is a pruning strategy in which the decision tree is allowed to grow to its full depth first, after which unnecessary or weak branches are removed. Unlike pre-pruning, this approach does not restrict the tree during training. Instead, it analyzes the fully grown tree and evaluates whether certain subtrees contribute meaningfully to predictive performance.

Working

Common Post-Pruning Techniques

**1. Cost-Complexity Pruning (CCP)

**2. Reduced Error Pruning

**3. Minimum Impurity Decrease

**4. Minimum Leaf Size

Implementation

Let's see the implementation using the Breast cancer dataset from scikit-learn.

Step 1: Import Libraries and Load Dataset

We need to import the required libraries and load the dataset from scikit-learn library.

Python `

from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.tree import DecisionTreeClassifier, plot_tree import matplotlib.pyplot as plt

`

Step 2: Split the Dataset

X, y = load_breast_cancer(return_X_y=True)

X_train, X_test, y_train, y_test = train_test_split( X, y, train_size=0.2, random_state=42 )

`

Step 3: Train the Original (Unpruned) Decision Tree

model = DecisionTreeClassifier(criterion="gini") model.fit(X_train, y_train)

`

**Output:

Screenshot-2025-12-18-154847

Model Training

Step 4: Visualize the Original Decision Tree

plt.figure(figsize=(20, 12)) plot_tree( model, filled=True, fontsize=11 ) plt.title("Original Decision Tree", fontsize=18) plt.show()

`

**Output:

download

Output

Step 5: Model Accuracy Before Pruning

Here we evaluates baseline model performance

Python `

accuracy_before_pruning = model.score(X_test, y_test) print("Accuracy before pruning:", accuracy_before_pruning)

`

**Output:

Accuracy before pruning: 0.8947368421052632

Step 6: Hyperparameter Grid and GridSearchCV

from sklearn.model_selection import GridSearchCV

parameters = { 'criterion': ['gini', 'entropy', 'log_loss'], 'splitter': ['best', 'random'], 'max_depth': [1, 2, 3, 4, 5], 'max_features': ['sqrt', 'log2'] }

dt = DecisionTreeClassifier() cv = GridSearchCV(dt, param_grid=parameters, cv=5) cv.fit(X_train, y_train)

`

**Output:

Screenshot-2025-12-18-154833

GridSearchCV

Step 7: Evaluate the Pre-Pruned Model

print("Best Accuracy:", cv.score(X_test, y_test)) print("Best Parameters:", cv.best_params_)

`

**Output:

Best Accuracy: 0.9276315789473685
Best Parameters: {'criterion': 'entropy', 'max_depth': 4, 'max_features': 'log2', 'splitter': 'best'}

Step 8: Compute Pruning Path

path = model.cost_complexity_pruning_path(X_train, y_train) ccp_alphas = path.ccp_alphas

`

Step 9: Train Pruned Models

pruned_models = []

for alpha in ccp_alphas: pruned_model = DecisionTreeClassifier( criterion="gini", ccp_alpha=alpha ) pruned_model.fit(X_train, y_train) pruned_models.append(pruned_model)

`

Step 10: Select the Best Pruned Model

best_accuracy = 0 best_pruned_model = None

for m in pruned_models: acc = m.score(X_test, y_test) if acc > best_accuracy: best_accuracy = acc best_pruned_model = m

print("Accuracy after pruning:", best_accuracy)

`

**Output:

Accuracy after pruning: 0.9166666666666666

Step 11: Visualize the Pruned Decision Tree

plt.figure(figsize=(22, 14)) plot_tree( best_pruned_model, filled=True, fontsize=11 ) plt.title("Pruned Decision Tree", fontsize=18) plt.show()

`

**Output:

download

Output

Advantages