Source code for pythae.models.ae.ae_model

import os
from typing import Optional

import torch.nn.functional as F

from ...data.datasets import BaseDataset
from ..base import BaseAE
from ..base.base_utils import ModelOutput
from ..nn import BaseDecoder, BaseEncoder
from ..nn.default_architectures import Encoder_AE_MLP
from .ae_config import AEConfig


[docs]class AE(BaseAE): """Vanilla Autoencoder model. Args: model_config (AEConfig): The Autoencoder configuration setting the main parameters of the model. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: AEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): BaseAE.__init__(self, model_config=model_config, decoder=decoder) self.model_name = "AE" if encoder is None: if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" "'input_dim' parameter of BaseAEConfig instance must be set to 'data_shape' where " "the shape of the data is (C, H, W ..). Unable to build encoder " "automatically" ) encoder = Encoder_AE_MLP(model_config) self.model_config.uses_default_encoder = True else: self.model_config.uses_default_encoder = False self.set_encoder(encoder)
[docs] def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: """The input data is encoded and decoded Args: inputs (BaseDataset): An instance of pythae's datasets Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = inputs["data"] z = self.encoder(x).embedding recon_x = self.decoder(z)["reconstruction"] loss = self.loss_function(recon_x, x) output = ModelOutput(loss=loss, recon_x=recon_x, z=z) return output
def loss_function(self, recon_x, x): MSE = F.mse_loss( recon_x.reshape(x.shape[0], -1), x.reshape(x.shape[0], -1), reduction="none" ).sum(dim=-1) return MSE.mean(dim=0)