How to build a decision tree with the IRIS dataset in Python

A decision tree is a machine learning algorithm that uses a tree-like model of decisions and their subsequent consequences to arrive at a particular decision. It is a Supervised Machine Learning model, where the data is continuously split according to a certain parameter, and finally, a decision is made.

Usually, a decision tree is drawn upside down, with the root node at the top and the leaf nodes at the bottom. A decision tree usually contains 3 types of nodes.

  1. Root node: The very top node that represents the entire population or sample.
  2. Decision nodes: Sub-nodes that split from the root node.
  3. Leaf nodes: Nodes with no children, also known as terminal nodes.
Structure of a Decision Tree


In Machine Learning, we have two types of models:

  • Regression
  • Classification

You can use decision trees in Regression and Classification problems.

  • Regression tree: These are used to predict continuous variables. For example, predicting rainfall in a region or predicting the revenue that a company might generate in the future.

  • Classification tree: These are used to classify discrete variables. For example, classifying if the temperature of a day will be high or low, or predicting if a team will win the match or not.

How decision trees work

Decision trees work in a step-wise manner, meaning that they perform a step-by-step process instead of following a continuous process. Decision trees follow a tree-like structure, where the nodes of a tree are split using the features based on defined criteria. The main criteria based on which decision trees split are:

  • Gini impurity: Measures the impurity in a node.

  • Entropy: Measures the randomness of the system.

  • Variance: This is normally used in the Regression model, which is a measure of the variation of each data point from the mean.

Practical implementation

Let’s use a real-world dataset to apply decision tree algorithms in Python. You can follow the steps below to create a feasible and useful decision tree:

Import the libraries

We import the required libraries for the model. load_iris from sklearn.datasets and accuracy_score from metrics.

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score

Gather the data

We will be using the IRIS dataset to build a decision tree classifier. The dataset contains information for three classes of the IRIS plant, namely IRIS Setosa, IRIS Versicolour, and IRIS Virginica, with the following attributes: sepal length, sepal width, petal length, and petal width.

data = load_iris()
# Extracting Attributes / Features
X =
# Extracting Target / Class Labels
y =

Import the required Python library and build a data frame

Import the train_test_split and convert the data set into training and testing data.

# Import Library for splitting data
from sklearn.model_selection import train_test_split
# Creating Train and Test datasets
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state = 50, test_size = 0.25)

Create the model in Python

Import DecisionTreeClassifier to perform the classification.

from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier()

Use the test dataset to make a prediction

Our aim is to predict the class of the IRIS plant based on the given attributes.

Complete code

Let’s take a look at the code.

import pandas as pd
import numpy as np
from sklearn.datasets import load_iris
from sklearn.metrics import accuracy_score
# Reading the Iris.csv file
data = load_iris()
# Extracting Attributes / Features
X =
# Extracting Target / Class Labels
y =
# Import Library for splitting data
from sklearn.model_selection import train_test_split
# Creating Train and Test datasets
X_train, X_test, y_train, y_test = train_test_split(X,y, random_state = 50, test_size = 0.25)
# Creating Decision Tree Classifier
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(),y_train)
# Predict Accuracy Score
y_pred = clf.predict(X_test)
print("Train data accuracy:",accuracy_score(y_true = y_train, y_pred=clf.predict(X_train)))
print("Test data accuracy:",accuracy_score(y_true = y_test, y_pred=y_pred))


  • Line 1-4: We import the necessary libraries to read and analyze the dataset.

  • Line 7: We store the IRIS dataset in the variable data. Since the sklearn library contains the IRIS dataset by default, you do not need to upload it again.

  • In Line 10:, we extract all of the attributes in variable X.

  • In Line 13: we extract the target, i.e., the labels in variable y.

  • Line 16: we import the train_test_split function.

  • Line 19 we implement the train_test_split() function. The parameter random_state can be randomly set to any value, but the same needs to be maintained in order to produce reproducible splits. The parameter test_size can also be manipulated based on need. Here, we use a test_size of 0.25, which indicates that we want to split the test data as 25% of the total dataset, and the remaining 75% will be assigned as training data.

  • Lines 22-24: we create a decision tree classifier and fit it against the training dataset. By default, the criterion parameter is set to gini.

  • Line 27-30: we import the “accuracy_score” module and implement the same to find the accuracy of both the training and test data.
  • Line 28-29, we get the output as 1, i.e., 100% for training data and 0.947, which is approximately 95%, for the test dataset.

Free Resources