Source code for pythae.models.factor_vae.factor_vae_model

import os
from typing import Optional

import torch
import torch.nn.functional as F

from ...customexception import BadInheritanceError
from ...data.datasets import BaseDataset
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseDiscriminator, BaseEncoder
from ..vae import VAE
from .factor_vae_config import FactorVAEConfig
from .factor_vae_utils import FactorVAEDiscriminator


[docs]class FactorVAE(VAE): """ FactorVAE model. Args: model_config (FactorVAEConfig): The Variational Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: FactorVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self.discriminator = FactorVAEDiscriminator(latent_dim=model_config.latent_dim) self.model_name = "FactorVAE" self.gamma = model_config.gamma
[docs] def set_discriminator(self, discriminator: BaseDiscriminator) -> None: r"""This method is called to set the discriminator network Args: discriminator (BaseDiscriminator): The discriminator module that needs to be set to the model. """ if not issubclass(type(discriminator), BaseDiscriminator): raise BadInheritanceError( ( "Discriminator must inherit from BaseDiscriminator class from " "pythae.models.base_architectures.BaseDiscriminator. Refer to documentation." ) ) self.discriminator = discriminator
[docs] def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: """ The VAE model Args: inputs (BaseDataset): The training dataset with labels Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x_in = inputs["data"] if x_in.shape[0] <= 1: raise ArithmeticError( "At least 2 samples in a batch are required for the `FactorVAE` model" ) idx = torch.randperm(x_in.shape[0]) idx_1 = idx[int(x_in.shape[0] / 2) :] idx_2 = idx[: int(x_in.shape[0] / 2)] # first batch x = inputs["data"][idx_1] encoder_output = self.encoder(x) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) z, _ = self._sample_gauss(mu, std) recon_x = self.decoder(z)["reconstruction"] # second batch x_bis = inputs["data"][idx_2] encoder_output = self.encoder(x_bis) mu_bis, log_var_bis = encoder_output.embedding, encoder_output.log_covariance std_bis = torch.exp(0.5 * log_var_bis) z_bis, _ = self._sample_gauss(mu_bis, std_bis) z_bis_permuted = self._permute_dims(z_bis).detach() recon_loss, autoencoder_loss, discriminator_loss = self.loss_function( recon_x, x, mu, log_var, z, z_bis_permuted ) loss = autoencoder_loss + discriminator_loss output = ModelOutput( loss=loss, recon_loss=recon_loss, autoencoder_loss=autoencoder_loss, discriminator_loss=discriminator_loss, recon_x=recon_x, recon_x_indices=idx_1, z=z, z_bis_permuted=z_bis_permuted, ) return output
def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): N = z.shape[0] # batch size if self.model_config.reconstruction_loss == "mse": recon_loss = ( 0.5 * F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) ) elif self.model_config.reconstruction_loss == "bce": recon_loss = F.binary_cross_entropy( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none", ).sum(dim=-1) KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1) latent_adversarial_score = self.discriminator(z) TC = (latent_adversarial_score[:, 0] - latent_adversarial_score[:, 1]).mean() autoencoder_loss = recon_loss + KLD + self.gamma * TC # discriminator loss permuted_latent_adversarial_score = self.discriminator(z_bis_permuted) true_labels = ( torch.ones(z_bis_permuted.shape[0], requires_grad=False) .type(torch.LongTensor) .to(z.device) ) fake_labels = ( torch.zeros(z.shape[0], requires_grad=False) .type(torch.LongTensor) .to(z.device) ) TC_permuted = F.cross_entropy( latent_adversarial_score, fake_labels ) + F.cross_entropy(permuted_latent_adversarial_score, true_labels) discriminator_loss = 0.5 * TC_permuted return ( (recon_loss).mean(dim=0), (autoencoder_loss).mean(dim=0), (discriminator_loss).mean(dim=0), )
[docs] def reconstruct(self, inputs: torch.Tensor): """This function returns the reconstructions of given input data. Args: inputs (torch.Tensor): The inputs data to be reconstructed of shape [B x input_dim] ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape Returns: torch.Tensor: A tensor of shape [B x input_dim] containing the reconstructed samples. """ encoder_output = self.encoder(inputs) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) z, _ = self._sample_gauss(mu, std) recon_x = self.decoder(z)["reconstruction"] return recon_x
[docs] def interpolate( self, starting_inputs: torch.Tensor, ending_inputs: torch.Tensor, granularity: int = 10, ): """This function performs a linear interpolation in the latent space of the autoencoder from starting inputs to ending inputs. It returns the interpolation trajectories. Args: starting_inputs (torch.Tensor): The starting inputs in the interpolation of shape [B x input_dim] ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape [B x input_dim] granularity (int): The granularity of the interpolation. Returns: torch.Tensor: A tensor of shape [B x granularity x input_dim] containing the interpolation trajectories. """ assert starting_inputs.shape[0] == ending_inputs.shape[0], ( "The number of starting_inputs should equal the number of ending_inputs. Got " f"{starting_inputs.shape[0]} sampler for starting_inputs and {ending_inputs.shape[0]} " "for endinging_inputs." ) encoder_output = self.encoder(starting_inputs) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) starting_z, _ = self._sample_gauss(mu, std) encoder_output = self.encoder(ending_inputs) mu, log_var = encoder_output.embedding, encoder_output.log_covariance std = torch.exp(0.5 * log_var) ending_z, _ = self._sample_gauss(mu, std) t = torch.linspace(0, 1, granularity).to(starting_inputs.device) intep_line = ( torch.kron( starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1) ) + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1)) ).reshape((starting_z.shape[0] * t.shape[0],) + (starting_z.shape[1:])) decoded_line = self.decoder(intep_line).reconstruction.reshape( ( starting_inputs.shape[0], t.shape[0], ) + (starting_inputs.shape[1:]) ) return decoded_line
def _sample_gauss(self, mu, std): # Reparametrization trick # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps def _permute_dims(self, z): permuted = torch.zeros_like(z) for i in range(z.shape[-1]): perms = torch.randperm(z.shape[0]).to(z.device) permuted[:, i] = z[perms, i] return permuted