Function that takes in the a tensor where the first half of the columns contains the means of the normal distributions, while the latter half of the columns contains the standard deviations. The standard deviations are clamped with min_sigma to ensure stable results. If params is of dimensions batch_size x 8, the function will create 4 independent normal distributions for each of the observation (batch_size observations in total).
params: Tensor of dimension batch_size x 2*n_featuers containing the means and standard deviations to be used in the normal distributions for of the batch_size observations.
min_sigma: For stability it might be desirable that the minimal sigma is not too close to zero.
Returns
A torch::distr_normal() distribution with the provided means and standard deviations.
Details
Take a Tensor (e.g. neural network output) and return a torch::distr_normal() distribution. This normal distribution is component-wise independent, and its dimensionality depends on the input shape. First half of channels is mean (μ) of the distribution, the softplus of the second half is std (σ), so there is no restrictions on the input tensor. min_sigma is the minimal value of σ. I.e., if the above softplus is less than min_sigma, then σ is clipped from below with value min_sigma. This regularization is required for the numerical stability and may be considered as a neural network architecture choice without any change to the probabilistic model.