Base Class for Callbacks
Base class from which callbacks should inherit (see section Inheriting). A callback set is a collection of functions that are executed at different stages of the training loop. They can be used to gain more control over the training process of a neural network without having to write everything from scratch.
When used a in torch learner, the CallbackSet
is wrapped in a TorchCallback
. The latters parameter set represents the arguments of the CallbackSet
's $initialize()
method.
For each available stage (see section Stages) a public method $on_<stage>()
can be defined. The evaluation context (a ContextTorch
) can be accessed via self$ctx
, which contains the current state of the training loop. This context is assigned at the beginning of the training loop and removed afterwards. Different stages of a callback can communicate with each other by assigning values to $self
.
State: To be able to store information in the $model
slot of a LearnerTorch
, callbacks support a state API. You can overload the $state_dict()
public method to define what will be stored in learner$model$callbacks$<id>
after training finishes. This then also requires to implement a $load_state_dict(state_dict)
method that defines how to load a previously saved callback state into a different callback. Note that the $state_dict()
should not include the parameter values that were used to initialize the callback.
For creating custom callbacks, the function torch_callback()
is recommended, which creates a CallbackSet
and then wraps it in a TorchCallback
. To create a CallbackSet
the convenience function callback_set()
can be used. These functions perform checks such as that the stages are not accidentally misspelled.
begin
:: Run before the training loop begins.epoch_begin
:: Run he beginning of each epoch.batch_begin
:: Run before the forward call.after_backward
:: Run after the backward call.batch_end
:: Run after the optimizer step.batch_valid_begin
:: Run before the forward call in the validation loop.batch_valid_end
:: Run after the forward call in the validation loop.valid_end
:: Run at the end of validation.epoch_end
:: Run at the end of each epoch.end
:: Run after last epoch.exit
:: Run at last, using on.exit()
.If training is to be stopped, it is possible to set the field $terminate
of ContextTorch
. At the end of every epoch this field is checked and if it is TRUE
, training stops. This can for example be used to implement custom early stopping.
Other Callback: TorchCallback
, as_torch_callback()
, as_torch_callbacks()
, callback_set()
, mlr3torch_callbacks
, mlr_callback_set.checkpoint
, mlr_callback_set.progress
, mlr_callback_set.tb
, mlr_callback_set.unfreeze
, mlr_context_torch
, t_clbk()
, torch_callback()
ctx
: (ContextTorch
or NULL
)
The evaluation context for the callback. This field should always be `NULL` except during the `$train()` call of the torch learner.
stages
: (character()
)
The active stages of this callback set.
print()
Prints the object.
CallbackSet$print(...)
...
: (any)
Currently unused.
state_dict()
Returns information that is kept in the the LearnerTorch
's state after training. This information should be loadable into the callback using $load_state_dict()
to be able to continue training. This returns NULL
by default.
CallbackSet$state_dict()
load_state_dict()
Loads the state dict into the callback to continue training.
CallbackSet$load_state_dict(state_dict)
state_dict
: (any)
The state dict as retrieved via `$state_dict()`.
clone()
The objects of this class are cloneable with this method.
CallbackSet$clone(deep = FALSE)
deep
: Whether to make a deep clone.
Useful links