Create a Dataset from a Task
Creates a torch dataset from an mlr3 Task
. The resulting dataset's $.get_batch()
method returns a list with elements x
, y
and index
:
x
is a list with tensors, whose content is defined by the parameter feature_ingress_tokens
.y
is the target variable and its content is defined by the parameter target_batchgetter
..index
is the index of the batch in the task's data.The data is returned on the device specified by the parameter device
.
task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL)
task
: (Task
)
The task for which to build the dataset .
feature_ingress_tokens
: (named list()
of TorchIngressToken
)
Each ingress token defines one item in the $x
value of a batch with corresponding names.
target_batchgetter
: (function(data, device)
)
A function taking in arguments data
, which is a data.table
containing only the target variable, and device
. It must return the target as a torch tensor on the selected device.
torch::dataset
task = tsk("iris") sepal_ingress = TorchIngressToken( features = c("Sepal.Length", "Sepal.Width"), batchgetter = batchgetter_num, shape = c(NA, 2) ) petal_ingress = TorchIngressToken( features = c("Petal.Length", "Petal.Width"), batchgetter = batchgetter_num, shape = c(NA, 2) ) ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress) target_batchgetter = function(data) { torch_tensor(data = data[[1L]], dtype = torch_float32())$unsqueeze(2) } dataset = task_dataset(task, ingress_tokens, target_batchgetter) batch = dataset$.getbatch(1:10) batch
Useful links