TorchLoss function

Torch Loss

Torch Loss

This wraps a torch::nn_loss and annotates it with metadata, most importantly a ParamSet. The loss function is created for the given parameter values by calling the $generate() method.

This class is usually used to configure the loss function of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.

For a list of available losses, see mlr3torch_losses. Items from this dictionary can be retrieved using t_loss().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Examples

# Create a new torch loss torch_loss = TorchLoss$new(torch_loss = nn_mse_loss, task_types = "regr") torch_loss # the parameters are inferred torch_loss$param_set # Retrieve a loss from the dictionary: torch_loss = t_loss("mse", reduction = "mean") # is the same as torch_loss torch_loss$param_set torch_loss$label torch_loss$task_types torch_loss$id # Create the loss function loss_fn = torch_loss$generate() loss_fn # Is the same as nn_mse_loss(reduction = "mean") # open the help page of the wrapped loss function # torch_loss$help() # Use in a learner learner = lrn("regr.mlp", loss = t_loss("mse")) # The parameters of the loss are added to the learner's parameter set learner$param_set

See Also

Other Torch Descriptor: TorchCallback, TorchDescriptor, TorchOptimizer, as_torch_callbacks(), as_torch_loss(), as_torch_optimizer(), mlr3torch_losses, mlr3torch_optimizers, t_clbk(), t_loss(), t_opt()

Super class

mlr3torch::TorchDescriptor -> TorchLoss

Public fields

  • task_types: (character())

     The task types this loss supports.
    

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

TorchLoss$new(
  torch_loss,
  task_types = NULL,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL
)

Arguments

  • torch_loss: (nn_loss)

     The loss module.
    
  • task_types: (character())

     The task types supported by this loss.
    
  • param_set: (ParamSet or NULL)

     The parameter set. If `NULL` (default) it is inferred from `torch_loss`.
    
  • id: (character(1))

     The id for of the new object.
    
  • label: (character(1))

     Label for the new instance.
    
  • packages: (character())

     The R packages this object depends on.
    
  • man: (character(1))

     String in the format `[pkg]::[topic]` pointing to a manual page for this object. The referenced help package can be opened via method `$help()`.
    

Method print()

Prints the object

Usage

TorchLoss$print(...)

Arguments

  • ...: any

Method clone()

The objects of this class are cloneable with this method.

Usage

TorchLoss$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.