mlr_pipeops_module function

Class for Torch Module Wrappers

Class for Torch Module Wrappers

PipeOpModule wraps an nn_module or function that is being called during the train phase of this mlr3pipelines::PipeOp. By doing so, this allows to assemble PipeOpModules in a computational mlr3pipelines::Graph that represents either a neural network or a preprocessing graph of a lazy_tensor. In most cases it is easier to create such a network by creating a graph that generates this graph.

In most cases it is easier to create such a network by creating a structurally related graph consisting of nodes of class PipeOpTorchIngress and PipeOpTorch. This graph will then generate the graph consisting of PipeOpModules as part of the ModelDescriptor.

Input and Output Channels

The number and names of the input and output channels can be set during construction. They input and output "torch_tensor" during training, and NULL during prediction as the prediction phase currently serves no meaningful purpose.

State

The state is the value calculated by the public method shapes_out().

Parameters

No parameters.

Internals

During training, the wrapped nn_module / function is called with the provided inputs in the order in which the channels are defined. Arguments are not matched by name.

Examples

## creating an PipeOpModule manually # one input and output channel po_module = po("module", id = "linear", module = torch::nn_linear(10, 20), inname = "input", outname = "output" ) x = torch::torch_randn(16, 10) # This calls the forward function of the wrapped module. y = po_module$train(list(input = x)) str(y) # multiple input and output channels nn_custom = torch::nn_module("nn_custom", initialize = function(in_features, out_features) { self$lin1 = torch::nn_linear(in_features, out_features) self$lin2 = torch::nn_linear(in_features, out_features) }, forward = function(x, z) { list(out1 = self$lin1(x), out2 = torch::nnf_relu(self$lin2(z))) } ) module = nn_custom(3, 2) po_module = po("module", id = "custom", module = module, inname = c("x", "z"), outname = c("out1", "out2") ) x = torch::torch_randn(1, 3) z = torch::torch_randn(1, 3) out = po_module$train(list(x = x, z = z)) str(out) # How such a PipeOpModule is usually generated graph = po("torch_ingress_num") %>>% po("nn_linear", out_features = 10L) result = graph$train(tsk("iris")) # The PipeOpTorchLinear generates a PipeOpModule and adds it to a new (module) graph result[[1]]$graph linear_module = result[[1L]]$graph$pipeops$nn_linear linear_module formalArgs(linear_module$module) linear_module$input$name # Constructing a PipeOpModule using a simple function po_add1 = po("module", id = "add_one", module = function(x) x + 1 ) input = list(torch_tensor(1)) po_add1$train(input)$output

See Also

Other Graph Network: ModelDescriptor(), TorchIngressToken(), mlr_learners_torch_model, mlr_pipeops_torch, mlr_pipeops_torch_ingress, mlr_pipeops_torch_ingress_categ, mlr_pipeops_torch_ingress_ltnsr, mlr_pipeops_torch_ingress_num, model_descriptor_to_learner(), model_descriptor_to_module(), model_descriptor_union(), nn_graph()

Other PipeOp: mlr_pipeops_torch_callbacks, mlr_pipeops_torch_optimizer

Super class

mlr3pipelines::PipeOp -> PipeOpModule

Public fields

  • module: (nn_module)

     The torch module that is called during the training phase.
    

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

PipeOpModule$new(
  id = "module",
  module = nn_identity(),
  inname = "input",
  outname = "output",
  param_vals = list(),
  packages = character(0)
)

Arguments

  • id: (character(1))

     The id for of the new object.
    
  • module: (nn_module or function())

     The torch module or function that is being wrapped.
    
  • inname: (character())

     The names of the input channels.
    
  • outname: (character())

     The names of the output channels. If this parameter has length 1, the parameter module must return a tensor . Otherwise it must return a `list()` of tensors of corresponding length.
    
  • param_vals: (named list())

     Parameter values to be set after construction.
    
  • packages: (character())

     The R packages this object depends on.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

PipeOpModule$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.