A More Convenient Approach to Cross-Validation
Learn to use the GridSearchCV in scikit-learn for hyperparameter tuning.
We'll cover the following...
Advantages of using GridSearchCV
In the “The Bias-Variance Trade-Off” chapter, we gained a deep understanding of cross-validation by writing our own function to do it, using the KFold
class to generate the training and testing indices. This was helpful to get a thorough understanding of how the process works. However, scikit-learn offers a convenient class that can do more of the heavy lifting for us: GridSearchCV
. The GridSearchCV
class can take as input a model that we want to find optimal hyperparameters for, such as a decision tree or a logistic regression, and a “grid” of hyperparameters that we want to perform cross-validation over. For example, in a logistic regression, we may want to get the average cross-validation score over all the folds for different values of the regularization parameter, C
. With decision trees, we may want to explore different depths of trees.
You can also search multiple parameters at once, for example, if we wanted to try different depths of trees and different numbers of max_features
to consider at each node split.
GridSearchCV
does what is called an exhaustive grid search over all the possible combinations of parameters that we supply. This means that if we supplied five different values for each of the two hyperparameters, the cross-validation procedure would be run 5 x 5 = 25 times. If you are searching many values of many hyperparameters, the number of cross-validation runs can grow very quickly. In these cases, you may wish to use RandomizedSearchCV
, which searches a random subset of hyperparameter combinations from the universe of all possibilities in the grid you supply.
GridSearchCV
can speed up your work by streamlining the cross-validation process. You should be familiar with the concepts of cross-validation from the previous chapter, so we proceed directly to listing all the options available for GridSearchCV
.
Options for GridSearchCV
In the next lesson, we will get hands-on practice using GridSearchCV
with the case study data, in order to search hyperparameters for a decision tree classifier. Here are the options for GridSearchCV
:
The Options for GridSearchCV
Parameter | Possible values | Notes |
|
| This is a model object that you have instantiated from a model class. The hyperparameters will be updated as GridSearchCV does its work. |
|
| This is a model object that you have instantiated from a model class. The hyperparameters will be updated as |
|
| This represents the model assessment metric(s) you want to use to measure training and testing performance across the folds, for example, |
|
| The number of processing jobs to run in parallel. It may speed up cross-validation to run parallel jobs, but it is good to experiment to be sure. |
|
| The number of jobs or a formula for the number of jobs to dispatch. Relevant for parallel processing using |
|
| If supplying an integer, this is the number of folds to use fore cross-validation. |
|
| After doing the cross-validation, the "best" hyperparameters according to the metric specified in scoring can be used directly with the fitted |
|
| Controls how much output you will see from the cross-validation procedure. |
|
| What to do if an error happens during model fitting. |
|
| Whether or not to compute and return training scores on the folds. It is not required for selecting the best hyperparameters based on testing fold scores, and for some datasets and models, this can take substantially more time. However, it does give insights into possible overfitting. |
In the next lesson, we’ll make use of the standard error of the mean to create error bars. We’ll average the model performance metric across the testing folds, and the error bars will help us visualize how variable model performance is across the folds.
The standard error of the mean is also known as the standard deviation of the sampling distribution of the sample mean. That is a long name, but the concept isn’t too complicated. The idea behind this is that the population of model performance metrics that we wish to make error bars for represent one possible way of sampling a theoretical, larger population of similar samples, for example if more data were available and we used it to have more testing folds. If we could take repeated samples from the larger population, each of these sampling events would result in a slightly different mean (the sample mean).
Constructing ...