Source code for pythae.models.wae_mmd.wae_mmd_config

from dataclasses import field
from typing import List, Union

from pydantic.dataclasses import dataclass
from typing_extensions import Literal

from ..ae import AEConfig


[docs]@dataclass class WAE_MMD_Config(AEConfig): """Wasserstein autoencoder model config class. Parameters: input_dim (tuple): The input_data dimension. latent_dim (int): The latent space dimension. Default: None. kernel_choice (str): The kernel to choose. Available options are ['rbf', 'imq'] i.e. radial basis functions or inverse multiquadratic kernel. Default: 'imq'. reg_weight (float): The weight to apply between reconstruction and Maximum Mean Discrepancy. Default: 3e-2 kernel_bandwidth (float): The kernel bandwidth. Default: 1 scales (list): The scales to apply if using multi-scale imq kernels. If None, use a unique imq kernel. Default: [.1, .2, .5, 1., 2., 5, 10.]. reconstruction_loss_scale (float): Parameter scaling the reconstruction loss. Default: 1 """ kernel_choice: Literal["rbf", "imq"] = "imq" reg_weight: float = 3e-2 kernel_bandwidth: float = 1.0 scales: Union[List[float], None] = field( default_factory=lambda: [0.1, 0.2, 0.5, 1.0, 2.0, 5, 10.0] ) reconstruction_loss_scale: float = 1.0