Riemannian Hamiltonian VAE

This is an implementation of the Riemannian Hamiltonian VAE model proposed in (https://arxiv.org/abs/2105.00026). This model provides a way to learn the Riemannian latent structure of a given set of data set through a parametrized Riemannian metric having the following shape: \(\mathbf{G}^{-1}(z) = \sum \limits _{i=1}^N L_{\psi_i} L_{\psi_i}^{\top} \exp \Big(-\frac{\lVert z - c_i \rVert_2^2}{T^2} \Big) + \lambda I_d\)

It is particularly well suited for High Dimensional data combined with low sample number and proved relevant for Data Augmentation as proved in (https://arxiv.org/abs/2105.00026).

Available samplers

RHVAESampler

Sampling form the inverse of the metric volume element of a RHVAE model.

NormalSampler

Samples from a Standard normal distribution in the Autoencoder’s latent space.

GaussianMixtureSampler

Fits a Gaussian Mixture in the Autoencoder’s latent space.

TwoStageVAESampler

Fits a second VAE in the Autoencoder’s latent space.

MAFSampler

Fits a Masked Autoregressive Flow in the Autoencoder’s latent space.

IAFSampler

Fits an Inverse Autoregressive Flow in the Autoencoder’s latent space.

class pythae.models.RHVAEConfig(input_dim=None, latent_dim=10, uses_default_encoder=True, uses_default_decoder=True, reconstruction_loss='mse', n_lf=3, eps_lf=0.001, beta_zero=0.3, temperature=1.5, regularization=0.01, uses_default_metric=True)[source]

RHVAE config class.

Parameters
  • latent_dim (int) – The latent dimension used for the latent space. Default: 10

  • n_lf (int) – The number of leapfrog steps to used in the integrator: Default: 3

  • eps_lf (int) – The leapfrog stepsize. Default: 1e-3

  • beta_zero (int) – The tempering factor in the Riemannian Hamiltonian Monte Carlo Sampler. Default: 0.3

  • temperature (float) – The metric temperature \(T\). Default: 1.5

  • regularization (float) – The metric regularization factor \(\lambda\)

class pythae.models.RHVAE(model_config, encoder=None, decoder=None, metric=None)[source]

Riemannian Hamiltonian VAE model.

Parameters
  • model_config (RHVAEConfig) – A model configuration setting the main parameters of the model.

  • encoder (BaseEncoder) – An instance of BaseEncoder (inheriting from torch.nn.Module which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

  • decoder (BaseDecoder) – An instance of BaseDecoder (inheriting from torch.nn.Module which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.

Note

For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a MemoryError.

forward(inputs, **kwargs)[source]

The input data is first encoded. The reparametrization is used to produce a sample \(z_0\) from the approximate posterior \(q_{\phi}(z|x)\). Then Riemannian Hamiltonian equations are solved using the generalized leapfrog integrator. In the meantime, the input data \(x\) is fed to the metric network outputing the matrices \(L_{\psi}\). The metric is computed and used with the integrator.

Parameters

inputs (BaseDataset) – The training data with labels

Returns

An instance of ModelOutput containing all the relevant parameters

Return type

ModelOutput

get_nll(data, n_samples=1, batch_size=100)[source]

Function computed the estimate negative log-likelihood of the model. It uses importance sampling method with the approximate posterior distribution. This may take a while.

Parameters
  • data (torch.Tensor) – The input data from which the log-likelihood should be estimated. Data must be of shape [Batch x n_channels x …]

  • n_samples (int) – The number of importance samples to use for estimation

  • batch_size (int) – The batchsize to use to avoid memory issues

classmethod load_from_folder(dir_path)[source]

Class method to be used to load the model from a specific folder

Parameters

dir_path (str) – The path where the model should have been be saved.

Note

This function requires the folder to contain:

  • a model_config.json and a model.pt if no custom architectures were provided

or

  • a model_config.json, a model.pt and a encoder.pkl (resp. decoder.pkl or/and metric.pkl) if a custom encoder (resp. decoder or/and metric) was provided
classmethod load_from_hf_hub(hf_hub_path, allow_pickle=False)[source]

Class method to be used to load a pretrained model from the Hugging Face hub

Parameters

hf_hub_path (str) – The path where the model should have been be saved on the hugginface hub.

Note

This function requires the folder to contain:

  • a model_config.json and a model.pt if no custom architectures were provided

or

  • a model_config.json, a model.pt and a encoder.pkl (resp. decoder.pkl and metric.pkl) if a custom encoder (resp. decoder and/or metric) was provided
predict(inputs)[source]

The input data is encoded and decoded without computing loss

Parameters

inputs (torch.Tensor) – The input data to be reconstructed, as well as to generate the embedding.

Returns

An instance of ModelOutput containing reconstruction, raw embedding (output of encoder), and the final embedding (output of metric)

Return type

ModelOutput

save(dir_path)[source]

Method to save the model at a specific location

Parameters

dir_path (str) – The path where the model should be saved. If the path path does not exist a folder will be created at the provided location.

set_metric(metric)[source]

This method is called to set the metric network outputing the \(L_{\psi_i}\) of the metric matrices

Parameters

metric (BaseMetric) – The metric module that need to be set to the model.

update()[source]

As soon as the model has seen all the data points (i.e. at the end of 1 loop) we update the final metric function using mu(x_i) as centroids