mlr_callback_set.unfreeze function

Unfreezing Weights Callback

Unfreezing Weights Callback

Unfreeze some weights (parameters of the network) after some number of steps or epochs.

Examples

task = tsk("iris") cb = t_clbk("unfreeze") mlp = lrn("classif.mlp", callbacks = cb, cb.unfreeze.starting_weights = select_invert( select_name(c("0.weight", "3.weight", "6.weight", "6.bias")) ), cb.unfreeze.unfreeze = data.table( epoch = c(2, 5), weights = list(select_name("0.weight"), select_name(c("3.weight", "6.weight"))) ), epochs = 6, batch_size = 150, neurons = c(1, 1, 1) ) mlp$train(task)

See Also

Other Callback: TorchCallback, as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_callback_set.tb, mlr_context_torch, t_clbk(), torch_callback()

Super class

mlr3torch::CallbackSet -> CallbackSetUnfreeze

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

CallbackSetUnfreeze$new(starting_weights, unfreeze)

Arguments

  • starting_weights: (Select)

     A `Select` denoting the weights that are trainable from the start.
    
  • unfreeze: (data.table)

     A `data.table` with a column `weights` (a list column of `Select`s) and a column `epoch` or `batch`. The selector indicates which parameters to unfreeze, while the `epoch` or `batch` column indicates when to do so.
    

Method on_begin()

Sets the starting weights

Usage

CallbackSetUnfreeze$on_begin()

Method on_epoch_begin()

Unfreezes weights if the training is at the correct epoch

Usage

CallbackSetUnfreeze$on_epoch_begin()

Method on_batch_begin()

Unfreezes weights if the training is at the correct batch

Usage

CallbackSetUnfreeze$on_batch_begin()

Method clone()

The objects of this class are cloneable with this method.

Usage

CallbackSetUnfreeze$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.