VAEGAN¶
This module is the implementation of the VAE-GAN model proposed in (https://arxiv.org/abs/1512.09300).
Available samplers¶
Samples from a Standard normal distribution in the Autoencoder’s latent space. |
|
Fits a Gaussian Mixture in the Autoencoder’s latent space. |
|
Fits a second VAE in the Autoencoder’s latent space. |
|
Fits a Masked Autoregressive Flow in the Autoencoder’s latent space. |
|
Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space. |
- class pythae.models.VAEGANConfig(input_dim=None, latent_dim=10, uses_default_encoder=True, uses_default_decoder=True, reconstruction_loss='mse', adversarial_loss_scale=0.5, reconstruction_layer=- 1, uses_default_discriminator=True, discriminator_input_dim=None, equilibrium=0.68, margin=0.4)[source]¶
Adversarial AE model config class.
- Parameters
input_dim (tuple) – The input_data dimension.
latent_dim (int) – The latent space dimension. Default: None.
reconstruction_loss (str) – The reconstruction loss to use [‘bce’, ‘mse’]. Default: ‘mse’
adversarial_loss_scale (float) – Parameter scaling the adversarial loss
reconstruction_layer (int) – The reconstruction layer depth used for reconstruction metric
- class pythae.models.VAEGAN(model_config, encoder=None, decoder=None, discriminator=None)[source]¶
Variational Autoencoder using Adversarial reconstruction loss model.
- Parameters
model_config (VAEGANConfig) – The Autoencoder configuration setting the main parameters of the model.
encoder (BaseEncoder) – An instance of BaseEncoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
decoder (BaseDecoder) – An instance of BaseDecoder (inheriting from torch.nn.Module which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
discriminator (BaseDiscriminator) – An instance of BaseDiscriminator (inheriting from torch.nn.Module which plays the role of discriminator. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
Note
For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a
MemoryError.- forward(inputs, **kwargs)[source]¶
The input data is encoded and decoded
- Parameters
inputs (BaseDataset) – An instance of pythae’s datasets
- Returns
An instance of ModelOutput containing all the relevant parameters
- Return type
- classmethod load_from_folder(dir_path)[source]¶
Class method to be used to load the model from a specific folder
- Parameters
dir_path (str) – The path where the model should have been be saved.
Note
This function requires the folder to contain:
- a
model_config.jsonand amodel.ptif no custom architectures were provided
or
- a
model_config.json, amodel.ptand aencoder.pkl(resp.decoder.pkl) if a custom encoder (resp. decoder) was provided
- classmethod load_from_hf_hub(hf_hub_path, allow_pickle=False)[source]¶
Class method to be used to load a pretrained model from the Hugging Face hub
- Parameters
hf_hub_path (str) – The path where the model should have been be saved on the hugginface hub.
Note
This function requires the folder to contain:
- a
model_config.jsonand amodel.ptif no custom architectures were provided
or
- a
model_config.json, amodel.ptand aencoder.pkl(resp.decoder.pklanddiscriminator) if a custom encoder (resp. decoder and/or discriminator) was provided
- save(dir_path)[source]¶
Method to save the model at a specific location
- Parameters
dir_path (str) – The path where the model should be saved. If the path path does not exist a folder will be created at the provided location.
- set_discriminator(discriminator)[source]¶
This method is called to set the metric network outputing the \(L_{\psi_i}\) of the metric matrices
- Parameters
metric (BaseMetric) – The metric module that need to be set to the model.