Home/Blog/Data Science/Scikit-learn decision tree: A step-by-step guide
Home/Blog/Data Science/Scikit-learn decision tree: A step-by-step guide

Scikit-learn decision tree: A step-by-step guide

Mehwish Fatima
May 02, 2024
11 min read

In this blog, we will understand how to implement decision trees in Python with the scikit-learn library. We’ll go over decision trees’ features one by one. Decision trees are useful tools for categorization problems. We’ll use the famous wine dataset, a classic for multi-class classification, to demonstrate the process. This dataset is perfect for illustrating decision trees because it includes a variety of wine qualities divided into three classes.

Introduction#

Decision trees are considered a fundamental tool in machine learning. They provide logical insights into complex datasets. A decision tree is a non-parametric supervised learning algorithm used for both classification and regression problems. It has a hierarchical tree structure with a root node, branches, internal nodes, and leaf nodes.

Let’s first understand what a decision tree is and then go into the coding related details.

Decision trees terminologies #

Before going deeper into the topic of decision trees, let’s familiarize ourselves with some terminologies, as mentioned in the illustration below:

  • Root node: The root node is the beginning point of a decision tree where the whole dataset starts to divide based on different features present in the dataset.

  • Decision nodes: Nodes with children nodes represent a decision to be taken. The root node (if having children) is also a decision node.

  • Leaf nodes: Nodes that indicate the final categorization or result when additional splitting is not possible. Terminal nodes are another name for leaf nodes.

  • Branches or subtrees: A branch or subtree is a component of the decision tree that is part of the larger structure. Within the tree, it symbolizes a certain decision-making and result-oriented path.

  • Pruning: It is the practice of eliminating or chopping down particular decision tree nodes to simplify the model and avoid overfitting.

  • Parent and child nodes: In a decision tree, a node that can be divided is called a parent node, and nodes that emerge from it are called its child nodes. The parent node represents a decision or circumstance, and the child nodes represent possible outcomes or additional decisions based on that situation.

A decision tree
A decision tree

Decision trees in scikit-learn#

We have understood the basic concept of decision trees. Now, with scikit-learn’s help, we explore how decision trees work.

The dataset#

We use the wine dataset, a classic for multi-class classification. Let’s explore the dataset:

import pandas as pd
from sklearn.datasets import load_wine
data = load_wine() # Loading dataset
wine = pd.DataFrame(data['data'], columns = data['feature_names']) # Converting data to a Data Frame to view properly
wine['target'] = pd.Series(data['target'], name = 'target_values') # Configuring pandas to show all features
pd.set_option("display.max_rows", None, "display.max_columns", None)
print(wine.head())
print("Total number of observations: {}".format(len(wine)))

Code explanation#

Let’s review the code:

  • Line 3: We load the wine dataset in a variable named data.

  • Lines 4–5: We convert the unordered data of the wine dataset to a pandas DataFrame. We add the target values to the DataFrame to better understand and view the pandas DataFrame.

  • Line 6: We set the length and width of the pandas’ DataFrame to its maximum to provide a better view to the reader.

  • Line 7: We print the first five observations of the wine dataset by using the head() method.

  • Line 8: We print the total number of observations that are 178.

The target#

Let’s explore the target values to find how many classes we have in this dataset:

print(wine['target'].head())
shuffled = wine.sample(frac=1, random_state=1).reset_index()
print(shuffled['target'].head())

Code explanation#

Let’s review the code:

  • Line 1: We print only the target of the wine dataset for the first five observations. We observe that there is only one class: 0.

  • Line 2: We shuffle the dataset to add randomization in observation placements.

  • Line 3: We again print only the target of the wine dataset for the first five observations. Now, we observe three classes: 0, 1, 2.

Let’s sum up the properties (according to the official website) of the wine dataset:

Properties of the Wine Dataset

Classes

3

Sample per class

[59,71,48]

Total observations

178

