model_parts.surv_explainer function

Dataset Level Variable Importance for Survival Models

Dataset Level Variable Importance for Survival Models

This function calculates variable importance as a change in the loss function after the variable values permutations.

model_parts(explainer, ...) ## S3 method for class 'surv_explainer' model_parts( explainer, loss_function = survex::loss_brier_score, ..., type = "difference", output_type = "survival", N = 1000 )

Arguments

  • explainer: an explainer object - model preprocessed by the explain() function

  • ...: Arguments passed on to surv_feature_importance, surv_integrated_feature_importance

    • B: numeric, number of permutations to be calculated
    • variables: a character vector, names of variables to be included in the calculation
    • variable_groups: a list of character vectors of names of explanatory variables. For each vector, a single variable-importance measure is computed for the joint effect of the variables which names are provided in the vector. By default, variable_groups = NULL, in which case variable-importance measures are computed separately for all variables indicated in the variables argument
    • label: label of the model, if provides overrides x$label
  • loss_function: a function that will be used to assess variable importance, by default loss_brier_score for survival models. The function can be supplied manually but has to have these named parameters (y_true, risk, surv, times), where y_true represents the survival::Surv object with observed times and statuses, risk is the risk score calculated by the model, and surv is the survival function for each observation evaluated at times.

  • type: a character vector, if "raw" the results are losses after the permutation, if "ratio" the results are in the form loss/loss_full_model and if "difference" the results are of the form loss - loss_full_model. Defaults to "difference".

  • output_type: either "survival" or "risk" the type of survival model output that should be used for explanations. If "survival" the explanations are based on the survival function. Otherwise the scalar risk predictions are used by the DALEX::model_profile function.

  • N: number of observations that should be sampled for calculation of variable importance. If NULL then variable importance will be calculated on the whole dataset.

Returns

An object of class c("model_parts_survival", "surv_feature_importance"). It's a list with the explanations in the result element.

Details

Note: This function can be run within progressr::with_progress() to display a progress bar, as the execution can take long, especially on large datasets.

Examples

library(survival) library(survex) cph <- coxph(Surv(time, status) ~ ., data = veteran, model = TRUE, x = TRUE, y = TRUE) rsf_ranger <- ranger::ranger(Surv(time, status) ~ ., data = veteran, respect.unordered.factors = TRUE, num.trees = 100, mtry = 3, max.depth = 5 ) cph_exp <- explain(cph) rsf_ranger_exp <- explain(rsf_ranger, data = veteran[, -c(3, 4)], y = Surv(veteran$time, veteran$status) ) cph_model_parts_brier <- model_parts(cph_exp) print(head(cph_model_parts_brier$result)) plot(cph_model_parts_brier) rsf_ranger_model_parts <- model_parts(rsf_ranger_exp) print(head(rsf_ranger_model_parts$result)) plot(cph_model_parts_brier, rsf_ranger_model_parts)
  • Maintainer: Mikołaj Spytek
  • License: GPL (>= 3)
  • Last published: 2023-10-24