This wraps a torch::nn_loss and annotates it with metadata, most importantly a ParamSet. The loss function is created for the given parameter values by calling the $generate() method.
This class is usually used to configure the loss function of a torch learner, e.g. when construcing a learner or in a ModelDescriptor.
For a list of available losses, see mlr3torch_losses. Items from this dictionary can be retrieved using t_loss().
Parameters
Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.
Examples
# Create a new torch losstorch_loss = TorchLoss$new(torch_loss = nn_mse_loss, task_types ="regr")torch_loss
# the parameters are inferredtorch_loss$param_set
# Retrieve a loss from the dictionary:torch_loss = t_loss("mse", reduction ="mean")# is the same astorch_loss
torch_loss$param_set
torch_loss$label
torch_loss$task_types
torch_loss$id
# Create the loss functionloss_fn = torch_loss$generate()loss_fn
# Is the same asnn_mse_loss(reduction ="mean")# open the help page of the wrapped loss function# torch_loss$help()# Use in a learnerlearner = lrn("regr.mlp", loss = t_loss("mse"))# The parameters of the loss are added to the learner's parameter setlearner$param_set