Sci-kit learn is one of the most popular machine learning libraries available. It contains algorithms in both supervised domains (algorithms that use labeled datasets) and unsupervised domains (algorithms that do not use pre-labeled datasets).
In this Answer, we will discuss decision tree-based algorithms, how they work, and their implementation in the sci-kit learn library.
A decision tree is a method that uses conditions or if-else statements to arrive at a final decision. In machine learning, decision tree algorithms are used for both classification and regression problems and fall in the supervised machine learning category.
A decision tree starts with a single node, called the root node, and then splits into two or more further nodes, called decision nodes. Methods such as Entropy
, Information Gain
, and Gini Index
(among others) are used to determine the main feature to be employed at the root node to split the data. These methods use different criteria, depending on the types of features available and the weight each feature carries.
Decision nodes split further while leaf nodes do not. The last nodes that do not split further are referred to as leaf nodes, terminal nodes, or end nodes, and they represent a final outcome. Since decision trees can become very complex, especially if there are many different options, an important concept to avoid complexity and, therefore, curb overfitting is called pruning. Pruning cuts down the decision tree and prevents the model from paying attention to irrelevant information in the data set.
There are a few important terminologies used in this Answer. Let’s have a look at what they mean.
max_depth
, are used to pre-prune the decision tree.Let’s assume we want to decide whether or not we need to carry a sweater depending on the weather. A simple decision tree can be used to weigh the options and make a decision.
Let's see a code example:
from sklearn.datasets import make_classificationfrom sklearn import treeimport pandas as pdimport matplotlib.pyplot as pltimport numpy as npX, y = make_classification(n_samples=100,n_features=6,random_state=42)clf = tree.DecisionTreeClassifier()clf = clf.fit(X, y)fig = plt.figure(figsize=(25,20))_ = tree.plot_tree(clf,filled=True)
Lines 1–5: We import the necessary packages.
Lines 7–11: We create a sample data set with 100 samples and 6 features.
Line 13: We initialize our decision tree classifier model.
Line 14: We fit the model to the data set.
Lines 16–17: We plot the decision tree created by the model.
Free Resources