Source code for pythae.models.vq_vae.vq_vae_model

import os
from typing import Optional

import torch
import torch.nn.functional as F

from ...data.datasets import BaseDataset
from ..ae import AE
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from .vq_vae_config import VQVAEConfig
from .vq_vae_utils import Quantizer, QuantizerEMA


[docs]class VQVAE(AE): r""" Vector Quantized-VAE model. Args: model_config (VQVAEConfig): The Variational Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: VQVAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): AE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder) self._set_quantizer(model_config) self.model_name = "VQVAE" def _set_quantizer(self, model_config): if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" "'input_dim' parameter of VQVAEConfig instance must be set to 'data_shape' where " "the shape of the data is (C, H, W ..). Unable to set quantizer." ) x = torch.randn((2,) + self.model_config.input_dim) z = self.encoder(x).embedding if len(z.shape) == 2: z = z.reshape(z.shape[0], 1, 1, -1) z = z.permute(0, 2, 3, 1) self.model_config.embedding_dim = z.shape[-1] if model_config.use_ema: self.quantizer = QuantizerEMA(model_config=model_config) else: self.quantizer = Quantizer(model_config=model_config)
[docs] def forward(self, inputs: BaseDataset, **kwargs): """ The VAE model Args: inputs (BaseDataset): The training dataset with labels Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] uses_ddp = kwargs.pop("uses_ddp", False) encoder_output = self.encoder(x) embeddings = encoder_output.embedding reshape_for_decoding = False if len(embeddings.shape) == 2: embeddings = embeddings.reshape(embeddings.shape[0], 1, 1, -1) reshape_for_decoding = True embeddings = embeddings.permute(0, 2, 3, 1) quantizer_output = self.quantizer(embeddings, uses_ddp=uses_ddp) quantized_embed = quantizer_output.quantized_vector quantized_indices = quantizer_output.quantized_indices if reshape_for_decoding: quantized_embed = quantized_embed.reshape(embeddings.shape[0], -1) recon_x = self.decoder(quantized_embed).reconstruction loss, recon_loss, vq_loss = self.loss_function(recon_x, x, quantizer_output) output = ModelOutput( recon_loss=recon_loss, vq_loss=vq_loss, loss=loss, recon_x=recon_x, z=quantized_embed, quantized_indices=quantized_indices, ) return output
def loss_function(self, recon_x, x, quantizer_output): recon_loss = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) vq_loss = quantizer_output.loss return ( (recon_loss + vq_loss).mean(dim=0), recon_loss.mean(dim=0), vq_loss.mean(dim=0), ) def _sample_gauss(self, mu, std): # Reparametrization trick # Sample N(0, I) eps = torch.randn_like(std) return mu + eps * std, eps