WAE

This module is the implementation of the Wasserstein Autoencoder proposed in (https://arxiv.org/abs/1711.01558).

Available samplers

NormalSampler

Samples from a Standard normal distribution in the Autoencoder’s latent space.

GaussianMixtureSampler

Fits a Gaussian Mixture in the Autoencoder’s latent space.

MAFSampler

Fits a Masked Autoregressive Flow in the Autoencoder’s latent space.

IAFSampler

Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space.

class pythae.models.WAE_MMD_Config(input_dim=None, latent_dim=10, uses_default_encoder=True, uses_default_decoder=True, kernel_choice='imq', reg_weight=0.03, kernel_bandwidth=1.0, scales=<factory>, reconstruction_loss_scale=1.0)[source]

Wasserstein autoencoder model config class.

Parameters
  • input_dim (tuple) – The input_data dimension.

  • latent_dim (int) – The latent space dimension. Default: None.

  • kernel_choice (str) – The kernel to choose. Available options are [‘rbf’, ‘imq’] i.e. radial basis functions or inverse multiquadratic kernel. Default: ‘imq’.

  • reg_weight (float) – The weight to apply between reconstruction and Maximum Mean Discrepancy. Default: 3e-2

  • kernel_bandwidth (float) – The kernel bandwidth. Default: 1

  • scales (list) – The scales to apply if using multi-scale imq kernels. If None, use a unique imq kernel. Default: [.1, .2, .5, 1., 2., 5, 10.].

  • reconstruction_loss_scale (float) – Parameter scaling the reconstruction loss. Default: 1

class pythae.models.WAE_MMD(model_config, encoder=None, decoder=None)[source]

Wasserstein Autoencoder model.

Parameters
  • model_config (WAE_MMD_Config) – The Autoencoder configuration setting the main parameters of the model.

  • encoder (BaseEncoder) – An instance of BaseEncoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

  • decoder (BaseDecoder) – An instance of BaseDecoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

Note

For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a MemoryError.

forward(inputs, **kwargs)[source]

The input data is encoded and decoded

Parameters

inputs (BaseDataset) – An instance of pythae’s datasets

Returns

An instance of ModelOutput containing all the relevant parameters

Return type

ModelOutput

imq_kernel(z1, z2)[source]

Returns a matrix of shape [batch x batch] containing the pairwise kernel computation

rbf_kernel(z1, z2)[source]

Returns a matrix of shape [batch x batch] containing the pairwise kernel computation