Exercise: A Decision Tree in scikit-learn
Learn how to model a decision tree on our case study data and visualize it using graphviz.
We'll cover the following
Modeling a decision tree on the case study data
In this exercise, we will use the case study data to grow a decision tree, where we specify the maximum depth. We’ll also use some handy functionality to visualize the decision tree, in the form of the graphviz
package. Perform the following steps to complete the exercise:
-
Load several of the packages that we’ve been using, and an additional one,
graphviz
, so that we can visualize decision trees:import numpy as np #numerical computation import pandas as pd #data wrangling import matplotlib.pyplot as plt #plotting package #Next line helps with rendering plots %matplotlib inline import matplotlib as mpl #additional plotting functionality mpl.rcParams['figure.dpi'] = 400 #high res figures import graphviz #to visualize decision trees
-
Load the cleaned case study data:
df = pd.read_csv('Chapter_1_cleaned_data.csv')
-
Get a list of column names of the DataFrame:
features_response = df.columns.tolist()
-
Make a list of columns to remove that aren’t features or the response variable:
items_to_remove = ['ID', 'SEX', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'EDUCATION_CAT',\ 'graduate school', 'high school', 'none', 'others', 'university']
-
Use a list comprehension to remove these column names from our list of features and the response variable:
features_response = [item for item in features_response if item not in items_to_remove] features_response
This should output the list of features and the response variable:
['LIMIT_BAL', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_1', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'default payment next month']
Now the list of features is prepared. Next, we will make some imports from scikit-learn. We want to make a train/test split, which we are already familiar with. We also want to import the decision tree functionality.
-
Run this code to make imports from scikit-learn:
from sklearn.model_selection import train_test_split from sklearn import tree
The
tree
library of scikit-learn contains decision tree-related classes. -
Split the data into training and testing sets using the same random seed that we have used throughout the course:
X_train, X_test, y_train, y_test = train_test_split(df[features_response[:-1]].values,\ df['default payment next month'].values, test_size=0.2, random_state=24)
Here, we use all but the last element of the list to get the names of the features, but not the response variable:
features_response[:-1]
. We use this to select columns from the DataFrame, and then retrieve their values using the.values
method. We also do something similar for the response variable, but specify the column name directly. In making the train/test split, we’ve used the same random seed as in previous work, as well as the same split size. This way, we can directly compare the work we will do in this section with previous results. Also, we continue to reserve the same “unseen test set” from the model development process.Now we are ready to instantiate the decision tree class.
-
Instantiate the decision tree class by setting the
max_depth
parameter to2
:dt = tree.DecisionTreeClassifier(max_depth=2)
We have used the
DecisionTreeClassifier
class because we have a classification problem. Because we specifiedmax_depth=2
, when we grow the decision tree using the case study data, the tree will grow to a depth of at most2
. Let’s now train this model. -
Use this code to fit the decision tree model and grow the tree:
dt.fit(X_train, y_train)
This should display the following output:
DecisionTreeClassifier(max_depth=2)
Now that we have fit this decision tree model, we can use the
graphviz
package to display a graphical representation of the tree. -
Export the trained model in a format that can be read by the
graphviz
package using this code:dot_data = tree.export_graphviz(dt, out_file=None, filled=True, rounded=True,\ feature_names = features_response[:-1], proportion=True, class_names=[ 'Not defaulted', 'Defaulted'])
Here, we’ve provided a number of options for the
.export_graphviz
method. First, we need to say which trained model we’d like to graph, which isdt
. Next, we say we don’t want an output file:out_file=None
. Instead, we provide thedot_data
variable to hold the output of this method.The rest of the options are set as follows:
-
filled=True:
Each node will be filled with a color. -
rounded=True:
The nodes will appear with rounded edges as opposed to rectangles. -
feature_names=features_response[:-1]
: The names of the features from our list will be used as opposed to generic names such asX[0]
. -
proportion=True
: The proportion of training samples in each node will be displayed (we’ll discuss this more later). -
class_names=['Not defaulted', 'Defaulted']
: The name of the predicted class will be displayed for each node.
What is the output of this method?
If you examine the contents of
dot_data
, you will see that it is a long text string. Thegraphviz
package can interpret this text string to create a visualization. -
-
Use the
.Source
method of thegraphviz
package to create an image fromdot_data
and display it:graph = graphviz.Source(dot_data) graph
The output should look like this:
Get hands-on with 1300+ tech skills courses.