mlr_learners_torch_model function

Learner Torch Model

Learner Torch Model

Create a torch learner from an instantiated nn_module(). For classification, the output of the network must be the scores (before the softmax).

Parameters

See LearnerTorch

Examples

# We show the learner using a classification task # The iris task has 4 features and 3 classes network = nn_linear(4, 3) task = tsk("iris") # This defines the dataloader. # It loads all 4 features, which are also numeric. # The shape is (NA, 4) because the batch dimension is generally NA ingress_tokens = list( input = TorchIngressToken(task$feature_names, batchgetter_num, c(NA, 4)) ) # Creating the learner and setting required parameters learner = lrn("classif.torch_model", network = network, ingress_tokens = ingress_tokens, batch_size = 16, epochs = 1, device = "cpu" ) # A simple train-predict ids = partition(task) learner$train(task, ids$train) learner$predict(task, ids$test)

See Also

Other Learner: mlr_learners.mlp, mlr_learners.tab_resnet, mlr_learners.torch_featureless, mlr_learners_torch, mlr_learners_torch_image

Other Graph Network: ModelDescriptor(), TorchIngressToken(), mlr_pipeops_module, mlr_pipeops_torch, mlr_pipeops_torch_ingress, mlr_pipeops_torch_ingress_categ, mlr_pipeops_torch_ingress_ltnsr, mlr_pipeops_torch_ingress_num, model_descriptor_to_learner(), model_descriptor_to_module(), model_descriptor_union(), nn_graph()

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchModel

Active bindings

  • ingress_tokens: (named list() with TorchIngressToken or NULL)

     The ingress tokens. Must be non-`NULL` when calling `$train()`.
    

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchModel$new(
  network = NULL,
  ingress_tokens = NULL,
  task_type,
  properties = NULL,
  optimizer = NULL,
  loss = NULL,
  callbacks = list(),
  packages = character(0),
  feature_types = NULL
)

Arguments

  • network: (nn_module)

     An instantiated `nn_module`. Is not cloned during construction. For classification, outputs must be the scores (before the softmax).
    
  • ingress_tokens: (list of TorchIngressToken())

     A list with ingress tokens that defines how the dataloader will be defined.
    
  • task_type: (character(1))

     The task type.
    
  • properties: (NULL or character())

     The properties of the learner. Defaults to all available properties for the given task type.
    
  • optimizer: (TorchOptimizer)

     The torch optimizer.
    
  • loss: (TorchLoss)

     The loss to use for training.
    
  • callbacks: (list() of TorchCallbacks)

     The callbacks used during training. Must have unique ids. They are executed in the order in which they are provided
    
  • packages: (character())

     The R packages this object depends on.
    
  • feature_types: (NULL or character())

     The feature types. Defaults to all available feature types.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchModel$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.