nn_flatten function

Flattens a contiguous range of dims into a tensor.

Flattens a contiguous range of dims into a tensor.

For use with nn_sequential .

nn_flatten(start_dim = 2, end_dim = -1)

Arguments

  • start_dim: first dim to flatten (default = 2).
  • end_dim: last dim to flatten (default = -1).

Shape

  • Input: (*, S_start,..., S_i, ..., S_end, *), where S_i is the size at dimension i and * means any number of dimensions including none.
  • Output: (*, S_start*...*S_i*...S_end, *).

Examples

if (torch_is_installed()) { input <- torch_randn(32, 1, 5, 5) m <- nn_flatten() m(input) }

See Also

nn_unflatten

  • Maintainer: Daniel Falbel
  • License: MIT + file LICENSE
  • Last published: 2025-02-14