callback_debug_dist_gradients function

Callback to monitor likelihood gradient components

Callback to monitor likelihood gradient components

Provides a keras callback to monitor the individual components of the censored and truncated likelihood. Useful for debugging TensorFlow implementations of Distributions.

callback_debug_dist_gradients( object, data, obs, keep_grads = FALSE, stop_on_na = TRUE, verbose = TRUE )

Arguments

  • object: A reservr_keras_model created by tf_compile_model().
  • data: Input data for the model.
  • obs: Observations associated to data.
  • keep_grads: Log actual gradients? (memory hungry!)
  • stop_on_na: Stop if any likelihood component as NaN in its gradients?
  • verbose: Print a message if training is halted? The Message will contain information about which likelihood components have NaN in their gradients.

Returns

A KerasCallback suitable for passing to keras3::fit().

Examples

dist <- dist_exponential() group <- sample(c(0, 1), size = 100, replace = TRUE) x <- dist$sample(100, with_params = list(rate = group + 1)) global_fit <- fit(dist, x) if (interactive()) { library(keras3) l_in <- layer_input(shape = 1L) mod <- tf_compile_model( inputs = list(l_in), intermediate_output = l_in, dist = dist, optimizer = optimizer_adam(), censoring = FALSE, truncation = FALSE ) tf_initialise_model(mod, global_fit$params) gradient_tracker <- callback_debug_dist_gradients( mod, as_tensor(group, config_floatx()), x, keep_grads = TRUE ) fit_history <- fit( mod, x = as_tensor(group, config_floatx()), y = x, epochs = 20L, callbacks = list( callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4), gradient_tracker, callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr ) ) gradient_tracker$gradient_logs[[20]]$dens plot(fit_history) predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx())) }