Applies weight normalization to a parameter in the given module.
Returns
The original module with the weight_v and weight_g paramters.
Details
\eqn{\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}}
Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by name (e.g. 'weight') with two parameters: one specifying the magnitude (e.g. 'weight_g') and one specifying the direction (e.g. 'weight_v').
Note
The pytorch Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every forward()
call. Since torch for R still do not support hooks, the weight recomputation need to be done explicitly inside the forward() definition trough a call of the recompute() method. See examples.
By default, with dim = 0, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use dim = NULL.
if(torch_is_installed()){x = nn_linear(in_features =20, out_features =40)weight_norm = nn_utils_weight_norm$new(name ='weight', dim =2)weight_norm$apply(x)x$weight_g$size()x$weight_v$size()x$weight
# the recompute() method recomputes the weight using g and v. It must be called# explicitly inside `forward()`.weight_norm$recompute(x)}