mlr_context_torch function

Context for Torch Learner

Context for Torch Learner

Context for training a torch learner. This is the - mostly read-only - information callbacks have access to through the argument ctx. For more information on callbacks, see CallbackSet.

See Also

Other Callback: TorchCallback, as_torch_callback(), as_torch_callbacks(), callback_set(), mlr3torch_callbacks, mlr_callback_set, mlr_callback_set.checkpoint, mlr_callback_set.progress, mlr_callback_set.tb, mlr_callback_set.unfreeze, t_clbk(), torch_callback()

Public fields

  • learner: (Learner)

     The torch learner.
    
  • task_train: (Task)

     The training task.
    
  • task_valid: (Task or NULL)

     The validation task.
    
  • loader_train: (torch::dataloader)

     The data loader for training.
    
  • loader_valid: (torch::dataloader)

     The data loader for validation.
    
  • measures_train: (list() of Measures)

     Measures used for training.
    
  • measures_valid: (list() of Measures)

     Measures used for validation.
    
  • network: (torch::nn_module)

     The torch network.
    
  • optimizer: (torch::optimizer)

     The optimizer.
    
  • loss_fn: (torch::nn_module)

     The loss function.
    
  • total_epochs: (integer(1))

     The total number of epochs the learner is trained for.
    
  • last_scores_train: (named list() or NULL)

     The scores from the last training batch. Names are the ids of the training measures. If `LearnerTorch` sets `eval_freq` different from `1`, this is `NULL` in all epochs that don't evaluate the model.
    
  • last_scores_valid: (list())

     The scores from the last validation batch. Names are the ids of the validation measures. If `LearnerTorch` sets `eval_freq` different from `1`, this is `NULL` in all epochs that don't evaluate the model.
    
  • last_loss: (numeric(1))

     The loss from the last trainings batch.
    
  • epoch: (integer(1))

     The current epoch.
    
  • step: (integer(1))

     The current iteration.
    
  • prediction_encoder: (function())

     The learner's prediction encoder.
    
  • batch: (named list() of torch_tensors)

     The current batch.
    
  • terminate: (logical(1))

     If this field is set to `TRUE` at the end of an epoch, training stops.
    
  • device: (torch::torch_device)

     The device.
    

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

ContextTorch$new(
  learner,
  task_train,
  task_valid = NULL,
  loader_train,
  loader_valid = NULL,
  measures_train = NULL,
  measures_valid = NULL,
  network,
  optimizer,
  loss_fn,
  total_epochs,
  prediction_encoder,
  eval_freq = 1L,
  device
)

Arguments

  • learner: (Learner)

     The torch learner.
    
  • task_train: (Task)

     The training task.
    
  • task_valid: (Task or NULL)

     The validation task.
    
  • loader_train: (torch::dataloader)

     The data loader for training.
    
  • loader_valid: (torch::dataloader or NULL)

     The data loader for validation.
    
  • measures_train: (list() of Measures or NULL)

     Measures used for training. Default is `NULL`.
    
  • measures_valid: (list() of Measures or NULL)

     Measures used for validation.
    
  • network: (torch::nn_module)

     The torch network.
    
  • optimizer: (torch::optimizer)

     The optimizer.
    
  • loss_fn: (torch::nn_module)

     The loss function.
    
  • total_epochs: (integer(1))

     The total number of epochs the learner is trained for.
    
  • prediction_encoder: (function())

     The learner's prediction encoder.
    
  • eval_freq: (integer(1))

     The evaluation frequency.
    
  • device: (character(1))

     The device.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

ContextTorch$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.