Total features

13

Step-by-step guide to decision trees#

Let’s break down the decision tree algorithm into simple steps for the wine dataset.

We will predict the wine class based on its given features. The root node represents all the instances of the dataset. At the root, we have the color_intensity feature. The decision tree algorithm follows a branch and advances to the next node based on the decision taken at the root. At this level, we have two different features—proline and flavonoids. The algorithm proceeds to the next node by comparing its attribute value with the other sub-nodes. It keeps doing this till it gets to the tree’s leaf node.

A decision tree created by the sample code below
A decision tree created by the sample code below

The following algorithm can help you better understand the entire process:

  • Begin with the root node:
    The root node symbolizes the whole dataset of wines—this is where the algorithm starts.

  • Find the best attribute:
    We have several wine characteristics—such as acidity, alcohol percentage, and so forth. These characteristics help to determine which is most useful for dividing wines into their appropriate groups—such as wine varieties. We determine the best attribute to split the dataset using attribute selection measures (ASM) like information gain or Gini index. This attribute should maximize the information gain or minimize impurity.

  • Divide the dataset: 
    The algorithm separates the dataset into smaller subsets, each comprising wines with comparable qualities based on the selected attribute’s possible values.

  • Generate decision tree nodes:
    The algorithm adds a new node to the tree at each stage to represent the selected attribute. These nodes direct the algorithm to the following stage as decision points.

  • Recursive tree building:
    The algorithm recursively repeats this process until it cannot further divide the dataset, adding new branches and nodes. These latter nodes—leaf nodes—stand for the anticipated wine categories.

Implementation#

Let’s apply this algorithm to the wine dataset, which contains attributes of wine samples categorized into three classes [class_0, class_1, class_2]. We’ll use Python’s scikit-learn library for implementing the decision tree classifier. The decision rule for classifying wines into particular classes using decision trees is determined based on the attribute values of the wine characteristics. For example, a decision rule could be that wines with certain levels of acidity, alcohol percentage, and color intensity belong to class_0, while wines with different attribute values belong to class_1 or class_2. These decision rules are learned during the training process of the decision tree algorithm based on the patterns and relationships found in the dataset.

from sklearn.datasets import load_wine
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, plot_tree
import matplotlib.pyplot as plt
wine = load_wine()
X = wine.data
y = wine.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Splitting the dataset into training and testing sets
clf = DecisionTreeClassifier() # Initialize the decision tree classifier
clf.fit(X_train, y_train) # Fitting the classifier on the training data
# Plot the decision tree
plt.figure(figsize=(15, 10))
plot_tree(clf, filled=True, feature_names=wine.feature_names, class_names=wine.target_names)
plt.savefig("output/plot.png", bbox_inches='tight')
plt.show()

Run the code, and after the execution, click the plot to zoom in.

Code explanation#

Let’s review the code:

  • Lines 1–4: We import the relevant libraries from scikit-learn, including functions for loading datasets, splitting data into training and testing sets, decision tree classifiers, and plotting decision trees.

  • Lines 6–8: We load the wine dataset using the load_wine function and assign the feature data to X and the target labels to y.

  • Line 10: We split the dataset into training and testing sets using the train_test_split function, where 80% of the data is used for training (X_train and y_train) and 20% for testing (X_test and y_test). The random_state parameter ensures the reproducibility of the split.

  • Line 11: We initialize a decision tree classifier (clf) without specifying any hyperparameters.

  • Line 12: We fit the decision tree classifier (clf) to the training data (X_train and y_train) using the fit method.

  • Line 14: We create a figure for plotting the decision tree with a specific size using plt.figure.

  • Line 15: We use the plot_tree function to visualize the decision tree (clf). We set filled=True to fill the decision tree nodes with colors representing the majority class. We specify the feature and class names for better visualization.

  • Line 16: We save the plotted decision tree as an image file named plot.png in the output directory using plt.savefig.

  • Line 17: We display the plotted decision tree using plt.show().

