Source code for pythae.models.vae_lin_nf.vae_lin_nf_config

from dataclasses import field
from typing import List

from pydantic.dataclasses import dataclass

from ..vae import VAEConfig


[docs]@dataclass class VAE_LinNF_Config(VAEConfig): """VAE with linear Normalizing Flow config class. Parameters: input_dim (int): The input_data dimension latent_dim (int): The latent space dimension. Default: None. reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' flows (List[str]): A list of strings corresponding to the class of each flow to be applied. Default: ['Plannar', 'Planar']. Flow choices: ['Planar', 'Radial']. """ flows: List[str] = field(default_factory=lambda: ["Planar", "Planar"]) def __post_init__(self): super().__post_init__() for i, f in enumerate(self.flows): assert f in ["Planar", "Radial"], ( f"Flow name number {i+1}: '{f}' doesn't correspond " "to ones of the classes. Available linear flows ['Planar', 'Radial']" )