Chapter 22 K-fold cross validation in R

Cross-validation is another method of estimating the quality of prediction of a model. Often, you’ll want to use cross-validation in situations where your dataset is small, in which case splitting the data into two parts – a training set and a testing set – does not result in a good prediction. That is, your training set might be too small to train a good model, and your testing set might be too small to get a good idea of the quality of your model’s predictions on unseen data.

In k-fold cross-validation, you divide the entire data set into k equal-size subsets, and use k-1 parts for training and the remaining part for testing and calculating the prediction error (i.e., prediction quality). You repeat this procedure k times (one for each of the k “folds”), and report the average quality of prediction from the k runs.

Fortunately, R has a variety of tools that help you to perform a k-fold cross validation. In this lab activity, we will use functions from the tidyverse collection of packages to perform a k-fold cross validation.

22.1 Dependencies

We’ll use the following packages (you will need to install any that you don’t already have installed):

library(tidyverse)
library(modelr)
library(rpart)

22.2 Loading the data

In this lab activity, we’ll use the mtcars dataset, which comes built-in with your R install. Read more about it by running ?mtcars in your R console on RStudio.

data <- mtcars

22.3 Cross-validation with crossv_kfold

We can use the crossv_kfold function in the modelr package to split our data into k exclusive partitions (i.e., k folds).

kfolds <- crossv_kfold(
  data,
  k=8
)
kfolds
## # A tibble: 8 × 3
##   train                test                .id  
##   <named list>         <named list>        <chr>
## 1 <resample [28 x 11]> <resample [4 x 11]> 1    
## 2 <resample [28 x 11]> <resample [4 x 11]> 2    
## 3 <resample [28 x 11]> <resample [4 x 11]> 3    
## 4 <resample [28 x 11]> <resample [4 x 11]> 4    
## 5 <resample [28 x 11]> <resample [4 x 11]> 5    
## 6 <resample [28 x 11]> <resample [4 x 11]> 6    
## 7 <resample [28 x 11]> <resample [4 x 11]> 7    
## 8 <resample [28 x 11]> <resample [4 x 11]> 8

Notice that the output of crossv_kfold is a tibble (i.e., a fancy data frame) where each row gives a different training/testing split. Each one of those resample objects you see in the tibble can be turned back into a data frame using as.data.frame. For example,

