Context for Torch Learner
Context for training a torch learner. This is the - mostly read-only - information callbacks have access to through the argument ctx
. For more information on callbacks, see CallbackSet
.
Other Callback: TorchCallback
, as_torch_callback()
, as_torch_callbacks()
, callback_set()
, mlr3torch_callbacks
, mlr_callback_set
, mlr_callback_set.checkpoint
, mlr_callback_set.progress
, mlr_callback_set.tb
, mlr_callback_set.unfreeze
, t_clbk()
, torch_callback()
learner
: (Learner
)
The torch learner.
task_train
: (Task
)
The training task.
task_valid
: (Task
or NULL
)
The validation task.
loader_train
: (torch::dataloader
)
The data loader for training.
loader_valid
: (torch::dataloader
)
The data loader for validation.
measures_train
: (list()
of Measure
s)
Measures used for training.
measures_valid
: (list()
of Measure
s)
Measures used for validation.
network
: (torch::nn_module
)
The torch network.
optimizer
: (torch::optimizer
)
The optimizer.
loss_fn
: (torch::nn_module
)
The loss function.
total_epochs
: (integer(1)
)
The total number of epochs the learner is trained for.
last_scores_train
: (named list()
or NULL
)
The scores from the last training batch. Names are the ids of the training measures. If `LearnerTorch` sets `eval_freq` different from `1`, this is `NULL` in all epochs that don't evaluate the model.
last_scores_valid
: (list()
)
The scores from the last validation batch. Names are the ids of the validation measures. If `LearnerTorch` sets `eval_freq` different from `1`, this is `NULL` in all epochs that don't evaluate the model.
last_loss
: (numeric(1)
)
The loss from the last trainings batch.
epoch
: (integer(1)
)
The current epoch.
step
: (integer(1)
)
The current iteration.
prediction_encoder
: (function()
)
The learner's prediction encoder.
batch
: (named list()
of torch_tensor
s)
The current batch.
terminate
: (logical(1)
)
If this field is set to `TRUE` at the end of an epoch, training stops.
device
: (torch::torch_device
)
The device.
new()
Creates a new instance of this R6 class.
ContextTorch$new(
learner,
task_train,
task_valid = NULL,
loader_train,
loader_valid = NULL,
measures_train = NULL,
measures_valid = NULL,
network,
optimizer,
loss_fn,
total_epochs,
prediction_encoder,
eval_freq = 1L,
device
)
learner
: (Learner
)
The torch learner.
task_train
: (Task
)
The training task.
task_valid
: (Task
or NULL
)
The validation task.
loader_train
: (torch::dataloader
)
The data loader for training.
loader_valid
: (torch::dataloader
or NULL
)
The data loader for validation.
measures_train
: (list()
of Measure
s or NULL
)
Measures used for training. Default is `NULL`.
measures_valid
: (list()
of Measure
s or NULL
)
Measures used for validation.
network
: (torch::nn_module
)
The torch network.
optimizer
: (torch::optimizer
)
The optimizer.
loss_fn
: (torch::nn_module
)
The loss function.
total_epochs
: (integer(1)
)
The total number of epochs the learner is trained for.
prediction_encoder
: (function()
)
The learner's prediction encoder.
eval_freq
: (integer(1)
)
The evaluation frequency.
device
: (character(1)
)
The device.
clone()
The objects of this class are cloneable with this method.
ContextTorch$clone(deep = FALSE)
deep
: Whether to make a deep clone.
Useful links