Source code for pythae.models.svae.svae_model

from typing import Optional

import numpy as np
import torch
import torch.distributions as dist
import torch.nn.functional as F

from ...data.datasets import BaseDataset
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from ..nn.default_architectures import Encoder_SVAE_MLP
from ..vae import VAE
from .svae_config import SVAEConfig
from .svae_utils import ive


[docs]class SVAE(VAE): r""" :math:`\mathcal{S}`-VAE model. Args: model_config (SVAEConfig): The Variational Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: SVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "SVAE" if encoder is None: encoder = Encoder_SVAE_MLP(model_config) self.model_config.uses_default_encoder = True else: self.model_config.uses_default_encoder = False self.set_encoder(encoder)
[docs] def forward(self, inputs: BaseDataset, **kwargs): """ The VAE model Args: inputs (BaseDataset): The training dataset with labels Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] encoder_output = self.encoder(x) loc, log_concentration = ( encoder_output.embedding, encoder_output.log_concentration, ) # normalize mean loc = loc / loc.norm(dim=-1, keepdim=True) concentration = torch.nn.functional.softplus(log_concentration) + 1 z = self._sample_von_mises(loc, concentration) recon_x = self.decoder(z)["reconstruction"] loss, recon_loss, kld = self.loss_function(recon_x, x, loc, concentration, z) output = ModelOutput( recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, z=z, ) return output
def loss_function(self, recon_x, x, loc, concentration, z): if self.model_config.reconstruction_loss == "mse": recon_loss = ( 0.5 * F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) ) elif self.model_config.reconstruction_loss == "bce": recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) KLD = self._compute_kl(m=loc.shape[-1], concentration=concentration) return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0) def _compute_kl(self, m, concentration): term1 = concentration * ( ive(m / 2, concentration) / (ive(m / 2 - 1, concentration)) ) # good term2 = ( (m / 2 - 1) * concentration.log() - torch.tensor([2 * np.pi]).to(concentration.device).log() * (m / 2) - (ive(m / 2 - 1, concentration)).log() - concentration ) # good term3 = ( -torch.lgamma(torch.tensor([m / 2]).to(concentration.device)) + torch.tensor([2]).to(concentration.device).log() + torch.tensor([np.pi]).to(concentration.device).log() * (m / 2) ) # good return (term1 + term2 + term3).squeeze(-1) def _sample_von_mises(self, loc, concentration): # Generate uniformly on sphere v = torch.randn_like(loc[:, 1:]) v = v / v.norm(dim=-1, keepdim=True) w = self._acc_rej_steps(m=loc.shape[-1], k=concentration) w_ = torch.sqrt(torch.clamp(1 - (w ** 2), 1e-10)) z = torch.cat((w, w_ * v), dim=-1) return self._householder_rotation(loc, z) def _householder_rotation(self, loc, z): e1 = torch.zeros(z.shape[-1]).to(z.device) e1[0] = 1 u = e1 - loc u = u / (u.norm(dim=-1, keepdim=True) + 1e-8) return z - 2 * u * (u * z).sum(dim=-1, keepdim=True) def _acc_rej_steps(self, m: int, k: torch.Tensor, device: str = "cpu"): batch_size = k.shape[0] c = torch.sqrt(4 * k ** 2 + (m - 1) ** 2) b = (-2 * k + c) / (m - 1) a = (m - 1 + 2 * k + c) / 4 d = (4 * a * b) / (1 + b) - (m - 1) * np.log(m - 1) d.to(k.device) b.to(k.device) w = torch.zeros_like(k) stopping_mask = torch.ones_like(torch.tensor(b)).type(torch.bool) i = 0 while stopping_mask.sum() > 0 and i < 100: i += 1 eps = ( dist.Beta( torch.tensor(0.5 * (m - 1)).type(torch.float), torch.tensor(0.5 * (m - 1)).type(torch.float), ) .sample((batch_size, 1)) .to(k.device) ) w_ = (1 - (1 + b) * eps) / (1 - (1 - b) * eps) t = 2 * a * b / (1 - (1 - b) * eps) u = dist.Uniform(0, 1).sample((batch_size, 1)).to(k.device) acc = (m - 1) * t.log() - t + d > u.log() w[acc * stopping_mask] = w_[acc * stopping_mask] stopping_mask[acc * stopping_mask] = ~acc[acc * stopping_mask] return w
[docs] def get_nll(self, data, n_samples=1, batch_size=100): """ Function computed the estimate negative log-likelihood of the model. It uses importance sampling method with the approximate posterior distribution. This may take a while. Args: data (torch.Tensor): The input data from which the log-likelihood should be estimated. Data must be of shape [Batch x n_channels x ...] n_samples (int): The number of importance samples to use for estimation batch_size (int): The batchsize to use to avoid memory issues """ if n_samples <= batch_size: n_full_batch = 1 else: n_full_batch = n_samples // batch_size n_samples = batch_size log_p = [] for i in range(len(data)): x = data[i].unsqueeze(0) log_p_x = [] for j in range(n_full_batch): x_rep = torch.cat(batch_size * [x]) encoder_output = self.encoder(x_rep) loc, log_concentration = ( encoder_output.embedding, encoder_output.log_concentration, ) # normalize mean loc = loc / loc.norm(dim=-1, keepdim=True) concentration = torch.nn.functional.softplus(log_concentration) z = self._sample_von_mises(loc, concentration) recon_x = self.decoder(z)["reconstruction"] m = loc.shape[-1] term1 = concentration * (loc * z).sum(dim=-1, keepdim=True) term2 = ( (m / 2 - 1) * concentration.log() - torch.tensor([2 * np.pi]).to(concentration.device).log() * (m / 2) - (ive(m / 2 - 1, concentration)).log() - concentration ) log_q_z_given_x = (term1 + term2).reshape(-1) # VMF log-density log_p_z = -torch.ones_like(log_q_z_given_x) * ( -torch.lgamma(torch.tensor([m / 2]).to(concentration.device)) + torch.tensor([2]).to(concentration.device).log() + torch.tensor([np.pi]).to(concentration.device).log() * (m / 2) ) recon_x = self.decoder(z)["reconstruction"] if self.model_config.reconstruction_loss == "mse": log_p_x_given_z = -0.5 * F.mse_loss( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), reduction="none", ).sum(dim=-1) - torch.tensor( [np.prod(self.input_dim) / 2 * np.log(np.pi * 2)] ).to( data.device ) # decoding distribution is assumed unit variance N(mu, I) elif self.model_config.reconstruction_loss == "bce": log_p_x_given_z = -F.binary_cross_entropy( recon_x.reshape(x_rep.shape[0], -1), x_rep.reshape(x_rep.shape[0], -1), reduction="none", ).sum(dim=-1) log_p_x.append( log_p_x_given_z + log_p_z - log_q_z_given_x ) # log(2*pi) simplifies log_p_x = torch.cat(log_p_x) log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item()) if i % 100 == 0: print(f"Current nll at {i}: {np.mean(log_p)}") return np.mean(log_p)