Source code for pythae.models.beta_tc_vae.beta_tc_vae_model

import os
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F

from ...data.datasets import BaseDataset
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from ..vae import VAE
from .beta_tc_vae_config import BetaTCVAEConfig


[docs]class BetaTCVAE(VAE): r""" :math:`\beta`-TCVAE model. Args: model_config (BetaTCVAEConfig): The Variational Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: BetaTCVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.model_name = "BetaTCVAE" self.alpha = model_config.alpha self.beta = model_config.beta self.gamma = model_config.gamma self.use_mss = model_config.use_mss
[docs] def forward(self, inputs: BaseDataset, **kwargs): """ The VAE model Args: inputs (BaseDataset): The training dataset with labels Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] dataset_size = kwargs.pop("dataset_size", x.shape[0]) encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) z, _ = self._sample_gauss(mu, std) recon_x = self.decoder(z)["reconstruction"] loss, recon_loss, kld = self.loss_function( recon_x, x, mu, log_var, z, dataset_size ) output = ModelOutput( recon_loss=recon_loss, reg_loss=kld, loss=loss, recon_x=recon_x, z=z, ) return output
def loss_function(self, recon_x, x, mu, log_var, z, dataset_size): if self.model_config.reconstruction_loss == "mse": recon_loss = ( 0.5 * F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) ) elif self.model_config.reconstruction_loss == "bce": recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) log_q_z_given_x = self._compute_log_gauss_density(z, mu, log_var).sum( dim=-1 ) # [B] log_prior = self._compute_log_gauss_density( z, torch.zeros_like(z), torch.zeros_like(z) ).sum( dim=-1 ) # [B] log_q_batch_perm = self._compute_log_gauss_density( z.reshape(z.shape[0], 1, -1), mu.reshape(1, z.shape[0], -1), log_var.reshape(1, z.shape[0], -1), ) # [B x B x Latent_dim] if self.use_mss: logiw_mat = self._log_importance_weight_matrix(z.shape[0], dataset_size).to( z.device ) log_q_z = torch.logsumexp( logiw_mat + log_q_batch_perm.sum(dim=-1), dim=-1 ) # MMS [B] log_prod_q_z = ( torch.logsumexp( logiw_mat.reshape(z.shape[0], z.shape[0], -1) + log_q_batch_perm, dim=1, ) ).sum( dim=-1 ) # MMS [B] else: log_q_z = torch.logsumexp(log_q_batch_perm.sum(dim=-1), dim=-1) - torch.log( torch.tensor([z.shape[0] * dataset_size]).to(z.device) ) # MWS [B] log_prod_q_z = ( torch.logsumexp(log_q_batch_perm, dim=1) - torch.log(torch.tensor([z.shape[0] * dataset_size]).to(z.device)) ).sum( dim=-1 ) # MWS [B] mutual_info_loss = log_q_z_given_x - log_q_z TC_loss = log_q_z - log_prod_q_z dimension_wise_KL = log_prod_q_z - log_prior return ( ( recon_loss + self.alpha * mutual_info_loss + self.beta * TC_loss + self.gamma * dimension_wise_KL ).mean(dim=0), recon_loss.mean(dim=0), ( self.alpha * mutual_info_loss + self.beta * TC_loss + self.gamma * dimension_wise_KL ).mean(dim=0), ) def _sample_gauss(self, mu, std): # Reparametrization trick # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps def _compute_log_gauss_density(self, z, mu, log_var): """element-wise computation""" return -0.5 * ( torch.log(torch.tensor([2 * np.pi]).to(z.device)) + log_var + (z - mu) ** 2 * torch.exp(-log_var) ) def _log_importance_weight_matrix(self, batch_size, dataset_size): """Compute importance weigth matrix for MSS Code from (https://github.com/rtqichen/beta-tcvae/blob/master/vae_quant.py) """ N = dataset_size M = batch_size - 1 strat_weight = (N - M) / (N * M) W = torch.Tensor(batch_size, batch_size).fill_(1 / M) W.view(-1)[:: M + 1] = 1 / N W.view(-1)[1 :: M + 1] = strat_weight W[M - 1, 0] = strat_weight return W.log()