tf_compile_model function

Compile a Keras model for truncated data under dist

Compile a Keras model for truncated data under dist

tf_compile_model( inputs, intermediate_output, dist, optimizer, censoring = TRUE, truncation = TRUE, metrics = NULL, weighted_metrics = NULL )

Arguments

  • inputs: List of keras input layers

  • intermediate_output: Intermediate model layer to be used as input to distribution parameters

  • dist: A Distribution to use for compiling the loss and parameter outputs

  • optimizer: String (name of optimizer) or optimizer instance. See optimizer_* family.

  • censoring: A flag, whether the compiled model should support censored observations. Set to FALSE for higher efficiency. fit(...) will error if the resulting model is used to fit censored observations.

  • truncation: A flag, whether the compiled model should support truncated observations. Set to FALSE for higher efficiency. fit(...) will warn if the resuting model is used to fit truncated observations.

  • metrics: List of metrics to be evaluated by the model during training and testing. Each of these can be:

    • a string (name of a built-in function),

    • a function, optionally with a "name" attribute or

    • a Metric()

      instance. See the metric_* family of functions.

    Typically you will use metrics = c('accuracy'). A function is any callable with the signature result = fn(y_true, y_pred). To specify different metrics for different outputs of a multi-output model, you could also pass a named list, such as metrics = list(a = 'accuracy', b = c('accuracy', 'mse')). You can also pass a list to specify a metric or a list of metrics for each output, such as metrics = list(c('accuracy'), c('accuracy', 'mse'))

    or metrics = list('accuracy', c('accuracy', 'mse')). When you pass the strings 'accuracy' or 'acc', we convert this to one of metric_binary_accuracy(), metric_categorical_accuracy(), metric_sparse_categorical_accuracy() based on the shapes of the targets and of the model output. A similar conversion is done for the strings "crossentropy"

    and "ce" as well. The metrics passed here are evaluated without sample weighting; if you would like sample weighting to apply, you can specify your metrics via the weighted_metrics argument instead.

    If providing an anonymous R function, you can customize the printed name during training by assigning attr(<fn>, "name") <- "my_custom_metric_name", or by calling custom_metric("my_custom_metric_name", <fn>)

  • weighted_metrics: List of metrics to be evaluated and weighted by sample_weight or class_weight during training and testing.

Returns

A reservr_keras_model that can be used to train truncated and censored observations from dist based on input data from inputs.

Examples

dist <- dist_exponential() params <- list(rate = 1.0) N <- 100L rand_input <- runif(N) x <- dist$sample(N, with_params = params) if (interactive()) { tf_in <- keras3::layer_input(1L) mod <- tf_compile_model( inputs = list(tf_in), intermediate_output = tf_in, dist = dist, optimizer = keras3::optimizer_adam(), censoring = FALSE, truncation = FALSE ) }