Chapter 22 K-fold cross validation in R
Cross-validation is another method of estimating the quality of prediction of a model. Often, you’ll want to use cross-validation in situations where your dataset is small, in which case splitting the data into two parts – a training set and a testing set – does not result in a good prediction. That is, your training set might be too small to train a good model, and your testing set might be too small to get a good idea of the quality of your model’s predictions on unseen data.
In k-fold cross-validation, you divide the entire data set into k
equal-size subsets, and use k-1
parts for training and the remaining part for testing and calculating the prediction error (i.e., prediction quality).
You repeat this procedure k
times (one for each of the k
“folds”), and report the average quality of prediction from the k
runs.
Fortunately, R has a variety of tools that help you to perform a k-fold cross validation. In this lab activity, we will use functions from the tidyverse collection of packages to perform a k-fold cross validation.
22.1 Dependencies
We’ll use the following packages (you will need to install any that you don’t already have installed):
library(tidyverse)
library(modelr)
library(rpart)
22.2 Loading the data
In this lab activity, we’ll use the mtcars
dataset, which comes built-in with your R install.
Read more about it by running ?mtcars
in your R console on RStudio.
<- mtcars data
22.3 Cross-validation with crossv_kfold
We can use the crossv_kfold
function in the modelr
package to split our data into k
exclusive partitions (i.e., k
folds).
<- crossv_kfold(
kfolds
data,k=8
) kfolds
## # A tibble: 8 × 3
## train test .id
## <named list> <named list> <chr>
## 1 <resample [28 x 11]> <resample [4 x 11]> 1
## 2 <resample [28 x 11]> <resample [4 x 11]> 2
## 3 <resample [28 x 11]> <resample [4 x 11]> 3
## 4 <resample [28 x 11]> <resample [4 x 11]> 4
## 5 <resample [28 x 11]> <resample [4 x 11]> 5
## 6 <resample [28 x 11]> <resample [4 x 11]> 6
## 7 <resample [28 x 11]> <resample [4 x 11]> 7
## 8 <resample [28 x 11]> <resample [4 x 11]> 8
Notice that the output of crossv_kfold
is a tibble (i.e., a fancy data frame) where each row gives a different training/testing split.
Each one of those resample objects you see in the tibble can be turned back into a data frame using as.data.frame
.
For example,
# Convert one of the training samples into a more familiar data object
as.data.frame(kfolds$train[1])
## X1.mpg X1.cyl X1.disp X1.hp X1.drat X1.wt X1.qsec X1.vs
## Mazda RX4 21.0 6 160.0 110 3.90 2.620 16.46 0
## Mazda RX4 Wag 21.0 6 160.0 110 3.90 2.875 17.02 0
## Hornet 4 Drive 21.4 6 258.0 110 3.08 3.215 19.44 1
## Hornet Sportabout 18.7 8 360.0 175 3.15 3.440 17.02 0
## Valiant 18.1 6 225.0 105 2.76 3.460 20.22 1
## Duster 360 14.3 8 360.0 245 3.21 3.570 15.84 0
## Merc 230 22.8 4 140.8 95 3.92 3.150 22.90 1
## Merc 280 19.2 6 167.6 123 3.92 3.440 18.30 1
## Merc 280C 17.8 6 167.6 123 3.92 3.440 18.90 1
## Merc 450SE 16.4 8 275.8 180 3.07 4.070 17.40 0
## Merc 450SL 17.3 8 275.8 180 3.07 3.730 17.60 0
## Merc 450SLC 15.2 8 275.8 180 3.07 3.780 18.00 0
## Cadillac Fleetwood 10.4 8 472.0 205 2.93 5.250 17.98 0
## Lincoln Continental 10.4 8 460.0 215 3.00 5.424 17.82 0
## Chrysler Imperial 14.7 8 440.0 230 3.23 5.345 17.42 0
## Fiat 128 32.4 4 78.7 66 4.08 2.200 19.47 1
## Honda Civic 30.4 4 75.7 52 4.93 1.615 18.52 1
## Toyota Corona 21.5 4 120.1 97 3.70 2.465 20.01 1
## Dodge Challenger 15.5 8 318.0 150 2.76 3.520 16.87 0
## AMC Javelin 15.2 8 304.0 150 3.15 3.435 17.30 0
## Camaro Z28 13.3 8 350.0 245 3.73 3.840 15.41 0
## Pontiac Firebird 19.2 8 400.0 175 3.08 3.845 17.05 0
## Fiat X1-9 27.3 4 79.0 66 4.08 1.935 18.90 1
## Porsche 914-2 26.0 4 120.3 91 4.43 2.140 16.70 0
## Lotus Europa 30.4 4 95.1 113 3.77 1.513 16.90 1
## Ford Pantera L 15.8 8 351.0 264 4.22 3.170 14.50 0
## Maserati Bora 15.0 8 301.0 335 3.54 3.570 14.60 0
## Volvo 142E 21.4 4 121.0 109 4.11 2.780 18.60 1
## X1.am X1.gear X1.carb
## Mazda RX4 1 4 4
## Mazda RX4 Wag 1 4 4
## Hornet 4 Drive 0 3 1
## Hornet Sportabout 0 3 2
## Valiant 0 3 1
## Duster 360 0 3 4
## Merc 230 0 4 2
## Merc 280 0 4 4
## Merc 280C 0 4 4
## Merc 450SE 0 3 3
## Merc 450SL 0 3 3
## Merc 450SLC 0 3 3
## Cadillac Fleetwood 0 3 4
## Lincoln Continental 0 3 4
## Chrysler Imperial 0 3 4
## Fiat 128 1 4 1
## Honda Civic 1 4 2
## Toyota Corona 0 3 1
## Dodge Challenger 0 3 2
## AMC Javelin 0 3 2
## Camaro Z28 0 3 4
## Pontiac Firebird 0 3 2
## Fiat X1-9 1 4 1
## Porsche 914-2 1 5 2
## Lotus Europa 1 5 2
## Ford Pantera L 1 5 4
## Maserati Bora 1 5 8
## Volvo 142E 1 4 2
We can use the map
function (from the purrr
tidyverse package) to train a simple linear model (using lm
) for each of our 8 folds:
# In this code, the map function runs the provided function (lm in this case)
# on each of the training sets inside of the kfolds tibble we created.
# After running this line, models_a will contain 8 models, each trained on one
# of our 8 training sets.
<- map(kfolds$train, ~lm(mpg ~ wt, data = .)) models_a
Next, we can use the map2_dbl
function (from the purrr
tidyverse package) to get the testing errors (root mean squared error, rmse in this case) for each model trained on a different fold:
# In short, the map2_dbl function "loops" over the contents of the first two
# arguments (models_a and kfolds$test), applying each pair of values (one from
# models_a and one from kfolds$test) to the function given in the third argument
# for the map2_dbl function (rmse).
<- map2_dbl(models_a, kfolds$test, rmse) errors_a
It’s common to then report the average error across all folds
mean(errors_a)
## [1] 2.911305
Cross-validation is particularly useful for comparing across different models. For example, let’s train a slightly more complicated linear model (with wt and hp as predictors):
# Again, we use map and map2_dbl to help us train a model and compute its test
# error.
<- map(kfolds$train, ~lm(mpg ~ wt + hp, data = .))
models_b <- map2_dbl(models_b, kfolds$test, rmse)
errors_b mean(errors_b)
## [1] 2.540529
And a regression tree with wt and hp as predictors for mpg:
<- map(kfolds$train, ~rpart(mpg ~ wt + hp, data = .))
models_c <- map2_dbl(models_c, kfolds$test, rmse)
errors_c mean(errors_c)
## [1] 4.3221
22.4 Exercises
- The
map
family of functions from thepurrr
tidyverse package are pretty powerful. Check your understanding of how they work using the documentation: https://purrr.tidyverse.org/reference/map.html - Why might k-fold cross validation be a more robust method of evaluation than performing a single training/testing split (as we have done before) when working with small datasets?
22.5 References
- Data mining - a knowledge discovery approach (textbook)
- Cross Validation in R example
- R for Data Analytics
- modelr documentation