prediction wrapper function designed for use with fastshap::explain()
Details
These prediction wrapper functions are designed to be used with the fastshap package. The functions pred_nestcv_glmnet and pred_train work for nestcv.glmnet and nestcv.train models respectively for either binary classification or regression.
For multiclass classification use pred_nestcv_glmnet_class(1), pred_nestcv_glmnet_class(2) etc for each class. Similarly pred_train_class(1), pred_train_class(2) etc for nestcv.train objects.
Examples
library(fastshap)# Boston housing datasetlibrary(mlbench)data(BostonHousing2)dat <- BostonHousing2
y <- dat$cmedv
x <- subset(dat, select =-c(cmedv, medv, town, chas))# Fit a glmnet model using nested CV# Only 3 outer CV folds and 1 alpha value for speedfit <- nestcv.glmnet(y, x, family ="gaussian", n_outer_folds =3, alphaSet =1)# Generate SHAP values using fastshap::explain# Only using 5 repeats here for speed, but recommend higher values of nsimsh <- explain(fit, X=x, pred_wrapper = pred_nestcv_glmnet, nsim =1)# Plot overall variable importanceplot_shap_bar(sh, x)# Plot beeswarm plotplot_shap_beeswarm(sh, x, size =1)