Source code for pythae.models.pvae.pvae_config

from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from ..vae import VAEConfig


[docs]@dataclass class PoincareVAEConfig(VAEConfig): """Poincaré VAE config class. Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. reconstruction_loss (str): The reconstruction loss to use ['bce', 'mse']. Default: 'mse' prior_distribution (str): The distribution to use as prior ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" posterior_distribution (str): The distribution to use as posterior ["wrapped_normal", "riemannian_normal"]. Default: "wrapped_normal" curvature (int): The curvature of the manifold. Default: 1 """ prior_distribution: Literal[ "wrapped_normal", "riemannian_normal" ] = "wrapped_normal" posterior_distribution: Literal[ "wrapped_normal", "riemannian_normal" ] = "wrapped_normal" curvature: float = 1