mlr_learners.tab_resnet function

Tabular ResNet

Tabular ResNet

Tabular resnet.

Dictionary

This Learner can be instantiated using the sugar function lrn():

lrn("classif.tab_resnet", ...)
lrn("regr.tab_resnet", ...)

Properties

Parameters

Parameters from LearnerTorch, as well as:

  • n_blocks :: integer(1)

    The number of blocks.

  • d_block :: integer(1)

    The input and output dimension of a block.

  • d_hidden :: integer(1)

    The latent dimension of a block.

  • d_hidden_multiplier :: numeric(1)

    Alternative way to specify the latent dimension as d_block * d_hidden_multiplier.

  • dropout1 :: numeric(1)

    First dropout ratio.

  • dropout2 :: numeric(1)

    Second dropout ratio.

Examples

# Define the Learner and set parameter values learner = lrn("classif.tab_resnet") learner$param_set$set_values( epochs = 1, batch_size = 16, device = "cpu", n_blocks = 2, d_block = 10, d_hidden = 20, dropout1 = 0.3, dropout2 = 0.3 ) # Define a Task task = tsk("iris") # Create train and test set ids = partition(task) # Train the learner on the training ids learner$train(task, row_ids = ids$train) # Make predictions for the test rows predictions = learner$predict(task, row_ids = ids$test) # Score the predictions predictions$score()

References

Gorishniy Y, Rubachev I, Khrulkov V, Babenko A (2021). Revisiting Deep Learning for Tabular Data.

arXiv, 2106.11959 .

See Also

Other Learner: mlr_learners.mlp, mlr_learners.torch_featureless, mlr_learners_torch, mlr_learners_torch_image, mlr_learners_torch_model

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> LearnerTorchTabResNet

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchTabResNet$new(
  task_type,
  optimizer = NULL,
  loss = NULL,
  callbacks = list()
)

Arguments

  • task_type: (character(1))

     The task type, either `"classif`" or `"regr"`.
    
  • optimizer: (TorchOptimizer)

     The optimizer to use for training. Per default, **adam** is used.
    
  • loss: (TorchLoss)

     The loss used to train the network. Per default, **mse** is used for regression and **cross_entropy** for classification.
    
  • callbacks: (list() of TorchCallbacks)

     The callbacks. Must have unique ids.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchTabResNet$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.