nn_utils_weight_norm function

nn_utils_weight_norm

nn_utils_weight_norm

Applies weight normalization to a parameter in the given module.

Returns

The original module with the weight_v and weight_g paramters.

Details

\eqn{\mathbf{w} = g \dfrac{\mathbf{v}}{\|\mathbf{v}\|}}

Weight normalization is a reparameterization that decouples the magnitude of a weight tensor from its direction. This replaces the parameter specified by name (e.g. 'weight') with two parameters: one specifying the magnitude (e.g. 'weight_g') and one specifying the direction (e.g. 'weight_v').

Note

The pytorch Weight normalization is implemented via a hook that recomputes the weight tensor from the magnitude and direction before every forward()

call. Since torch for R still do not support hooks, the weight recomputation need to be done explicitly inside the forward() definition trough a call of the recompute() method. See examples.

By default, with dim = 0, the norm is computed independently per output channel/plane. To compute a norm over the entire weight tensor, use dim = NULL.

@references https://arxiv.org/abs/1602.07868

Examples

if (torch_is_installed()) { x = nn_linear(in_features = 20, out_features = 40) weight_norm = nn_utils_weight_norm$new(name = 'weight', dim = 2) weight_norm$apply(x) x$weight_g$size() x$weight_v$size() x$weight # the recompute() method recomputes the weight using g and v. It must be called # explicitly inside `forward()`. weight_norm$recompute(x) }

Methods

Public methods

Method new()

Usage

nn_utils_weight_norm$new(name, dim)

Arguments

  • name: (str, optional): name of weight parameter

  • dim: (int, optional): dimension over which to compute the norm

Method compute_weight()

Usage

nn_utils_weight_norm$compute_weight(module, name = NULL, dim = NULL)

Arguments

  • module: (Module): containing module

  • name: (str, optional): name of weight parameter

  • dim: (int, optional): dimension over which to compute the norm

Method apply()

Usage

nn_utils_weight_norm$apply(module, name = NULL, dim = NULL)

Arguments

  • module: (Module): containing module

  • name: (str, optional): name of weight parameter

  • dim: (int, optional): dimension over which to compute the norm

Method call()

Usage

nn_utils_weight_norm$call(module)

Arguments

  • module: (Module): containing module

Method recompute()

Usage

nn_utils_weight_norm$recompute(module)

Arguments

  • module: (Module): containing module

Method remove()

Usage

nn_utils_weight_norm$remove(module, name = NULL)

Arguments

  • module: (Module): containing module

  • name: (str, optional): name of weight parameter

Method clone()

The objects of this class are cloneable with this method.

Usage

nn_utils_weight_norm$clone(deep = FALSE)

Arguments

  • deep: Whether to make a deep clone.

  • Maintainer: Daniel Falbel
  • License: MIT + file LICENSE
  • Last published: 2025-02-14