Source code for pythae.models.ciwae.ciwae_config

from pydantic.dataclasses import dataclass

from ..vae import VAEConfig


[docs]@dataclass class CIWAEConfig(VAEConfig): """Combination IWAE model config class. Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' number_samples (int): Number of samples to use on the Monte-Carlo estimation beta (float): The value of the factor in the convex combination of the VAE and IWAE ELBO. Default: 0.5. """ number_samples: int = 10 beta: float = 0.5 def __post_init__(self): super().__post_init__() assert 0 <= self.beta <= 1, f"Beta parameter must be in [0-1]. Got {self.beta}."