Adversarial Autoencoder

Implementation of an Adversarial Autoencoder model as proposed in (https://arxiv.org/abs/1511.05644). This model tries to make the posterior distribution match the prior using adversarial training.

Available samplers

NormalSampler

Samples from a Standard normal distribution in the Autoencoder’s latent space.

GaussianMixtureSampler

Fits a Gaussian Mixture in the Autoencoder’s latent space.

TwoStageVAESampler

Fits a second VAE in the Autoencoder’s latent space.

MAFSampler

Fits a Masked Autoregressive Flow in the Autoencoder’s latent space.

IAFSampler

Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space.

class pythae.models.Adversarial_AE_Config(input_dim=None, latent_dim=10, uses_default_encoder=True, uses_default_decoder=True, reconstruction_loss='mse', adversarial_loss_scale=0.5, reconstruction_loss_scale=1.0, deterministic_posterior=False, uses_default_discriminator=True, discriminator_input_dim=None)[source]

Adversarial AE model config class.

Parameters
  • input_dim (tuple) – The input_data dimension.

  • latent_dim (int) – The latent space dimension. Default: None.

  • reconstruction_loss (str) – The reconstruction loss to use [‘bce’, ‘mse’]. Default: ‘mse’

  • adversarial_loss_scale (float) – Parameter scaling the adversarial loss. Default: 0.5

  • reconstruction_loss_scale (float) – Parameter scaling the reconstruction loss. Default: 1

  • deterministic_posterior (bool) – Whether to use a deterministic posterior (Dirac). Default: False

class pythae.models.Adversarial_AE(model_config, encoder=None, decoder=None, discriminator=None)[source]

Adversarial Autoencoder model.

Parameters
  • model_config (Adversarial_AE_Config) – The Autoencoder configuration setting the main parameters of the model.

  • encoder (BaseEncoder) – An instance of BaseEncoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

  • decoder (BaseDecoder) – An instance of BaseDecoder (inheriting from torch.nn.Module which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

  • discriminator (BaseDiscriminator) – An instance of BaseDiscriminator (inheriting from torch.nn.Module which plays the role of discriminator. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

Note

For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a MemoryError.

forward(inputs, **kwargs)[source]

The input data is encoded and decoded

Parameters

inputs (BaseDataset) – An instance of pythae’s datasets

Returns

An instance of ModelOutput containing all the relevant parameters

Return type

ModelOutput

classmethod load_from_folder(dir_path)[source]

Class method to be used to load the model from a specific folder

Parameters

dir_path (str) – The path where the model should have been be saved.

Note

This function requires the folder to contain:

  • a model_config.json and a model.pt if no custom architectures were provided

or

  • a model_config.json, a model.pt and a encoder.pkl (resp. decoder.pkl) if a custom encoder (resp. decoder) was provided
classmethod load_from_hf_hub(hf_hub_path, allow_pickle=False)[source]

Class method to be used to load a pretrained model from the Hugging Face hub

Parameters

hf_hub_path (str) – The path where the model should have been be saved on the hugginface hub.

Note

This function requires the folder to contain:

  • a model_config.json and a model.pt if no custom architectures were provided

or

  • a model_config.json, a model.pt and a encoder.pkl (resp. decoder.pkl and discriminator) if a custom encoder (resp. decoder and/or discriminator) was provided
save(dir_path)[source]

Method to save the model at a specific location

Parameters

dir_path (str) – The path where the model should be saved. If the path path does not exist a folder will be created at the provided location.

set_discriminator(discriminator)[source]

This method is called to set the discriminator network

Parameters

discriminator (BaseDiscriminator) – The discriminator module that needs to be set to the model.