# Convert one of the training samples into a more familiar data object
as.data.frame(kfolds$train[1])
##                     X1.mpg X1.cyl X1.disp X1.hp X1.drat X1.wt X1.qsec X1.vs
## Mazda RX4             21.0      6   160.0   110    3.90 2.620   16.46     0
## Mazda RX4 Wag         21.0      6   160.0   110    3.90 2.875   17.02     0
## Hornet 4 Drive        21.4      6   258.0   110    3.08 3.215   19.44     1
## Hornet Sportabout     18.7      8   360.0   175    3.15 3.440   17.02     0
## Valiant               18.1      6   225.0   105    2.76 3.460   20.22     1
## Duster 360            14.3      8   360.0   245    3.21 3.570   15.84     0
## Merc 230              22.8      4   140.8    95    3.92 3.150   22.90     1
## Merc 280              19.2      6   167.6   123    3.92 3.440   18.30     1
## Merc 280C             17.8      6   167.6   123    3.92 3.440   18.90     1
## Merc 450SE            16.4      8   275.8   180    3.07 4.070   17.40     0
## Merc 450SL            17.3      8   275.8   180    3.07 3.730   17.60     0
## Merc 450SLC           15.2      8   275.8   180    3.07 3.780   18.00     0
## Cadillac Fleetwood    10.4      8   472.0   205    2.93 5.250   17.98     0
## Lincoln Continental   10.4      8   460.0   215    3.00 5.424   17.82     0
## Chrysler Imperial     14.7      8   440.0   230    3.23 5.345   17.42     0
## Fiat 128              32.4      4    78.7    66    4.08 2.200   19.47     1
## Honda Civic           30.4      4    75.7    52    4.93 1.615   18.52     1
## Toyota Corona         21.5      4   120.1    97    3.70 2.465   20.01     1
## Dodge Challenger      15.5      8   318.0   150    2.76 3.520   16.87     0
## AMC Javelin           15.2      8   304.0   150    3.15 3.435   17.30     0
## Camaro Z28            13.3      8   350.0   245    3.73 3.840   15.41     0
## Pontiac Firebird      19.2      8   400.0   175    3.08 3.845   17.05     0
## Fiat X1-9             27.3      4    79.0    66    4.08 1.935   18.90     1
## Porsche 914-2         26.0      4   120.3    91    4.43 2.140   16.70     0
## Lotus Europa          30.4      4    95.1   113    3.77 1.513   16.90     1
## Ford Pantera L        15.8      8   351.0   264    4.22 3.170   14.50     0
## Maserati Bora         15.0      8   301.0   335    3.54 3.570   14.60     0
## Volvo 142E            21.4      4   121.0   109    4.11 2.780   18.60     1
##                     X1.am X1.gear X1.carb
## Mazda RX4               1       4       4
## Mazda RX4 Wag           1       4       4
## Hornet 4 Drive          0       3       1
## Hornet Sportabout       0       3       2
## Valiant                 0       3       1
## Duster 360              0       3       4
## Merc 230                0       4       2
## Merc 280                0       4       4
## Merc 280C               0       4       4
## Merc 450SE              0       3       3
## Merc 450SL              0       3       3
## Merc 450SLC             0       3       3
## Cadillac Fleetwood      0       3       4
## Lincoln Continental     0       3       4
## Chrysler Imperial       0       3       4
## Fiat 128                1       4       1
## Honda Civic             1       4       2
## Toyota Corona           0       3       1
## Dodge Challenger        0       3       2
## AMC Javelin             0       3       2
## Camaro Z28              0       3       4
## Pontiac Firebird        0       3       2
## Fiat X1-9               1       4       1
## Porsche 914-2           1       5       2
## Lotus Europa            1       5       2
## Ford Pantera L          1       5       4
## Maserati Bora           1       5       8
## Volvo 142E              1       4       2

We can use the map function (from the purrr tidyverse package) to train a simple linear model (using lm) for each of our 8 folds:

# In this code, the map function runs the provided function (lm in this case)
# on each of the training sets inside of the kfolds tibble we created.
# After running this line, models_a will contain 8 models, each trained on one
# of our 8 training sets.
models_a <- map(kfolds$train, ~lm(mpg ~ wt, data = .))

Next, we can use the map2_dbl function (from the purrr tidyverse package) to get the testing errors (root mean squared error, rmse in this case) for each model trained on a different fold:

# In short, the map2_dbl function "loops" over the contents of the first two
# arguments (models_a and kfolds$test), applying each pair of values (one from
# models_a and one from kfolds$test) to the function given in the third argument
# for the map2_dbl function (rmse).
errors_a <- map2_dbl(models_a, kfolds$test, rmse)

It’s common to then report the average error across all folds

mean(errors_a)
## [1] 2.911305

Cross-validation is particularly useful for comparing across different models. For example, let’s train a slightly more complicated linear model (with wt and hp as predictors):

# Again, we use map and map2_dbl to help us train a model and compute its test
# error.
models_b <- map(kfolds$train, ~lm(mpg ~ wt + hp, data = .))
errors_b <- map2_dbl(models_b, kfolds$test, rmse)
mean(errors_b)
## [1] 2.540529

And a regression tree with wt and hp as predictors for mpg:

models_c <- map(kfolds$train, ~rpart(mpg ~ wt + hp, data = .))
errors_c <- map2_dbl(models_c, kfolds$test, rmse)
mean(errors_c)
## [1] 4.3221

22.4 Exercises

  • The map family of functions from the purrr tidyverse package are pretty powerful. Check your understanding of how they work using the documentation: https://purrr.tidyverse.org/reference/map.html
  • Why might k-fold cross validation be a more robust method of evaluation than performing a single training/testing split (as we have done before) when working with small datasets?

22.5 References