model_profile.surv_explainer function

Dataset Level Variable Profile as Partial Dependence Explanations for Survival Models

Dataset Level Variable Profile as Partial Dependence Explanations for Survival Models

This function calculates explanations on a dataset level that help explore model response as a function of selected variables. The explanations are calculated as an extension of Partial Dependence Profiles with the inclusion of the time dimension.

model_profile( explainer, variables = NULL, N = 100, ..., groups = NULL, k = NULL, type = "partial", center = FALSE, output_type = "survival" ) ## S3 method for class 'surv_explainer' model_profile( explainer, variables = NULL, N = 100, ..., categorical_variables = NULL, grid_points = 51, variable_splits_type = "uniform", groups = NULL, k = NULL, center = FALSE, type = "partial", output_type = "survival" )

Arguments

  • explainer: an explainer object - model preprocessed by the explain() function
  • variables: character, a vector of names of variables to be explained
  • N: number of observations used for the calculation of aggregated profiles. By default 100. If NULL all observations are used.
  • ...: other parameters passed to DALEX::model_profile if output_type == "risk", otherwise ignored
  • groups: if output_type == "risk" a variable name that will be used for grouping. By default NULL, so no groups are calculated. If output_type == "survival" then ignored
  • k: passed to DALEX::model_profile if output_type == "risk", otherwise ignored
  • type: the type of variable profile, "partial" for Partial Dependence, "accumulated" for Accumulated Local Effects, or "conditional" (available only for output_type == "risk")
  • center: logical, should profiles be centered around the average prediction
  • output_type: either "survival", "chf" or "risk" the type of survival model output that should be considered for explanations. If "survival" the explanations are based on the survival function. If "chf" the explanations are based on the cumulative hazard function. Otherwise the scalar risk predictions are used by the DALEX::predict_profile function.
  • categorical_variables: character, a vector of names of additional variables which should be treated as categorical (factors are automatically treated as categorical variables). If it contains variable names not present in the variables argument, they will be added at the end.
  • grid_points: maximum number of points for profile calculations. Note that the final number of points may be lower than grid_points. Will be passed to internal function. By default 51.
  • variable_splits_type: character, decides how variable grids should be calculated. Use "quantiles" for percentiles or "uniform" (default) to get uniform grid of points.

Returns

An object of class model_profile_survival. It is a list with the element result containing the results of the calculation.

Examples

library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_src <- randomForestSRC::rfsrc(Surv(time, status) ~ ., data = veteran) cph_exp <- explain(cph) rsf_src_exp <- explain(rsf_src) cph_model_profile <- model_profile(cph_exp, output_type = "survival", variables = c("age") ) head(cph_model_profile$result) plot(cph_model_profile) rsf_model_profile <- model_profile(rsf_src_exp, output_type = "survival", variables = c("age", "celltype"), type = "accumulated" ) head(rsf_model_profile$result) plot(rsf_model_profile, variables = c("age", "celltype"), numerical_plot_type = "contours")
  • Maintainer: Mikołaj Spytek
  • License: GPL (>= 3)
  • Last published: 2023-10-24