TreeSurrogate function

Decision tree surrogate model

TreeSurrogate fits a decision tree on the predictions of a prediction model.


A conditional inference tree is fitted on the predicted y^\hat{y} from the machine learning model and the data. The partykit package and function are used to fit the tree. By default a tree of maximum depth of 2 is fitted to improve interpretability.

To learn more about global surrogate models, read the Interpretable Machine Learning book:


library("randomForest") # Fit a Random Forest on the Boston housing data set data("Boston", package = "MASS") rf <- randomForest(medv ~ ., data = Boston, ntree = 50) # Create a model object mod <- Predictor$new(rf, data = Boston[-which(names(Boston) == "medv")]) # Fit a decision tree as a surrogate for the whole random forest dt <- TreeSurrogate$new(mod) # Plot the resulting leaf nodes plot(dt) # Use the tree to predict new data predict(dt, Boston[1:10, ]) # Extract the results dat <- dt$results head(dat) # It also works for classification rf <- randomForest(Species ~ ., data = iris, ntree = 50) X <- iris[-which(names(iris) == "Species")] mod <- Predictor$new(rf, data = X, type = "prob") # Fit a decision tree as a surrogate for the whole random forest dt <- TreeSurrogate$new(mod, maxdepth = 2) # Plot the resulting leaf nodes plot(dt) # If you want to visualize the tree directly: plot(dt$tree) # Use the tree to predict new data set.seed(42) iris.sample <- X[sample(1:nrow(X), 10), ] predict(dt, iris.sample) predict(dt, iris.sample, type = "class") # Extract the dataset dat <- dt$results head(dat)


Super class

iml::InterpretationMethod -> TreeSurrogate

Public fields

  • tree: party

     The fitted tree. See also partykit::ctree .
  • maxdepth: numeric(1)

     The maximum tree depth.
  • r.squared: numeric(1|n.classes)

     R squared measures how well the decision tree approximates the underlying model. It is calculated as 1 - (variance of prediction differences / variance of black box model predictions). For the multi-class case, r.squared contains one measure per class.


Public methods

Method new()

Create a TreeSurrogate object


TreeSurrogate$new(predictor, maxdepth = 2, tree.args = NULL)


  • predictor: Predictor

     The object (created with `Predictor$new()`) holding the machine learning model and the data.
  • maxdepth: numeric(1)

     The maximum depth of the tree. Default is 2.
  • tree.args: (named list)

     Further arguments for `party::ctree()`.

Method predict()

Predict new data with the tree. See also predict.TreeSurrogate


TreeSurrogate$predict(newdata, type = "prob", ...)


  • newdata: data.frame

     Data to predict on.
  • type: Prediction type.

  • ...: Further arguments passed to predict().

Method clone()

The objects of this class are cloneable with this method.


TreeSurrogate$clone(deep = FALSE)


  • deep: Whether to make a deep clone.

  • Maintainer: Giuseppe Casalicchio
  • License: MIT + file LICENSE
  • Last published: 2025-02-24