Instance Level Profile as Ceteris Paribus for Survival Models
Instance Level Profile as Ceteris Paribus for Survival Models
This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.
predict_profile( explainer, new_observation, variables =NULL, categorical_variables =NULL,..., type ="ceteris_paribus", output_type ="survival", variable_splits_type ="uniform", center =FALSE)## S3 method for class 'surv_explainer'predict_profile( explainer, new_observation, variables =NULL, categorical_variables =NULL,..., type ="ceteris_paribus", output_type ="survival", variable_splits_type ="uniform", center =FALSE)
Arguments
explainer: an explainer object - model preprocessed by the explain() function
new_observation: a new observation for which the prediction need to be explained
variables: a character vector containing names of variables to be explained
categorical_variables: a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the variables argument, they will be added at the end.
...: additional parameters passed to DALEX::predict_profile if output_type =="risk"
type: character, only "ceteris_paribus" is implemented
output_type: either "survival", "chf" or "risk" the type of survival model output that should be considered for explanations. If "survival" the explanations are based on the survival function. If "chf" the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the DALEX::predict_profile function.
variable_splits_type: character, decides how variable grids should be calculated. Use "quantiles" for percentiles or "uniform" (default) to get uniform grid of points.
center: logical, should profiles be centered around the average prediction
Returns
An object of class c("predict_profile_survival", "surv_ceteris_paribus"). It is a list with the final result in the result element.
Examples
library(survival)library(survex)cph <- coxph(Surv(time, status)~ ., data = veteran, model =TRUE, x =TRUE, y =TRUE)rsf_src <- randomForestSRC::rfsrc(Surv(time, status)~ ., data = veteran)cph_exp <- explain(cph)rsf_src_exp <- explain(rsf_src)cph_predict_profile <- predict_profile(cph_exp, veteran[2,-c(3,4)], variables = c("trt","celltype","karno","age"), categorical_variables ="trt")plot(cph_predict_profile, facet_ncol =2)rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5,-c(3,4)], variables ="karno")plot(cph_predict_profile, numerical_plot_type ="contours")