RHVAESampler

Implementation of a Manifold sampler proposed in (https://arxiv.org/abs/2105.00026).

Available models:

RHVAE

Riemannian Hamiltonian VAE model.

class pythae.samplers.RHVAESamplerConfig(mcmc_steps_nbr=100, n_lf=15, eps_lf=0.03, beta_zero=1.0)[source]

RHVAESampler config class.

Parameters
  • num_samples (int) – The number of samples to generate. Default: 1

  • batch_size (int) – The number of samples per batch. Batching is used to speed up generation and avoid memory overflows. Default: 50

  • mcmc_steps (int) – The number of MCMC steps to use in the latent space HMC sampler. Default: 100

  • n_lf (int) – The number of leapfrog to use in the integrator of the HMC sampler. Default: 15

  • eps_lf (float) – The leapfrog stepsize in the integrator of the HMC sampler. Default: 3e-2

  • random_start (bool) – Initialization of the latent space sampler. If False, the sampler starts the Markov chain on the metric centroids. If True , a random start is applied. Default: False

class pythae.samplers.RHVAESampler(model, sampler_config=None)[source]

Sampling form the inverse of the metric volume element of a RHVAE model.

Parameters
  • model (RHVAE) – The VAE model to sample from

  • sampler_config (RHVAESamplerConfig) – A RHVAESamplerConfig instance containing the main parameters of the sampler. If None, a pre-defined configuration is used. Default: None

sample(num_samples=1, batch_size=500, output_dir=None, return_gen=True, save_sampler_config=False)[source]

Main sampling function of the sampler.

Parameters
  • num_samples (int) – The number of samples to generate

  • batch_size (int) – The batch size to use during sampling

  • output_dir (str) – The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None.

  • return_gen (bool) – Whether the sampler should directly return a tensor of generated data. Default: True.

  • save_sampler_config (bool) – Whether to save the sampler config. It is saved in output_dir

Returns

The generated images

Return type

Tensor