What is a scikit learn decision tree?

Sci-kit learn is one of the most popular machine learning libraries available. It contains algorithms in both supervised domains (algorithms that use labeled datasets) and unsupervised domains (algorithms that do not use pre-labeled datasets).

In this Answer, we will discuss decision tree-based algorithms, how they work, and their implementation in the sci-kit learn library.

What is a decision tree?

A decision tree is a method that uses conditions or if-else statements to arrive at a final decision. In machine learning, decision tree algorithms are used for both classification and regression problems and fall in the supervised machine learning category.

How decision trees work

A decision tree starts with a single node, called the root node, and then splits into two or more further nodes, called decision nodes. Methods such as Entropy, Information Gain, and Gini Index (among others) are used to determine the main feature to be employed at the root node to split the data. These methods use different criteria, depending on the types of features available and the weight each feature carries.

%0 node_1 Root Node node_2 Decision Node node_1->node_2 node_3 Decision Node node_1->node_3 node_1681297915964 Decision Node node_2->node_1681297915964 node_1681297853262 Terminal Node node_2->node_1681297853262 node_1681297862469 Terminal Node node_1681297915964->node_1681297862469 node_1681297928932 Terminal Node node_1681297915964->node_1681297928932 node_1681297928429 Terminal Node node_3->node_1681297928429 node_1681297723202 Terminal Node node_3->node_1681297723202
Example of a decision tree

Decision nodes split further while leaf nodes do not. The last nodes that do not split further are referred to as leaf nodes, terminal nodes, or end nodes, and they represent a final outcome. Since decision trees can become very complex, especially if there are many different options, an important concept to avoid complexity and, therefore, curb overfitting is called pruning. Pruning cuts down the decision tree and prevents the model from paying attention to irrelevant information in the data set.

There are a few important terminologies used in this Answer. Let’s have a look at what they mean.

  • Root node: This is the node where the decision tree starts splitting.
  • Terminal/leaf/end node: These are nodes that do not split any further. They indicate that no further decisions need to be made.
  • Decision nodes: These are nodes that split further in the tree.
  • Splitting/branching: This refers to the process of the decision tree dividing into 2 or more nodes.
  • Pruning: This refers to a decision tree “cutting” or reducing the number of nodes. It’s an important mechanism used to reduce overfitting. In sci-kit learn, certain hyperparameters, such as max_depth, are used to pre-prune the decision tree.

Example of a real-life decision tree

Let’s assume we want to decide whether or not we need to carry a sweater depending on the weather. A simple decision tree can be used to weigh the options and make a decision.

A real-life example of a decision tree
A real-life example of a decision tree

Code example

Let's see a code example:

from sklearn.datasets import make_classification
from sklearn import tree
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
X, y = make_classification(
n_samples=100,
n_features=6,
random_state=42
)
clf = tree.DecisionTreeClassifier()
clf = clf.fit(X, y)
fig = plt.figure(figsize=(25,20))
_ = tree.plot_tree(clf,
filled=True)

Code explanation

  • Lines 1–5: We import the necessary packages.

  • Lines 7–11: We create a sample data set with 100 samples and 6 features.

  • Line 13: We initialize our decision tree classifier model.

  • Line 14: We fit the model to the data set.

  • Lines 16–17: We plot the decision tree created by the model.

Free Resources

Copyright ©2024 Educative, Inc. All rights reserved