FactorVAE¶
This module is the implementation of the FactorVAE proposed in (https://arxiv.org/abs/1802.05983). This model adds a new parameter to the VAE loss function balancing the weight of the reconstruction term and the Total Correlation.
Available samplers¶
Samples from a Standard normal distribution in the Autoencoder’s latent space. |
|
Fits a Gaussian Mixture in the Autoencoder’s latent space. |
|
Fits a second VAE in the Autoencoder’s latent space. |
|
Fits a Masked Autoregressive Flow in the Autoencoder’s latent space. |
|
Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space. |
- class pythae.models.FactorVAEConfig(input_dim=None, latent_dim=10, uses_default_encoder=True, uses_default_decoder=True, reconstruction_loss='mse', gamma=2.0, uses_default_discriminator=True)[source]¶
FactorVAE model config config class
- class pythae.models.FactorVAE(model_config, encoder=None, decoder=None)[source]¶
FactorVAE model.
- Parameters
model_config (FactorVAEConfig) – The Variational Autoencoder configuration setting the main parameters of the model.
encoder (BaseEncoder) – An instance of BaseEncoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
decoder (BaseDecoder) – An instance of BaseDecoder (inheriting from torch.nn.Module which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
Note
For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a
MemoryError.- forward(inputs, **kwargs)[source]¶
The VAE model
- Parameters
inputs (BaseDataset) – The training dataset with labels
- Returns
An instance of ModelOutput containing all the relevant parameters
- Return type
- interpolate(starting_inputs, ending_inputs, granularity=10)[source]¶
This function performs a linear interpolation in the latent space of the autoencoder from starting inputs to ending inputs. It returns the interpolation trajectories.
- Parameters
starting_inputs (torch.Tensor) – The starting inputs in the interpolation of shape [B x input_dim]
ending_inputs (torch.Tensor) – The starting inputs in the interpolation of shape [B x input_dim]
granularity (int) – The granularity of the interpolation.
- Returns
A tensor of shape [B x granularity x input_dim] containing the interpolation trajectories.
- Return type
- reconstruct(inputs)[source]¶
This function returns the reconstructions of given input data.
- Parameters
inputs (torch.Tensor) – The inputs data to be reconstructed of shape [B x input_dim]
ending_inputs (torch.Tensor) – The starting inputs in the interpolation of shape
- Returns
A tensor of shape [B x input_dim] containing the reconstructed samples.
- Return type
- set_discriminator(discriminator)[source]¶
This method is called to set the discriminator network
- Parameters
discriminator (BaseDiscriminator) – The discriminator module that needs to be set to the model.