train function

Train an imputation model using Midas

Train an imputation model using Midas

Build and run a MIDAS neural network on the supplied missing data.

train( data, binary_columns = NULL, softmax_columns = NULL, training_epochs = 10L, layer_structure = c(256, 256, 256), learn_rate = 4e-04, input_drop = 0.8, seed = 123L, train_batch = 16L, latent_space_size = 4, cont_adj = 1, binary_adj = 1, softmax_adj = 1, dropout_level = 0.5, vae_layer = FALSE, vae_alpha = 1, vae_sample_var = 1 )

Arguments

  • data: A data.frame (or coercible) object, or an object of class midas_pre created from rMIDAS::convert()
  • binary_columns: A vector of column names, containing binary variables. NOTE: if data is a midas_pre object, this argument will be overwritten.
  • softmax_columns: A list of lists, each internal list corresponding to a single categorical variable and containing names of the one-hot encoded variable names. NOTE: if data is a midas_pre object, this argument will be overwritten.
  • training_epochs: An integer, indicating the number of forward passes to conduct when running the model.
  • layer_structure: A vector of integers, The number of nodes in each layer of the network (default = c(256, 256, 256), denoting a three-layer network with 256 nodes per layer). Larger networks can learn more complex data structures but require longer training and are more prone to overfitting.
  • learn_rate: A number, the learning rate γ\gamma (default = 0.0001), which controls the size of the weight adjustment in each training epoch. In general, higher values reduce training time at the expense of less accurate results.
  • input_drop: A number between 0 and 1. The probability of corruption for input columns in training mini-batches (default = 0.8). Higher values increase training time but reduce the risk of overfitting. In our experience, values between 0.7 and 0.95 deliver the best performance.
  • seed: An integer, the value to which 's pseudo-random number generator is initialized. This enables users to ensure that data shuffling, weight and bias initialization, and missingness indicator vectors are reproducible.
  • train_batch: An integer, the number of observations in training mini-batches (default = 16).
  • latent_space_size: An integer, the number of normal dimensions used to parameterize the latent space.
  • cont_adj: A number, weights the importance of continuous variables in the loss function
  • binary_adj: A number, weights the importance of binary variables in the loss function
  • softmax_adj: A number, weights the importance of categorical variables in the loss function
  • dropout_level: A number between 0 and 1, determines the number of nodes dropped to "thin" the network
  • vae_layer: Boolean, specifies whether to include a variational autoencoder layer in the network
  • vae_alpha: A number, the strength of the prior imposed on the Kullback-Leibler divergence term in the variational autoencoder loss functions.
  • vae_sample_var: A number, the sampling variance of the normal distributions used to parameterize the latent space.

Returns

Object of class midas from which completed datasets can be drawn, using rMIDAS::complete()

Details

For more information, see Lall and Robinson (2023): doi:10.18637/jss.v107.i09.

Examples

# Generate raw data, with numeric, binary, and categorical variables ## Not run: # Run where Python available and configured correctly if (python_configured()) { set.seed(89) n_obs <- 10000 raw_data <- data.table(a = sample(c("red","yellow","blue",NA),n_obs, replace = TRUE), b = 1:n_obs, c = sample(c("YES","NO",NA),n_obs,replace=TRUE), d = runif(n_obs,1,10), e = sample(c("YES","NO"), n_obs, replace = TRUE), f = sample(c("male","female","trans","other",NA), n_obs, replace = TRUE)) # Names of bin./cat. variables test_bin <- c("c","e") test_cat <- c("a","f") # Pre-process data test_data <- convert(raw_data, bin_cols = test_bin, cat_cols = test_cat, minmax_scale = TRUE) # Run imputations test_imp <- train(test_data) # Generate datasets complete_datasets <- complete(test_imp, m = 5, fast = FALSE) # Use Rubin's rules to combine m regression models midas_pool <- combine(formula = d~a+c+e+f, complete_datasets) } ## End(Not run)

References

Rdpack::insert_ref(key="rmidas_jss",package="rMIDAS")

  • Maintainer: Thomas Robinson
  • License: Apache License (>= 2.0)
  • Last published: 2023-10-11