vae_loss_correlated function

A custom loss function for a VAE learning a multivariate normal distribution with a full covariance matrix