Base Class for Torch Learners
This base class provides the basic functionality for training and prediction of a neural network. All torch learners should inherit from this class.
To specify the validation data, you can set the $validate
field of the Learner, which can be set to:
NULL
: no validationratio
: only proportion 1 - ratio
of the task is used for training and ratio
is used for validation."test"
means that the "test"
task of a resampling is used and is not possible when calling $train()
manually."predefined"
: This will use the predefined $internal_valid_task
of a mlr3::Task
.This validation data can also be used for early stopping, see the description of the Learner
's parameters.
In order to save a LearnerTorch
for later usage, it is necessary to call the $marshal()
method on the Learner
before writing it to disk, as the object will otherwise not be saved correctly. After loading a marshaled LearnerTorch
into R again, you then need to call $unmarshal()
to transform it into a useable state.
In order to prevent overfitting, the LearnerTorch
class allows to use early stopping via the patience
and min_delta
parameters, see the Learner
's parameters. When tuning a LearnerTorch
it is also possible to combine the explicit tuning via mlr3tuning
and the LearnerTorch
's internal tuning of the epochs via early stopping. To do so, you just need to include epochs = to_tune(upper = <upper>, internal = TRUE)
in the search space, where <upper>
is the maximally allowed number of epochs, and configure the early stopping.
The Model is a list of class "learner_torch_model"
with the following elements:
network
:: The trained network .optimizer
:: The $state_dict()
optimizer used to train the network.loss_fn
:: The $state_dict()
of the loss used to train the network.callbacks
:: The callbacks used to train the network.seed
:: The seed that was / is used for training and prediction.epochs
:: How many epochs the model was trained for (early stopping).task_col_info
:: A data.table()
containing information about the train-task.General :
The parameters of the optimizer, loss and callbacks, prefixed with "opt."
, "loss."
and "cb.<callback id>."
respectively, as well as:
epochs
:: integer(1)
The number of epochs.
device
:: character(1)
The device. One of "auto"
, "cpu"
, or "cuda"
or other values defined in mlr_reflections$torch$devices
. The value is initialized to "auto"
, which will select "cuda"
if possible, then try "mps"
and otherwise fall back to "cpu"
.
num_threads
:: integer(1)
The number of threads for intraop pararallelization (if device
is "cpu"
). This value is initialized to 1.
num_interop_threads
:: integer(1)
The number of threads for intraop and interop pararallelization (if device
is "cpu"
). This value is initialized to 1. Note that this can only be set once during a session and changing the value within an R session will raise a warning.
seed
:: integer(1)
or "random"
or NULL
The torch seed that is used during training and prediction. This value is initialized to "random"
, which means that a random seed will be sampled at the beginning of the training phase. This seed (either set or randomly sampled) is available via $model$seed
after training and used during prediction. Note that by setting the seed during the training phase this will mean that by default (i.e. when seed
is "random"
), clones of the learner will use a different seed. If set to NULL
, no seeding will be done.
tensor_dataset
:: logical(1)
| "device"
Whether to load all batches at once at the beginning of training and stack them. This is initialized to FALSE
. If set to "device"
, the device of the tensors will be set to the value of device
, which can avoid unnecessary moving of tensors between devices. When your dataset fits into memory this will make the loading of batches faster. Note that this should not be set for datasets that contain lazy_tensor
s with random data augmentation, as this augmentation will only be applied once at the beginning of training.
Evaluation :
measures_train
:: Measure
or list()
of Measure
s
Measures to be evaluated during training.
measures_valid
:: Measure
or list()
of Measure
s
Measures to be evaluated during validation.
eval_freq
:: integer(1)
How often the train / validation predictions are evaluated using measures_train
/ measures_valid
. This is initialized to 1
. Note that the final model is always evaluated.
Early Stopping :
patience
:: integer(1)
This activates early stopping using the validation scores. If the performance of a model does not improve for patience
evaluation steps, training is ended. Note that the final model is stored in the learner, not the best model. This is initialized to 0
, which means no early stopping. The first entry from measures_valid
is used as the metric. This also requires to specify the $validate
field of the Learner, as well as measures_valid
. If this is set, the epoch after which no improvement was observed, can be accessed via the $internal_tuned_values
field of the learner.
min_delta
:: double(1)
The minimum improvement threshold for early stopping. Is initialized to 0.
Dataloader :
batch_size
:: integer(1)
The batch size (required).
shuffle
:: logical(1)
Whether to shuffle the instances in the dataset. This is initialized to TRUE
, which differs from the default (FALSE
).
sampler
:: torch::sampler
Object that defines how the dataloader draw samples.
batch_sampler
:: torch::sampler
Object that defines how the dataloader draws batches.
num_workers
:: integer(1)
The number of workers for data loading (batches are loaded in parallel). The default is 0
, which means that data will be loaded in the main process.
collate_fn
:: function
How to merge a list of samples to form a batch.
pin_memory
:: logical(1)
Whether the dataloader copies tensors into CUDA pinned memory before returning them.
drop_last
:: logical(1)
Whether to drop the last training batch in each epoch during training. Default is FALSE
.
timeout
:: numeric(1)
The timeout value for collecting a batch from workers. Negative values mean no timeout and the default is -1
.
worker_init_fn
:: function(id)
A function that receives the worker id (in [1, num_workers]
) and is exectued after seeding on the worker but before data loading.
worker_globals
:: list()
| character()
When loading data in parallel, this allows to export globals to the workers. If this is a character vector, the objects in the global environment with those names are copied to the workers.
worker_packages
:: character()
Which packages to load on the workers.
Also see torch::dataloder
for more information.
There are no seperate classes for classification and regression to inherit from. Instead, the task_type
must be specified as a construction argument. Currently, only classification and regression are supported.
When inheriting from this class, one should overload two private methods:
.network(task, param_vals)
(Task
, list()
) -> nn_module
Construct a torch::nn_module
object for the given task and parameter values, i.e. the neural network that is trained by the learner. For classification, the output of this network are expected to be the scores before the application of the final softmax layer.
.dataset(task, param_vals)
(Task
, list()
) -> torch::dataset
Create the dataset for the task. The dataset must return a named list where:
x
is a list of torch tensors that are the input to the network. For networks with more than one input, the names must correspond to the inputs of the network.y
is the target tensor..index
are the indices of the batch (integer()
or a torch_int()
).Moreover, one needs to pay attention respect the row ids of the provided task.
It is also possible to overwrite the private .dataloader()
method. This must respect the dataloader parameters from the ParamSet
.
.dataloader(dataset, param_vals)
(Task
, list()
) -> torch::dataloader
Create a dataloader from the task. Needs to respect at least batch_size
and shuffle
(otherwise predictions can be permuted).
To change the predict types, the it is possible to overwrite the method below:
.encode_prediction(predict_tensor, task)
(torch_tensor
, Task
) -> list()
Take in the raw predictions from self$network
(predict_tensor
) and encode them into a format that can be converted to valid mlr3
predictions using mlr3::as_prediction_data()
. This method must take self$predict_type
into account.
While it is possible to add parameters by specifying the param_set
construction argument, it is currently
not possible to remove existing parameters, i.e. those listed in section Parameters.
None of the parameters provided in param_set
can have an id that starts with "loss."
, "opt.", or
"cb."`, as these are preserved for the dynamically constructed parameters of the optimizer, the loss function,
and the callbacks.
To perform additional input checks on the task, the private .verify_train_task(task, param_vals)
and
.verify_predict_task(task, param_vals)
can be overwritten.
For learners that have other construction arguments that should change the hash of a learner, it is required
to implement the private $.additional_phash_input()
.
Other Learner: mlr_learners.mlp
, mlr_learners.tab_resnet
, mlr_learners.torch_featureless
, mlr_learners_torch_image
, mlr_learners_torch_model
mlr3::Learner
-> LearnerTorch
validate
: How to construct the internal validation data. This parameter can be either NULL
, a ratio in , "test"
, or "predefined"
.
loss
: (TorchLoss
)
The torch loss.
optimizer
: (TorchOptimizer
)
The torch optimizer.
callbacks
: (list()
of TorchCallback
s)
List of torch callbacks. The ids will be set as the names.
internal_valid_scores
: Retrieves the internal validation scores as a named list()
. Specify the $validate
field and the measures_valid
parameter to configure this. Returns NULL
if learner is not trained yet.
internal_tuned_values
: When early stopping is activate, this returns a named list with the early-stopped epochs, otherwise an empty list is returned. Returns NULL
if learner is not trained yet.
marshaled
: (logical(1)
)
Whether the learner is marshaled.
network
: (nn_module()
)
Shortcut for `learner$model$network`.
param_set
: (ParamSet
)
The parameter set
hash
: (character(1)
)
Hash (unique identifier) for this object.
phash
: (character(1)
)
Hash (unique identifier) for this partial object, excluding some components which are varied systematically during tuning (parameter values).
new()
Creates a new instance of this R6 class.
LearnerTorch$new(
id,
task_type,
param_set,
properties,
man,
label,
feature_types,
optimizer = NULL,
loss = NULL,
packages = character(),
predict_types = NULL,
callbacks = list()
)
id
: (character(1)
)
The id for of the new object.
task_type
: (character(1)
)
The task type.
param_set
: (ParamSet
or alist()
)
Either a parameter set, or an `alist()` containing different values of self, e.g. `alist(private$.param_set1, private$.param_set2)`, from which a `ParamSet` collection should be created.
properties
: (character()
)
The properties of the object. See `mlr_reflections$learner_properties` for available values.
man
: (character(1)
)
String in the format `[pkg]::[topic]` pointing to a manual page for this object. The referenced help package can be opened via method `$help()`.
label
: (character(1)
)
Label for the new instance.
feature_types
: (character()
)
The feature types. See `mlr_reflections$task_feature_types` for available values, Additionally, `"lazy_tensor"` is supported.
optimizer
: (NULL
or TorchOptimizer
)
The optimizer to use for training. Defaults to adam.
loss
: (NULL
or TorchLoss
)
The loss to use for training. Defaults to MSE for regression and cross entropy for classification.
packages
: (character()
)
The R packages this object depends on.
predict_types
: (character()
)
The predict types. See `mlr_reflections$learner_predict_types` for available values. For regression, the default is `"response"`. For classification, this defaults to `"response"` and `"prob"`. To deviate from the defaults, it is necessary to overwrite the private `$.encode_prediction()`
method, see section **Inheriting**.
callbacks
: (list()
of TorchCallback
s)
The callbacks to use for training. Defaults to an empty`list()`, i.e. no callbacks.
format()
Helper for print outputs.
LearnerTorch$format(...)
...
: (ignored).
print()
Prints the object.
LearnerTorch$print(...)
...
: (any)
Currently unused.
marshal()
Marshal the learner.
LearnerTorch$marshal(...)
...
: (any)
Additional parameters.
self
unmarshal()
Unmarshal the learner.
LearnerTorch$unmarshal(...)
...
: (any)
Additional parameters.
self
dataset()
Create the dataset for a task.
LearnerTorch$dataset(task)
task
: Task
The task
dataset
clone()
The objects of this class are cloneable with this method.
LearnerTorch$clone(deep = FALSE)
deep
: Whether to make a deep clone.
Useful links