Source code for pythae.models.vamp.vamp_model

from typing import Optional

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from ...data.datasets import BaseDataset
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from ..vae import VAE
from .vamp_config import VAMPConfig


[docs]class VAMP(VAE): """Variational Mixture of Posteriors (VAMP) VAE model Args: model_config (VAEConfig): The Variational Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: VAMPConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "VAMP" self.number_components = model_config.number_components if model_config.input_dim is None: raise AttributeError("Provide input dim to build pseudo input network") linear_layer = nn.Linear( model_config.number_components, int(np.prod(model_config.input_dim)) ) self.pseudo_inputs = nn.Sequential(linear_layer, nn.Hardtanh(0.0, 1.0)) self.idle_input = torch.eye( model_config.number_components, requires_grad=False ).to(self.device) self.linear_scheduling = self.model_config.linear_scheduling_steps
[docs] def forward(self, inputs: BaseDataset, **kwargs): """ The VAE model Args: inputs (BaseDataset): The training dataset with labels Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] epoch = kwargs.pop("epoch", 100) encoder_output = self.encoder(x) mu, log_var = ( encoder_output.embedding, encoder_output.log_covariance, ) std = torch.exp(0.5 * log_var) z, _ = self._sample_gauss(mu, std) recon_x = self.decoder(z)["reconstruction"] loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z, epoch) output = ModelOutput( recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, z=z, ) return output
def loss_function(self, recon_x, x, mu, log_var, z, epoch): if self.model_config.reconstruction_loss == "mse": recon_loss = ( 0.5 * F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) ) elif self.model_config.reconstruction_loss == "bce": recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) log_p_z = self._log_p_z(z) log_q_z = (-0.5 * (log_var + torch.pow(z - mu, 2) / log_var.exp())).sum(dim=1) KLD = -(log_p_z - log_q_z) if self.linear_scheduling > 0: beta = 1.0 * epoch / self.linear_scheduling if beta > 1 or not self.training: beta = 1.0 else: beta = 1.0 return ( (recon_loss + beta * KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0), ) def _log_p_z(self, z): """Computation of the log prob of the VAMP""" C = self.number_components x = self.pseudo_inputs(self.idle_input.to(z.device)).reshape( (C,) + self.model_config.input_dim ) # we bound log_var to avoid unbounded optim encoder_output = self.encoder(x) prior_mu, prior_log_var = ( encoder_output.embedding, encoder_output.log_covariance, ) z_expand = z.unsqueeze(1) prior_mu = prior_mu.unsqueeze(0) prior_log_var = prior_log_var.unsqueeze(0) log_p_z = ( torch.sum( -0.5 * ( prior_log_var + (z_expand - prior_mu) ** 2 / torch.exp(prior_log_var) ), dim=2, ) - torch.log(torch.tensor(C).type(torch.float)) ) log_p_z = torch.logsumexp(log_p_z, dim=1) return log_p_z def _sample_gauss(self, mu, std): # Reparametrization trick # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps
[docs] def get_nll(self, data, n_samples=1, batch_size=100): """ Function computed the estimate negative log-likelihood of the model. It uses importance sampling method with the approximate posterior distribution. This may take a while. Args: data (torch.Tensor): The input data from which the log-likelihood should be estimated. Data must be of shape [Batch x n_channels x ...] n_samples (int): The number of importance samples to use for estimation batch_size (int): The batchsize to use to avoid memory issues """ if n_samples <= batch_size: n_full_batch = 1 else: n_full_batch = n_samples // batch_size n_samples = batch_size log_p = [] for i in range(len(data)): x = data[i].unsqueeze(0) encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance log_p_x = [] for j in range(n_full_batch): x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) z, eps = self._sample_gauss(mu, std) log_q_z_given_x = -0.5 * ( log_var + (z - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) log_p_z = self._log_p_z(z) recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), reduction="none", ).sum(dim=-1) - torch.tensor( [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] ).to( data.device ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), reduction="none", ).sum(dim=-1) log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies log_p_x = torch.cat(log_p_x) log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) if i % 1000 == 0: print(f"Current nll at {i}: {np.mean(log_p)}") return np.mean(log_p)