Interpretation#

The decision tree model classifies instances into different classes based on the selected attributes and decision rules learned during training. At each node of the tree, the model evaluates the value of a specific attribute and decides to split the data into two or more subsets. This splitting continues recursively until the algorithm determines that further division is not beneficial or until certain stopping criteria are met. Each leaf node represents a final classification or outcome, indicating the predicted class for instances that reach that node.

Information gain (IG) or Gini index#

Information gain (IG) and the Gini index play crucial roles in the decision-making process of the decision tree algorithm. IG measures the effectiveness of a particular attribute in classification data by quantifying the reduction in entropy (uncertainty) about the classification of data points after splitting them based on the attribute. While Gini index measures the impurity or homogeneity of a dataset, indicating the likelihood that a randomly selected element in the dataset would be erroneously classified if its label were randomly assigned based on the distribution of labels in the dataset. These metrics help the algorithm determine which attribute to select for splitting at each node, aiming to maximize the information gain or minimize impurity in the resulting subsets.

Decision rule#

The decision tree algorithm selects the attribute with the highest IG or lowest Gini index at each node to make splitting decisions. This process involves evaluating all available attributes and calculating their IG or Gini index. The highest IG or lowest Gini index attribute is the best attribute for splitting the dataset at that node. By selecting attributes that maximize IG or minimize impurity, the algorithm aims to create subsets that are as pure and informative as possible, facilitating accurate classification. This iterative process helps the decision tree algorithm learn decision rules that effectively partition the data and classify instances into the correct classes based on the attributes’ values.

Advantages of decision trees#

Here are some advantages of decision trees:

  • Simplicity: Decision trees are easy to comprehend as they closely resemble how humans make decisions. Even nonexperts can use them because of their ease.

  • Flexible problem-solving: Decision trees are adaptable instruments that may be used to solve various decision-related issues in various industries, including healthcare and finance.

  • Easy outcome analysis: Decision trees allow methodically investigating every scenario and its ramifications by examining every conceivable outcome for a given situation.

  • Less data cleaning: Decision trees usually require less preprocessing and data cleaning than other machine learning algorithms, saving time and effort in the data preparation process.

Disadvantages of decision trees#

Let’s discuss some drawbacks of decision trees:

  • Layering complexity: Decision trees have the potential to branch out into several levels as they get larger. Because of its complexity, the model’s judgments may be difficult to understand and interpret.

  • Risk of overfitting: Decision trees can overfit, which causes them to identify noise or unimportant patterns in the training set, which impairs their ability to generalize to new data. This problem can be lessened using strategies like random forests, which combine several decision trees.

  • Computational complexity: Working with datasets with many class labels can lead to computationally expensive decision trees. This complexity can impact training and prediction times, requiring additional computational resources.

Conclusion#

We have explored the scikit-learn library to create decision trees in this blog. Decision trees are useful tools that offer logical insights into complicated information and help solve categorization challenges. To clarify the hierarchical structure and functionalities of decision trees, we examined the important related terms, such as branches, decision nodes, leaf nodes, and root nodes.

We discussed the benefits of decision trees, including their adaptability, simplicity, and ease of outcome analysis, making them appropriate for various decision-related issues in various businesses. However, we also pointed out certain shortcomings that must be addressed for best results, including layering complexity, the possibility of overfitting, and computing complexity.

In addition, we covered a comprehensive tutorial on decision trees, decomposing the technique into manageable steps and utilizing it on the wine dataset—a well-known example of multi-class classification. We illustrated how to view and decipher the decision tree’s structure by putting the decision tree classifier into practice with scikit-learn.

Interested in diving deeper into machine learning and decision trees? Explore our curated courses that delve into these fascinating topics. Whether you’re just starting or looking to enhance your skills, these courses will provide you with the knowledge to master these concepts.


  

Free Resources