distr_categorical function

Creates a categorical distribution parameterized by either probs or logits (but not both).

Creates a categorical distribution parameterized by either probs or logits (but not both).

distr_categorical(probs = NULL, logits = NULL, validate_args = NULL)

Arguments

  • probs: (Tensor): event probabilities
  • logits: (Tensor): event log probabilities (unnormalized)
  • validate_args: Additional arguments

Note

It is equivalent to the distribution that torch_multinomial()

samples from.

Samples are integers from {0,,K1}\{0, \ldots, K-1\} where K is probs$size(-1).

If probs is 1-dimensional with length-K, each element is the relative probability of sampling the class at that index.

If probs is N-dimensional, the first N-1 dimensions are treated as a batch of relative probability vectors.

The probs argument must be non-negative, finite and have a non-zero sum, and it will be normalized to sum to 1 along the last dimension. attr:probs

will return this normalized value. The logits argument will be interpreted as unnormalized log probabilities and can therefore be any real number. It will likewise be normalized so that the resulting probabilities sum to 1 along the last dimension. attr:logits

will return this normalized value.

See also: torch_multinomial()

Examples

if (torch_is_installed()) { m <- distr_categorical(torch_tensor(c(0.25, 0.25, 0.25, 0.25))) m$sample() # equal probability of 1,2,3,4 }
  • Maintainer: Daniel Falbel
  • License: MIT + file LICENSE
  • Last published: 2025-02-14