Exercise: A Decision Tree in scikit-learn
Learn how to model a decision tree on our case study data and visualize it using graphviz.
We'll cover the following...
Modeling a decision tree on the case study data
In this exercise, we will use the case study data to grow a decision tree, where we specify the maximum depth. We’ll also use some handy functionality to visualize the decision tree, in the form of the graphviz
package. Perform the following steps to complete the exercise:
-
Load several of the packages that we’ve been using, and an additional one,
graphviz
, so that we can visualize decision trees:import numpy as np #numerical computation import pandas as pd #data wrangling import matplotlib.pyplot as plt #plotting package #Next line helps with rendering plots %matplotlib inline import matplotlib as mpl #additional plotting functionality mpl.rcParams['figure.dpi'] = 400 #high res figures import graphviz #to visualize decision trees
-
Load the cleaned case study data:
df = pd.read_csv('Chapter_1_cleaned_data.csv')
-
Get a list of column names of the DataFrame:
features_response = df.columns.tolist()
-
Make a list of columns to remove that aren’t features or the response variable:
items_to_remove = ['ID', 'SEX', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'EDUCATION_CAT',\ 'graduate school', 'high school', 'none', 'others', 'university']
-
Use a list comprehension to remove these column names from our list of features and the response variable:
features_response = [item for item in features_response if item not in items_to_remove] features_response
This should output the list of features and the response variable:
['LIMIT_BAL', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_1', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'default payment next month']
Now the list of features is prepared. Next, we will make some imports from scikit-learn. We want to make a train/test split, which we are already familiar with. We also want to import the decision tree functionality.
-
Run this code to make imports from scikit-learn:
from sklearn.model_selection import train_test_split from sklearn import tree
The ...