callback_adaptive_lr function

Keras Callback for adaptive learning rate with weight restoration

Keras Callback for adaptive learning rate with weight restoration

Provides a keras callback similar to keras3::callback_reduce_lr_on_plateau() but which also restores the weights to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement detection.

callback_adaptive_lr( monitor = "val_loss", factor = 0.1, patience = 10L, verbose = 0L, mode = c("auto", "min", "max"), delta_abs = 1e-04, delta_rel = 0, cooldown = 0L, min_lr = 0, restore_weights = TRUE )

Arguments

  • monitor: quantity to be monitored.
  • factor: factor by which the learning rate will be reduced. new_lr = old_lr * factor.
  • patience: number of epochs with no significant improvement after which the learning rate will be reduced.
  • verbose: integer. Set to 1 to receive update messages.
  • mode: Optimisation mode. "auto" detects the mode from the name of monitor. "min" monitors for decreasing metrics. "max" monitors for increasing metrics.
  • delta_abs: Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average improvement is less than delta_abs per epoch for patience epochs.
  • delta_rel: Minimum relative metric improvement per epoch. The learning rate will be reduced if the average improvement is less than |metric| * delta_rel per epoch for patience epochs.
  • cooldown: number of epochs to wait before resuming normal operation after learning rate has been reduced. The minimum number of epochs between two learning rate reductions is patience + cooldown.
  • min_lr: lower bound for the learning rate. If a learning rate reduction would lower the learning rate below min_lr, it will be clipped at min_lr instead and no further reductions will be performed.
  • restore_weights: Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates.

Returns

A KerasCallback suitable for passing to keras3::fit().

Details

Note that while keras3::callback_reduce_lr_on_plateau() automatically logs the learning rate as a metric 'lr', this is currently impossible from R. Thus, if you want to also log the learning rate, you should add keras3::callback_reduce_lr_on_plateau() with a high min_lr to effectively disable the callback but still monitor the learning rate.

Examples

dist <- dist_exponential() group <- sample(c(0, 1), size = 100, replace = TRUE) x <- dist$sample(100, with_params = list(rate = group + 1)) global_fit <- fit(dist, x) if (interactive()) { library(keras3) l_in <- layer_input(shape = 1L) mod <- tf_compile_model( inputs = list(l_in), intermediate_output = l_in, dist = dist, optimizer = optimizer_adam(), censoring = FALSE, truncation = FALSE ) tf_initialise_model(mod, global_fit$params) fit_history <- fit( mod, x = as_tensor(group, config_floatx()), y = as_trunc_obs(x), epochs = 20L, callbacks = list( callback_adaptive_lr("loss", factor = 0.5, patience = 2L, verbose = 1L, min_lr = 1.0e-4), callback_reduce_lr_on_plateau("loss", min_lr = 1.0) # to track lr ) ) plot(fit_history) predicted_means <- predict(mod, data = as_tensor(c(0, 1), config_floatx())) }