Samplers

Here are implemented the main samplers used in the pythae.models.

By convention, each implemented model is contained within a folder located in pythae.samplers and named likewise the sampler. The following modules can be found in this folder:

  • samplername_config.py: Contains a SamplerNameConfig instance inheriting from BaseSamplerConfig where the sampler configuration is stored and
  • samplername_sampler.py: An implementation of the sampler_name inheriting from BaseSampler.
  • samplername_utils.py (optional): A module where utils methods are stored.

BaseSampler

Base class for samplers used to generate from the VAEs models.

NormalSampler

Samples from a Standard normal distribution in the Autoencoder’s latent space.

GaussianMixtureSampler

Fits a Gaussian Mixture in the Autoencoder’s latent space.

TwoStageVAESampler

Fits a second VAE in the Autoencoder’s latent space.

HypersphereUniformSampler

Sampling from uniform distribution on hypersphere.

PoincareDiskSampler

Sampling from the Poincaré Disk using either a Wrapped Riemannian or Riemannian Gaussian distribution.

VAMPSampler

Sampling from the VAMP prior.

RHVAESampler

Sampling form the inverse of the metric volume element of a RHVAE model.

MAFSampler

Fits a Masked Autoregressive Flow in the Autoencoder’s latent space.

IAFSampler

Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space.

PixelCNNSampler

Fits a PixelCNN in the VQVAE’s latent space.

Basic Examples

To launch the data generation process from a trained model, you only need to build your sampler. For instance, to generate new data with your sampler, run the following.

Normal sampling

>>> from pythae.models import VAE
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = VAE.load_from_folder(
...     'path/to/your/trained/model'
... )
>>> # Define your sampler
>>> my_samper = NormalSampler(
...     model=my_trained_vae
... )
>>> # Generate samples
>>> gen_data = my_samper.sample(
...     num_samples=50,
...     batch_size=10,
...     output_dir=None,
...     return_gen=True
... )

Gaussian mixture sampling

>>> from pythae.models import VAE
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Retrieve the trained model
>>> my_trained_vae = VAE.load_from_folder(
...     'path/to/your/trained/model'
... )
>>> # Define your sampler
...     gmm_sampler_config = GaussianMixtureSamplerConfig(
...     n_components=10
... )
>>> my_samper = GaussianMixtureSampler(
...     sampler_config=gmm_sampler_config,
...     model=my_trained_vae
... )
>>> # fit the sampler
>>> gmm_sampler.fit(train_dataset)
>>> # Generate samples
>>> gen_data = my_samper.sample(
...     num_samples=50,
...     batch_size=10,
...     output_dir=None,
...     return_gen=True
... )

See also tutorials.