TreeSurrogate fits a decision tree on the predictions of a prediction model.
Details
A conditional inference tree is fitted on the predicted y^ from the machine learning model and the data. The partykit package and function are used to fit the tree. By default a tree of maximum depth of 2 is fitted to improve interpretability.
library("randomForest")# Fit a Random Forest on the Boston housing data setdata("Boston", package ="MASS")rf <- randomForest(medv ~ ., data = Boston, ntree =50)# Create a model objectmod <- Predictor$new(rf, data = Boston[-which(names(Boston)=="medv")])# Fit a decision tree as a surrogate for the whole random forestdt <- TreeSurrogate$new(mod)# Plot the resulting leaf nodesplot(dt)# Use the tree to predict new datapredict(dt, Boston[1:10,])# Extract the resultsdat <- dt$results
head(dat)# It also works for classificationrf <- randomForest(Species ~ ., data = iris, ntree =50)X <- iris[-which(names(iris)=="Species")]mod <- Predictor$new(rf, data = X, type ="prob")# Fit a decision tree as a surrogate for the whole random forestdt <- TreeSurrogate$new(mod, maxdepth =2)# Plot the resulting leaf nodesplot(dt)# If you want to visualize the tree directly:plot(dt$tree)# Use the tree to predict new dataset.seed(42)iris.sample <- X[sample(1:nrow(X),10),]predict(dt, iris.sample)predict(dt, iris.sample, type ="class")# Extract the datasetdat <- dt$results
head(dat)
References
Craven, M., & Shavlik, J. W. (1996). Extracting tree-structured representations of trained networks. In Advances in neural information processing systems (pp. 24-30).
See Also
predict.TreeSurrogate plot.TreeSurrogate
For the tree implementation partykit::ctree()
Super class
iml::InterpretationMethod -> TreeSurrogate
Public fields
tree: party
The fitted tree. See also partykit::ctree .
maxdepth: numeric(1)
The maximum tree depth.
r.squared: numeric(1|n.classes)
R squared measures how well the decision tree approximates the underlying model. It is calculated as 1 - (variance of prediction differences / variance of black box model predictions). For the multi-class case, r.squared contains one measure per class.