crossval function

Estimator of the Mean Squared Prediction Error using Cross-Validation.

Estimator of the Mean Squared Prediction Error using Cross-Validation.

Estimator of the mean squared prediction error of different learners using cross-validation.

crossval( y, X, Z = NULL, learners, cv_folds = 5, cv_subsamples = NULL, silent = FALSE, progress = NULL )

Arguments

  • y: The outcome variable.

  • X: A (sparse) matrix of predictive variables.

  • Z: Optional additional (sparse) matrix of predictive variables.

  • learners: learners is a list of lists, each containing four named elements:

    • fun The base learner function. The function must be such that it predicts a named input y using a named input X.
    • args Optional arguments to be passed to fun.
    • assign_X An optional vector of column indices corresponding to variables in X that are passed to the base learner.
    • assign_Z An optional vector of column indices corresponding to variables in Z that are passed to the base learner.

    Omission of the args element results in default arguments being used in fun. Omission of assign_X (and/or assign_Z) results in inclusion of all predictive variables in X (and/or Z).

  • cv_folds: Number of folds used for cross-validation.

  • cv_subsamples: List of vectors with sample indices for cross-validation.

  • silent: Boolean to silence estimation updates.

  • progress: String to print before learner and cv fold progress.

Returns

crossval returns a list containing the following components:

  • mspe: A vector of MSPE estimates, each corresponding to a base learners (in chronological order).
  • oos_resid: A matrix of out-of-sample prediction errors, each column corresponding to a base learners (in chronological order).
  • cv_subsamples: Pass-through of cv_subsamples. See above.

Examples

# Construct variables from the included Angrist & Evans (1998) data y = AE98[, "worked"] X = AE98[, c("morekids", "age","agefst","black","hisp","othrace","educ")] # Compare ols, lasso, and ridge using 4-fold cross-validation cv_res <- crossval(y, X, learners = list(list(fun = ols), list(fun = mdl_glmnet), list(fun = mdl_glmnet, args = list(alpha = 0))), cv_folds = 4, silent = TRUE) cv_res$mspe

See Also

Other utilities: crosspred(), shortstacking()