Callback to monitor likelihood gradient components
Provides a keras callback to monitor the individual components of the censored and truncated likelihood. Useful for debugging TensorFlow implementations of Distributions.
callback_debug_dist_gradients( object, data, obs, keep_grads = FALSE, stop_on_na = TRUE, verbose = TRUE )
object
: A reservr_keras_model
created by tf_compile_model()
.data
: Input data for the model.obs
: Observations associated to data
.keep_grads
: Log actual gradients? (memory hungry!)stop_on_na
: Stop if any likelihood component as NaN in its gradients?verbose
: Print a message if training is halted? The Message will contain information about which likelihood components have NaN in their gradients.A KerasCallback
suitable for passing to keras3::fit()
.
dist <- dist_exponential() group <- sample(c(0, 1), size = 100, replace = TRUE) x <- dist$sample(100, with_params = list(rate = group + 1)) global_fit <- fit(dist, x) if (interactive()) { library(keras3) l_in <- layer_input(shape = 1L) mod <- tf_compile_model( inputs = list(l_in), intermediate_output = l_in, dist = dist, optimizer = optimizer_adam(), censoring = FALSE, truncation = FALSE ) tf_initialise_model(mod, global_fit$params) gradient_tracker <- callback_debug_dist_gradients( mod, as_tensor(group, config_floatx()), x, keep_grads = TRUE ) fit_history <- fit( mod, x = as_tensor(group, config_floatx()), y = x, epochs = 20L, callbacks = list( callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4), gradient_tracker, callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr ) ) gradient_tracker$gradient_logs[[20]]$dens plot(fit_history) predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx())) }
Useful links