nn_glu function

GLU module

GLU module

Applies the gated linear unit function GLU(a,b)=aσ(b){GLU}(a, b)= a \otimes \sigma(b) where aa is the first half of the input matrices and bb is the second half.

nn_glu(dim = -1)

Arguments

  • dim: (int): the dimension on which to split the input. Default: -1

Shape

  • Input: (1,N,2)(\ast_1, N, \ast_2) where * means, any number of additional dimensions
  • Output: (1,M,2)(\ast_1, M, \ast_2) where M=N/2M=N/2

Examples

if (torch_is_installed()) { m <- nn_glu() input <- torch_randn(4, 2) output <- m(input) }
  • Maintainer: Daniel Falbel
  • License: MIT + file LICENSE
  • Last published: 2025-02-14