mlr_pipeops_torch_model_classif function

PipeOp Torch Classifier

PipeOp Torch Classifier

Builds a torch classifier and trains it.

Parameters

See LearnerTorch

Input and Output Channels

There is one input channel "input" that takes in ModelDescriptor during traing and a Task of the specified task_type during prediction. The output is NULL during training and a Prediction of given task_type during prediction.

State

A trained LearnerTorchModel.

Internals

A LearnerTorchModel is created by calling model_descriptor_to_learner() on the provided ModelDescriptor that is received through the input channel. Then the parameters are set according to the parameters specified in PipeOpTorchModel and its '$train()method is called on the [Task][mlr3::Task] stored in the [ModelDescriptor`].

Examples

# simple logistic regression # configure the model descriptor md = as_graph(po("torch_ingress_num") %>>% po("nn_head") %>>% po("torch_loss", "cross_entropy") %>>% po("torch_optimizer", "adam"))$train(tsk("iris"))[[1L]] print(md) # build the learner from the model descriptor and train it po_model = po("torch_model_classif", batch_size = 50, epochs = 1) po_model$train(list(md)) po_model$state

See Also

Other PipeOps: mlr_pipeops_nn_adaptive_avg_pool1d, mlr_pipeops_nn_adaptive_avg_pool2d, mlr_pipeops_nn_adaptive_avg_pool3d, mlr_pipeops_nn_avg_pool1d, mlr_pipeops_nn_avg_pool2d, mlr_pipeops_nn_avg_pool3d, mlr_pipeops_nn_batch_norm1d, mlr_pipeops_nn_batch_norm2d, mlr_pipeops_nn_batch_norm3d, mlr_pipeops_nn_block, mlr_pipeops_nn_celu, mlr_pipeops_nn_conv1d, mlr_pipeops_nn_conv2d, mlr_pipeops_nn_conv3d, mlr_pipeops_nn_conv_transpose1d, mlr_pipeops_nn_conv_transpose2d, mlr_pipeops_nn_conv_transpose3d, mlr_pipeops_nn_dropout, mlr_pipeops_nn_elu, mlr_pipeops_nn_flatten, mlr_pipeops_nn_gelu, mlr_pipeops_nn_glu, mlr_pipeops_nn_hardshrink, mlr_pipeops_nn_hardsigmoid, mlr_pipeops_nn_hardtanh, mlr_pipeops_nn_head, mlr_pipeops_nn_layer_norm, mlr_pipeops_nn_leaky_relu, mlr_pipeops_nn_linear, mlr_pipeops_nn_log_sigmoid, mlr_pipeops_nn_max_pool1d, mlr_pipeops_nn_max_pool2d, mlr_pipeops_nn_max_pool3d, mlr_pipeops_nn_merge, mlr_pipeops_nn_merge_cat, mlr_pipeops_nn_merge_prod, mlr_pipeops_nn_merge_sum, mlr_pipeops_nn_prelu, mlr_pipeops_nn_relu, mlr_pipeops_nn_relu6, mlr_pipeops_nn_reshape, mlr_pipeops_nn_rrelu, mlr_pipeops_nn_selu, mlr_pipeops_nn_sigmoid, mlr_pipeops_nn_softmax, mlr_pipeops_nn_softplus, mlr_pipeops_nn_softshrink, mlr_pipeops_nn_softsign, mlr_pipeops_nn_squeeze, mlr_pipeops_nn_tanh, mlr_pipeops_nn_tanhshrink, mlr_pipeops_nn_threshold, mlr_pipeops_nn_unsqueeze, mlr_pipeops_torch_ingress, mlr_pipeops_torch_ingress_categ, mlr_pipeops_torch_ingress_ltnsr, mlr_pipeops_torch_ingress_num, mlr_pipeops_torch_loss, mlr_pipeops_torch_model, mlr_pipeops_torch_model_regr

Super classes

mlr3pipelines::PipeOp -> mlr3pipelines::PipeOpLearner -> mlr3torch::PipeOpTorchModel -> PipeOpTorchModelClassif

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

PipeOpTorchModelClassif$new(id = "torch_model_classif", param_vals = list())

Arguments

  • id: (character(1))

     Identifier of the resulting object.
    
  • param_vals: (list())

     List of hyperparameter settings, overwriting the hyperparameter settings that would otherwise be set during construction.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpTorchModelClassif$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.