TorchOptimizer function

Torch Optimizer

Torch Optimizer

This wraps a torch::torch_optimizer_generatora and annotates it with metadata, most importantly a ParamSet. The optimizer is created for the given parameter values by calling the $generate() method.

This class is usually used to configure the optimizer of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.

For a list of available optimizers, see mlr3torch_optimizers. Items from this dictionary can be retrieved using t_opt().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Examples

# Create a new torch loss torch_opt = TorchOptimizer$new(optim_ignite_adam, label = "adam") torch_opt # If the param set is not specified, parameters are inferred but are of class ParamUty torch_opt$param_set # open the help page of the wrapped optimizer # torch_opt$help() # Retrieve an optimizer from the dictionary torch_opt = t_opt("sgd", lr = 0.1) torch_opt torch_opt$param_set torch_opt$label torch_opt$id # Create the optimizer for a network net = nn_linear(10, 1) opt = torch_opt$generate(net$parameters) # is the same as optim_sgd(net$parameters, lr = 0.1) # Use in a learner learner = lrn("regr.mlp", optimizer = t_opt("sgd")) # The parameters of the optimizer are added to the learner's parameter set learner$param_set

See Also

Other Torch Descriptor: TorchCallback, TorchDescriptor, TorchLoss, as_torch_callbacks(), as_torch_loss(), as_torch_optimizer(), mlr3torch_losses, mlr3torch_optimizers, t_clbk(), t_loss(), t_opt()

Super class

mlr3torch::TorchDescriptor -> TorchOptimizer

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

TorchOptimizer$new(
  torch_optimizer,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)

Arguments

  • torch_optimizer: (torch_optimizer_generator)

     The torch optimizer.
    
  • param_set: (ParamSet or NULL)

     The parameter set. If `NULL` (default) it is inferred from `torch_optimizer`.
    
  • id: (character(1))

     The id for of the new object.
    
  • label: (character(1))

     Label for the new instance.
    
  • packages: (character())

     The R packages this object depends on.
    
  • man: (character(1))

     String in the format `[pkg]::[topic]` pointing to a manual page for this object. The referenced help package can be opened via method `$help()`.
    

Method generate()

Instantiates the optimizer.

Usage

TorchOptimizer$generate(params)

Arguments

  • params: (named list() of torch_tensors)

     The parameters of the network.
    

Returns

torch_optimizer

Method clone()

The objects of this class are cloneable with this method.

Usage

TorchOptimizer$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.