predict.regression_forest function

Predict with a regression forest

Predict with a regression forest

Gets estimates of E[Y|X=x] using a trained regression forest.

## S3 method for class 'regression_forest' predict( object, newdata = NULL, linear.correction.variables = NULL, ll.lambda = NULL, ll.weight.penalty = FALSE, num.threads = NULL, estimate.variance = FALSE, ... )

Arguments

  • object: The trained forest.
  • newdata: Points at which predictions should be made. If NULL, makes out-of-bag predictions on the training set instead (i.e., provides predictions at Xi using only trees that did not use the i-th training example). Note that this matrix should have the number of columns as the training matrix, and that the columns must appear in the same order.
  • linear.correction.variables: Optional subset of indexes for variables to be used in local linear prediction. If NULL, standard GRF prediction is used. Otherwise, we run a locally weighted linear regression on the included variables. Please note that this is a beta feature still in development, and may slow down prediction considerably. Defaults to NULL.
  • ll.lambda: Ridge penalty for local linear predictions. Defaults to NULL and will be cross-validated.
  • ll.weight.penalty: Option to standardize ridge penalty by covariance (TRUE), or penalize all covariates equally (FALSE). Defaults to FALSE.
  • num.threads: Number of threads used in prediction. If set to NULL, the software automatically selects an appropriate amount.
  • estimate.variance: Whether variance estimates for τ^(x)\hat\tau(x) are desired (for confidence intervals).
  • ...: Additional arguments (currently ignored).

Returns

Vector of predictions, along with estimates of the error and (optionally) its variance estimates. Column 'predictions' contains estimates of E[Y|X=x]. The square-root of column 'variance.estimates' is the standard error the test mean-squared error. Column 'excess.error' contains jackknife estimates of the Monte-carlo error. The sum of 'debiased.error' and 'excess.error' is the raw error attained by the current forest, and 'debiased.error' alone is an estimate of the error attained by a forest with an infinite number of trees. We recommend that users grow enough forests to make the 'excess.error' negligible.

Examples

# Train a standard regression forest. n <- 50 p <- 10 X <- matrix(rnorm(n * p), n, p) Y <- X[, 1] * rnorm(n) r.forest <- regression_forest(X, Y) # Predict using the forest. X.test <- matrix(0, 101, p) X.test[, 1] <- seq(-2, 2, length.out = 101) r.pred <- predict(r.forest, X.test) # Predict on out-of-bag training samples. r.pred <- predict(r.forest) # Predict with confidence intervals; growing more trees is now recommended. r.forest <- regression_forest(X, Y, num.trees = 100) r.pred <- predict(r.forest, X.test, estimate.variance = TRUE)
  • Maintainer: Erik Sverdrup
  • License: GPL-3
  • Last published: 2024-11-15