mlr_learners.torchvision function

AlexNet Image Classifier

AlexNet Image Classifier

Classic image classification networks from torchvision.

Parameters

Parameters from LearnerTorchImage and

  • pretrained :: logical(1)

    Whether to use the pretrained model. The final linear layer will be replaced with a new nn_linear with the number of classes inferred from the Task.

Properties

  • Supported task types: "classif"
  • Predict Types: "response" and "prob"
  • Feature Types: "lazy_tensor"
  • Required packages: "mlr3torch", "torch", "torchvision"

Super classes

mlr3::Learner -> mlr3torch::LearnerTorch -> mlr3torch::LearnerTorchImage -> LearnerTorchVision

Methods

Public methods

Method new()

Creates a new instance of this R6 class.

Usage

LearnerTorchVision$new(
  name,
  module_generator,
  label,
  optimizer = NULL,
  loss = NULL,
  callbacks = list()
)

Arguments

  • name: (character(1))

     The name of the network.
    
  • module_generator: (function(pretrained, num_classes))

     Function that generates the network.
    
  • label: (character(1))

     The label of the network. #' @references Krizhevsky, Alex, Sutskever, Ilya, Hinton, E. G (2017). Imagenet classification with deep convolutional neural networks.
     
     **Communications of the ACM**, 60 (6), 84--90. Sandler, Mark, Howard, Andrew, Zhu, Menglong, Zhmoginov, Andrey, Chen, Liang-Chieh (2018). Mobilenetv2: Inverted residuals and linear bottlenecks.
     
     In **Proceedings of the IEEE conference on computer vision and pattern recognition**, 4510--4520. He, Kaiming, Zhang, Xiangyu, Ren, Shaoqing, Sun, Jian (2016). Deep residual learning for image recognition.
     
     In **Proceedings of the IEEE conference on computer vision and pattern recognition**, 770--778. Simonyan, Karen, Zisserman, Andrew (2014). Very deep convolutional networks for large-scale image recognition.
     
     **arXiv preprint arXiv:1409.1556**.
    
  • optimizer: (TorchOptimizer)

     The optimizer to use for training. Per default, **adam** is used.
    
  • loss: (TorchLoss)

     The loss used to train the network. Per default, **mse** is used for regression and **cross_entropy** for classification.
    
  • callbacks: (list() of TorchCallbacks)

     The callbacks. Must have unique ids.
    

Method clone()

The objects of this class are cloneable with this method.

Usage

LearnerTorchVision$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.