Source code for pythae.models.hvae.hvae_model

import warnings
from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad

from ...data.datasets import BaseDataset
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from ..vae import VAE
from .hvae_config import HVAEConfig


[docs]class HVAE(VAE): r""" Hamiltonian VAE. Args: model_config (HVAEConfig): A model configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: HVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "HVAE" self.n_lf = model_config.n_lf self.eps_lf = nn.Parameter( torch.tensor([model_config.eps_lf]), requires_grad=True if model_config.learn_eps_lf else False, ) self.beta_zero_sqrt = nn.Parameter( torch.tensor([model_config.beta_zero]) ** 0.5, requires_grad=True if model_config.learn_beta_zero else False, ) if model_config.reconstruction_loss == "bce": warnings.warn( "Carefull, this model expects the encoder to give the *logits* of the Bernouilli " "distribution. Make sure the encoder actually outputs the logits." )
[docs] def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: r""" The input data is first encoded. The reparametrization is used to produce a sample :math:`z_0` from the approximate posterior :math:`q_{\phi}(z|x)`. Then Hamiltonian equations are solved using the leapfrog integrator. Args: inputs (BaseDataset): The training data with labels Returns: output (ModelOutput): An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = F.softplus(log_var) z0, eps0 = self._sample_gauss(mu, std) gamma = torch.randn_like(z0, device=x.device) rho = gamma / self.beta_zero_sqrt z = z0 beta_sqrt_old = self.beta_zero_sqrt for k in range(self.n_lf): # perform leapfrog steps # 1st leapfrog step rho_ = rho - (self.eps_lf / 2) * self._dU_dz(z, x) # 2nd leapfrog step z_ = z + self.eps_lf * rho # 3rd leapfrog step rho__ = rho_ - (self.eps_lf / 2) * self._dU_dz(z_, x) # tempering steps beta_sqrt = self._tempering(k + 1, self.n_lf) rho = (beta_sqrt_old / beta_sqrt) * rho__ beta_sqrt_old = beta_sqrt z = z_ rho = rho__ recon_x = self.decoder(z)["reconstruction"].reshape_as(x) loss = self.loss_function(x, z, rho, z0, mu, log_var) output = ModelOutput( loss=loss, recon_x=recon_x, z=z, z0=z0, rho=rho, eps0=eps0, gamma=gamma, mu=mu, log_var=log_var, ) return output
def _dU_dz(self, z, x): net_out = self.decoder(z)["reconstruction"].reshape(x.shape[0], -1) U = -self._log_p_x_given_z(net_out, x).sum() g = grad(U, z)[0] return g + z def loss_function(self, x, zK, rhoK, z0, mu, log_var): recon_x = self.decoder(zK)["reconstruction"] logpx_given_z = self._log_p_x_given_z(recon_x, x) # log p(x|z_K) log_zk = -0.5 * torch.pow(zK, 2).sum(dim=-1) # log p(\z_K) logrhoK = -0.5 * torch.pow(rhoK, 2).sum(dim=-1) # log p(\rho_K) logp = logpx_given_z + logrhoK + log_zk logq = -0.5 * log_var.sum( dim=-1 ) # (-0.5 * (log_var + torch.pow(z0 - mu, 2) / log_var.exp())).sum(dim=1) # q(z_0|x) return -(logp - logq).mean(dim=0) def _tempering(self, k, K): """Perform tempering step""" beta_k = ( (1 - 1 / self.beta_zero_sqrt) * (k / K) ** 2 ) + 1 / self.beta_zero_sqrt return 1 / beta_k def _log_p_x_given_z(self, recon_x, x): if self.model_config.reconstruction_loss == "mse": # sigma is taken as I_D logp_x_given_z = ( -0.5 * F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) ) # - torch.log(torch.tensor([2 * np.pi]).to(x.device)) \ # * np.prod(self.input_dim) / 2 elif self.model_config.reconstruction_loss == "bce": logp_x_given_z = ( torch.distributions.Bernoulli(logits=recon_x.reshape(x.shape[0], -1)) .log_prob(x.reshape(x.shape[0], -1)) .sum(dim=-1) ) return logp_x_given_z
[docs] def get_nll(self, data, n_samples=1, batch_size=100): """ Function computed the estimate negative log-likelihood of the model. It uses importance sampling method with the approximate posterior distribution. This may take a while. Args: data (torch.Tensor): The input data from which the log-likelihood should be estimated. Data must be of shape [Batch x n_channels x ...] n_samples (int): The number of importance samples to use for estimation batch_size (int): The batchsize to use to avoid memory issues """ if n_samples <= batch_size: n_full_batch = 1 else: n_full_batch = n_samples // batch_size n_samples = batch_size log_p = [] for i in range(len(data)): x = data[i].unsqueeze(0) log_p_x = [] for j in range(n_full_batch): x_rep = torch.cat(batch_size * [x]).reshape(-1, 1, 28, 28) encoder_output = self.encoder(x_rep) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) z0, _ = self._sample_gauss(mu, std) gamma = torch.randn_like(z0, device=x.device) rho = gamma / self.beta_zero_sqrt z = z0 beta_sqrt_old = self.beta_zero_sqrt for k in range(self.n_lf): # 1st leapfrog step rho_ = rho - (self.eps_lf / 2) * self._dU_dz(z, x_rep) # 2nd leapfrog step z = z + self.eps_lf * rho_ # 3rd leapfrog step rho__ = rho_ - (self.eps_lf / 2) * self._dU_dz(z, x_rep) # tempering steps beta_sqrt = self._tempering(k + 1, self.n_lf) rho = (beta_sqrt_old / beta_sqrt) * rho__ beta_sqrt_old = beta_sqrt log_q_z0_given_x = -0.5 * ( log_var + (z0 - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) log_p_z = -0.5 * (z ** 2).sum(dim=-1) log_p_rho = -0.5 * (rho ** 2).sum(dim=-1) log_p_rho0 = -0.5 * (rho ** 2).sum(dim=-1) * self.beta_zero_sqrt recon_x = self.decoder(z)["reconstruction"] log_p_x_given_z = self._log_p_x_given_z(recon_x, x_rep) log_p_x.append( log_p_x_given_z + log_p_z + log_p_rho - log_p_rho0 - log_q_z0_given_x ) # N*log(2*pi) simplifies in prior and posterior log_p_x = torch.cat(log_p_x) log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) if i % 50 == 0: print(f"Current nll at {i}: {np.mean(log_p)}") return np.mean(log_p)