predict_profile.surv_explainer function

Instance Level Profile as Ceteris Paribus for Survival Models

Instance Level Profile as Ceteris Paribus for Survival Models

This function calculates Ceteris Paribus Profiles for a specific observation with the possibility to take the time dimension into account.

predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE ) ## S3 method for class 'surv_explainer' predict_profile( explainer, new_observation, variables = NULL, categorical_variables = NULL, ..., type = "ceteris_paribus", output_type = "survival", variable_splits_type = "uniform", center = FALSE )

Arguments

  • explainer: an explainer object - model preprocessed by the explain() function
  • new_observation: a new observation for which the prediction need to be explained
  • variables: a character vector containing names of variables to be explained
  • categorical_variables: a character vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the variables argument, they will be added at the end.
  • ...: additional parameters passed to DALEX::predict_profile if output_type =="risk"
  • type: character, only "ceteris_paribus" is implemented
  • output_type: either "survival", "chf" or "risk" the type of survival model output that should be considered for explanations. If "survival" the explanations are based on the survival function. If "chf" the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the DALEX::predict_profile function.
  • variable_splits_type: character, decides how variable grids should be calculated. Use "quantiles" for percentiles or "uniform" (default) to get uniform grid of points.
  • center: logical, should profiles be centered around the average prediction

Returns

An object of class c("predict_profile_survival", "surv_ceteris_paribus"). It is a list with the final result in the result element.

Examples

library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_predict_profile <- predict_profile(cph_exp, veteran[2, -c(3, 4)], variables = c("trt", "celltype", "karno", "age"), categorical_variables = "trt" ) plot(cph_predict_profile, facet_ncol = 2) rsf_predict_profile <- predict_profile(rsf_src_exp, veteran[5, -c(3, 4)], variables = "karno") plot(cph_predict_profile, numerical_plot_type = "contours")
  • Maintainer: Mikołaj Spytek
  • License: GPL (>= 3)
  • Last published: 2023-10-24