vaeac_get_val_iwae function

Compute the Importance Sampling Estimator (Validation Error)

Compute the Importance Sampling Estimator (Validation Error)

Compute the Importance Sampling Estimator which the vaeac model uses to evaluate its performance on the validation data.

vaeac_get_val_iwae( val_dataloader, mask_generator, batch_size, vaeac_model, val_iwae_n_samples )

Arguments

  • val_dataloader: A torch dataloader which loads the validation data.
  • mask_generator: A mask generator object that generates the masks.
  • batch_size: Integer. The number of samples to include in each batch.
  • vaeac_model: The vaeac model.
  • val_iwae_n_samples: Number of samples to generate for computing the IWAE for each validation sample.

Returns

The average iwae over all instances in the validation dataset.

Details

Compute mean IWAE log likelihood estimation of the validation set. IWAE is an abbreviation for Importance Sampling Estimator

logpθ,ψ(xy)log1Si=1Spθ(xzi,y)pψ(ziy)/qϕ(zix,y), \log p_{\theta, \psi}(x|y) \approx \log {\frac{1}{S}\sum_{i=1}^Sp_\theta(x|z_i, y) p_\psi(z_i|y) \big/ q_\phi(z_i|x,y),}

where ziqϕ(zx,y)z_i \sim q_\phi(z|x,y). For more details, see Olsen et al. (2022).

Author(s)

Lars Henry Berge Olsen