Predictor function

Predictor object

Predictor object

A Predictor object holds any machine learning model (mlr, caret, randomForest, ...) and the data to be used for analyzing the model. The interpretation methods in the iml package need the machine learning model to be wrapped in a Predictor object.

Details

A Predictor object is a container for the prediction model and the data. This ensures that the machine learning model can be analyzed in a robust way.

Note: In case of classification, the model should return one column per class with the class probability.

Examples

library("mlr") task <- makeClassifTask(data = iris, target = "Species") learner <- makeLearner("classif.rpart", minsplit = 7, predict.type = "prob") mod.mlr <- train(learner, task) mod <- Predictor$new(mod.mlr, data = iris) mod$predict(iris[1:5, ]) mod <- Predictor$new(mod.mlr, data = iris, class = "setosa") mod$predict(iris[1:5, ]) library("randomForest") rf <- randomForest(Species ~ ., data = iris, ntree = 20) mod <- Predictor$new(rf, data = iris, type = "prob") mod$predict(iris[50:55, ]) # Feature importance needs the target vector, which needs to be supplied: mod <- Predictor$new(rf, data = iris, y = "Species", type = "prob")

Public fields

  • data: data.frame

     Data object with the data for the model interpretation.
    
  • model: (any)

     The machine learning model.
    
  • batch.size: numeric(1)

     The number of rows to be input the model for prediction at once.
    
  • class: character(1)

     The class column to be returned.
    
  • prediction.colnames: character

     The column names of the predictions.
    
  • prediction.function: function

     The function to predict newdata.
    
  • task: character(1)

     The inferred prediction task: `"classification"` or `"regression"`.
    

Methods

Public methods

Method new()

Create a Predictor object

Usage

Predictor$new(
  model = NULL,
  data = NULL,
  predict.function = NULL,
  y = NULL,
  class = NULL,
  type = NULL,
  batch.size = 1000
)

Arguments

  • model: any

     The machine learning model. Recommended are models from `mlr` and `caret`. Other machine learning with a S3 predict functions work as well, but less robust (e.g. `randomForest`).
    
  • data: data.frame

     The data to be used for analyzing the prediction model. Allowed column classes are: numeric , factor , integer , ordered and character
     
     For some models the data can be extracted automatically. `Predictor$new()` throws an error when it can't extract the data automatically.
    
  • predict.function: function

     The function to predict newdata. Only needed if `model` is not a model from `mlr` or `caret` package. The first argument of `predict.fun` has to be the model, the second the `newdata`:
     
      
     
     ```
     function(model, newdata)
     ```
    
  • y: character(1) | numeric | factor

     The target vector or (preferably) the name of the target column in the `data` argument. Predictor tries to infer the target automatically from the model.
    
  • class: character(1)

     The class column to be returned. You should use the column name of the predicted class, e.g. `class="setosa"`.
    
  • type: character(1))

     This argument is passed to the prediction function of the model. For regression models you usually don't have to provide the type argument. The classic use case is to say `type="prob"`
     
     for classification models. Consult the documentation of the machine learning package you use to find which type options you have. If both `predict.fun` and `type` are used, then type is passed as an argument to `predict.fun`.
    
  • batch.size: numeric(1)

     The maximum number of rows to be input the model for prediction at once. Currently only respected for FeatureImp , Partial and Interaction .
    

Method predict()

Predict new data with the machine learning model.

Usage

Predictor$predict(newdata)

Arguments

  • newdata: data.frame

     Data to predict on.
    

Method print()

Print the Predictor object.

Usage

Predictor$print()

Method clone()

The objects of this class are cloneable with this method.

Usage

Predictor$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.

  • Maintainer: Giuseppe Casalicchio
  • License: MIT + file LICENSE
  • Last published: 2025-02-24