TorchCallback function

Torch Callback

Torch Callback

This wraps a CallbackSet and annotates it with metadata, most importantly a ParamSet. The callback is created for the given parameter values by calling the $generate() method.

This class is usually used to configure the callback of a torch learner, e.g. when constructing a learner of in a ModelDescriptor.

For a list of available callbacks, see mlr3torch_callbacks. To conveniently retrieve a TorchCallback, use t_clbk().

Parameters

Defined by the constructor argument param_set. If no parameter set is provided during construction, the parameter set is constructed by creating a parameter for each argument of the wrapped loss function, where the parametes are then of type ParamUty.

Examples

# Create a new torch callback from an existing callback set torch_callback = TorchCallback$new(CallbackSetCheckpoint) # The parameters are inferred torch_callback$param_set # Retrieve a torch callback from the dictionary torch_callback = t_clbk("checkpoint", path = tempfile(), freq = 1 ) torch_callback torch_callback$label torch_callback$id # open the help page of the wrapped callback set # torch_callback$help() # Create the callback set callback = torch_callback$generate() callback # is the same as CallbackSetCheckpoint$new( path = tempfile(), freq = 1 ) # Use in a learner learner = lrn("regr.mlp", callbacks = t_clbk("checkpoint")) # the parameters of the callback are added to the learner's parameter set learner$param_set

See Also

Other Callback: as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_callback_set.tb, mlr_callback_set.unfreeze, mlr_context_torch, t_clbk(), torch_callback()

Other Torch Descriptor: TorchDescriptor, TorchLoss, TorchOptimizer, as_torch_callbacks(), as_torch_loss(), as_torch_optimizer(), mlr3torch_losses, mlr3torch_optimizers, t_clbk(), t_loss(), t_opt()

Super class

mlr3torch::TorchDescriptor -> TorchCallback

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

TorchCallback$new(
  callback_generator,
  param_set = NULL,
  id = NULL,
  label = NULL,
  packages = NULL,
  man = NULL,
  additional_args = NULL
)

Arguments

  • callback_generator: (R6ClassGenerator)

     The class generator for the callback that is being wrapped.
    
  • param_set: (ParamSet or NULL)

     The parameter set. If `NULL` (default) it is inferred from `callback_generator`.
    
  • id: (character(1))

     The id for of the new object.
    
  • label: (character(1))

     Label for the new instance.
    
  • packages: (character())

     The R packages this object depends on.
    
  • man: (character(1))

     String in the format `[pkg]::[topic]` pointing to a manual page for this object. The referenced help package can be opened via method `$help()`.
    
  • additional_args: (any)

     Additional arguments if necessary. For learning rate schedulers, this is the torch::LRScheduler.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

TorchCallback$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.