Source code for pythae.models.vae_iaf.vae_iaf_config

from pydantic.dataclasses import dataclass

from ..vae import VAEConfig


[docs]@dataclass class VAE_IAF_Config(VAEConfig): """VAE with Inverse Autoregressive Normalizing Flow config class. Parameters: input_dim (int): The input_data dimension latent_dim (int): The latent space dimension. Default: None. reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse'. n_made_blocks (int): The number of :class:`~pythae.models.normalizing_flows.MADE` models to consider in the IAF used in the VAE. Default: 2. n_hidden_in_made (int): The number of hidden layers in the :class:`~pythae.models.normalizing_flows.MADE` models composing the IAF in the VAE. Default: 3. hidden_size (list): The number of unit in each hidder layerin the :class:`~pythae.models.normalizing_flows.MADE` model. The same number of units is used across the `n_hidden_in_made` and `n_made_blocks`. Default: 128. """ n_made_blocks: int = 2 n_hidden_in_made: int = 3 hidden_size: int = 128