Performing Cross-Validation
Learn how to use cross-validation to calculate accuracy estimates using tidymodels.
Coding the workflow
One of the many benefits of using the tidymodels
family of R packages is the standardized approach it provides for coding machine learning workflows. The following code sets up a machine learning workflow for a CART classification decision tree:
Preparing the training data.
Declaring how the data should be used to train the model.
Specifying the machine learning algorithm to be used.
Orchestrating the workflow.
Press + to interact
#================================================================================================# Load libraries - suppress messages#suppressMessages(library(tidyverse))suppressMessages(library(tidymodels))suppressMessages(library(rattle))#================================================================================================# Load the Titanic training data and transform Embarked to a factor#titanic_train <- read_csv("titanic_train.csv", show_col_types = FALSE) %>%mutate(Sex = factor(Sex),Embarked = factor(case_when(Embarked == "C" ~ "Cherbourg",Embarked == "Q" ~ "Queenstown",Embarked == "S" ~ "Southampton",is.na(Embarked) ~ "missing")))#================================================================================================# Craft the recipe - recipes package#titanic_recipe <- recipe(Survived ~ Sex + Pclass + SibSp + Parch + Fare + Embarked, data = titanic_train) %>%step_num2factor(Survived,transform = function(x) x + 1,levels = c("perished", "survived")) %>%step_num2factor(Pclass,levels = c("first", "second", "third"))#================================================================================================# Specify the algorithm - parsnip package#titanic_model <- decision_tree() %>%set_engine("rpart") %>%set_mode("classification")#================================================================================================# Set up workflow - workflow package#titanic_workflow <- workflow() %>%add_recipe(titanic_recipe) %>%add_model(titanic_model)
Setting up cross-validation
The vfold_cv()
from the rsample
package creates the folds to be used in cross-validation. The following code uses these ...