Keras Callback for adaptive learning rate with weight restoration
Keras Callback for adaptive learning rate with weight restoration
Provides a keras callback similar to keras3::callback_reduce_lr_on_plateau() but which also restores the weights to the best seen so far whenever a learning rate reduction occurs, and with slightly more restrictive improvement detection.
factor: factor by which the learning rate will be reduced. new_lr = old_lr * factor.
patience: number of epochs with no significant improvement after which the learning rate will be reduced.
verbose: integer. Set to 1 to receive update messages.
mode: Optimisation mode. "auto" detects the mode from the name of monitor. "min" monitors for decreasing metrics. "max" monitors for increasing metrics.
delta_abs: Minimum absolute metric improvement per epoch. The learning rate will be reduced if the average improvement is less than delta_abs per epoch for patience epochs.
delta_rel: Minimum relative metric improvement per epoch. The learning rate will be reduced if the average improvement is less than |metric| * delta_rel per epoch for patience epochs.
cooldown: number of epochs to wait before resuming normal operation after learning rate has been reduced. The minimum number of epochs between two learning rate reductions is patience + cooldown.
min_lr: lower bound for the learning rate. If a learning rate reduction would lower the learning rate below min_lr, it will be clipped at min_lr instead and no further reductions will be performed.
restore_weights: Bool. If TRUE, the best weights will be restored at each learning rate reduction. This is very useful if the metric oscillates.
Returns
A KerasCallback suitable for passing to keras3::fit().
Details
Note that while keras3::callback_reduce_lr_on_plateau() automatically logs the learning rate as a metric 'lr', this is currently impossible from R. Thus, if you want to also log the learning rate, you should add keras3::callback_reduce_lr_on_plateau() with a high min_lr to effectively disable the callback but still monitor the learning rate.