Selects values from input at the 1-dimensional indices from indices along the given dim.
Selects values from input at the 1-dimensional indices from indices along the given dim.
torch_take_along_dim(self, indices, dim =NULL)
Arguments
self: the input tensor.
indices: the indices into input. Must have long dtype.
dim: the dimension to select along. Default is NULL.
Note
If dim is NULL, the input array is treated as if it has been flattened to 1d.
Functions that return indices along a dimension, like torch_argmax() and torch_argsort(), are designed to work with this function. See the examples below.