task_dataset function

Create a Dataset from a Task

Create a Dataset from a Task

Creates a torch dataset from an mlr3 Task. The resulting dataset's $.get_batch() method returns a list with elements x, y and index:

  • x is a list with tensors, whose content is defined by the parameter feature_ingress_tokens.
  • y is the target variable and its content is defined by the parameter target_batchgetter.
  • .index is the index of the batch in the task's data.

The data is returned on the device specified by the parameter device.

task_dataset(task, feature_ingress_tokens, target_batchgetter = NULL)

Arguments

  • task: (Task)

    The task for which to build the dataset .

  • feature_ingress_tokens: (named list() of TorchIngressToken)

    Each ingress token defines one item in the $x value of a batch with corresponding names.

  • target_batchgetter: (function(data, device))

    A function taking in arguments data, which is a data.table containing only the target variable, and device. It must return the target as a torch tensor on the selected device.

Returns

torch::dataset

Examples

task = tsk("iris") sepal_ingress = TorchIngressToken( features = c("Sepal.Length", "Sepal.Width"), batchgetter = batchgetter_num, shape = c(NA, 2) ) petal_ingress = TorchIngressToken( features = c("Petal.Length", "Petal.Width"), batchgetter = batchgetter_num, shape = c(NA, 2) ) ingress_tokens = list(sepal = sepal_ingress, petal = petal_ingress) target_batchgetter = function(data) { torch_tensor(data = data[[1L]], dtype = torch_float32())$unsqueeze(2) } dataset = task_dataset(task, ingress_tokens, target_batchgetter) batch = dataset$.getbatch(1:10) batch