What is the xgb.plot_tree() function in Python?

XGBoost (eXtreme gradient boosting) is a well-known machine-learning module that employs gradient boosting, a powerful ensemble learning approach that combines the predictions of many weak learners (often decision trees) to produce a strong learner.

The xgb.plot_tree() function

The xgb.plot_tree() function is an invaluable tool that XGBoost provides for visualizing individual decision trees that make the ensemble.

Decision trees can become complex, and visualizing them can help us better comprehend the model's decision-making process, feature relevance, and possible overfittingOverfitting occurs when a machine learning model performs well on the training data but poorly on unseen data, indicating it has memorized the training set and lacks generalization..

Syntax

Here, we will show the basic syntax for the xgb.plot_tree() function:

xgb.plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs)
Syntax for xgb.plot_tree() method
  • booster is a required parameter representing the model (XGBRegressor or XGBClassifier) to be visualized.

  • fmap is the name of the feature map file.

  • num_trees represents the index of the tree to be plotted. The default; value is 0.

  • rankdir is an optional parameter representing the direction of the graph layout. The value can be "TB" for top-to-bottom or "LR" for left-to-right.

  • ax is an optional parameter representing the matplotlib axes object to plot the tree.

  • **kwargs is an optional parameter showing additional keyword arguments that can be passed to the plot function.

Note: Make sure you have the XGBoost library installed. Learn more about the error-free XGBoost installation on your system here.

Code

Let's look at a code example that implements the function xgb.plot_tree() given below:

import xgboost as xgb
from xgboost import plot_tree
import numpy as np
import matplotlib.pyplot as plt
#Creating a synthetic dataset
np.random.seed(42)
X = np.random.rand(100, 3)
y = np.random.randint(0, 2, 100)
#Creating an XGBoost classifier
model = xgb.XGBClassifier()
#Training the model on the dataset
model.fit(X, y)
#Visualizing the first decision tree in the ensemble
plot_tree(model, num_trees=0)
plt.show()

Code explanation

  • Line 1–2: Firstly, we import the xgb library and the plot_tree function to visualize decision trees.

  • Line 3–4: Next, we import the numpy library and the pyplot module from the matplotlib library.

  • Line 7–9: Now, we create a smaller synthetic dataset with 100 samples and 3 features for our convenience using random.rand() and random.randint() functions. The variable y is binary, having values 0 or 1.

  • Line 12: In this line, we create an XGBoost classifier with default hyperparameters and store it in the variable model.

  • Line 15: Moving on, we train the model on the entire synthetic dataset X and y.

  • Line 18: Now, we visualize the first decision tree in the ensemble using the plot_tree function. The parameter num_trees=0 specifies to plot the first tree in the ensemble.

  • Line 19: Finally, we display the plot using plt.show() on the console.

Output

Upon execution, the code will use plot_tree() method to visualize the first decision tree in the XGBoost ensemble model.

The output or decision tree looks like this:

The decision tree in the ensemble model
The decision tree in the ensemble model

In the plot above, we can see the tree's nodes reflect splitting conditions on certain features, while the leaves provide predicted class labels. This helps in understanding how the model makes decisions based on the features in the dataset.

Conclusion

Therefore, the XGBoost method xgb.plot_tree() is useful for visualizing decision trees in an ensemble model. Using this function, we can learn about the model's decision-making process, feature relevance, and potential overfitting. This improves the XGBoost model's understanding, making communication easier and allowing for enhanced model debugging and tuning.

Free Resources

Copyright ©2025 Educative, Inc. All rights reserved