Source code for pythae.models.auto_model.auto_model

import json
import logging
import os

import torch.nn as nn

from ..base.base_utils import hf_hub_is_available

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class AutoModel(nn.Module): "Utils class allowing to reload any :class:`pythae.models` automatically" def __init__(self) -> None: super().__init__()
[docs] @classmethod def load_from_folder(cls, dir_path: str): """Class method to be used to load the model from a specific folder Args: dir_path (str): The path where the model should have been be saved. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. ``decoder.pkl``) if a custom encoder (resp. decoder) was provided """ with open(os.path.join(dir_path, "model_config.json")) as f: model_name = json.load(f)["name"] if model_name == "Adversarial_AE_Config": from ..adversarial_ae import Adversarial_AE model = Adversarial_AE.load_from_folder(dir_path=dir_path) elif model_name == "AEConfig": from ..ae import AE model = AE.load_from_folder(dir_path=dir_path) elif model_name == "BetaTCVAEConfig": from ..beta_tc_vae import BetaTCVAE model = BetaTCVAE.load_from_folder(dir_path=dir_path) elif model_name == "BetaVAEConfig": from ..beta_vae import BetaVAE model = BetaVAE.load_from_folder(dir_path=dir_path) elif model_name == "DisentangledBetaVAEConfig": from ..disentangled_beta_vae import DisentangledBetaVAE model = DisentangledBetaVAE.load_from_folder(dir_path=dir_path) elif model_name == "FactorVAEConfig": from ..factor_vae import FactorVAE model = FactorVAE.load_from_folder(dir_path=dir_path) elif model_name == "HVAEConfig": from ..hvae import HVAE model = HVAE.load_from_folder(dir_path=dir_path) elif model_name == "HRQVAEConfig": from ..hrq_vae import HRQVAE model = HRQVAE.load_from_folder(dir_path=dir_path) elif model_name == "INFOVAE_MMD_Config": from ..info_vae import INFOVAE_MMD model = INFOVAE_MMD.load_from_folder(dir_path=dir_path) elif model_name == "IWAEConfig": from ..iwae import IWAE model = IWAE.load_from_folder(dir_path=dir_path) elif model_name == "MSSSIM_VAEConfig": from ..msssim_vae import MSSSIM_VAE model = MSSSIM_VAE.load_from_folder(dir_path=dir_path) elif model_name == "RAE_GP_Config": from ..rae_gp import RAE_GP model = RAE_GP.load_from_folder(dir_path=dir_path) elif model_name == "RAE_L2_Config": from ..rae_l2 import RAE_L2 model = RAE_L2.load_from_folder(dir_path=dir_path) elif model_name == "RHVAEConfig": from ..rhvae import RHVAE model = RHVAE.load_from_folder(dir_path=dir_path) elif model_name == "SVAEConfig": from ..svae import SVAE model = SVAE.load_from_folder(dir_path=dir_path) elif model_name == "VAEConfig": from ..vae import VAE model = VAE.load_from_folder(dir_path=dir_path) elif model_name == "VAEGANConfig": from ..vae_gan import VAEGAN model = VAEGAN.load_from_folder(dir_path=dir_path) elif model_name == "VAE_IAF_Config": from ..vae_iaf import VAE_IAF model = VAE_IAF.load_from_folder(dir_path=dir_path) elif model_name == "VAE_LinNF_Config": from ..vae_lin_nf import VAE_LinNF model = VAE_LinNF.load_from_folder(dir_path=dir_path) elif model_name == "VAMPConfig": from ..vamp import VAMP model = VAMP.load_from_folder(dir_path=dir_path) elif model_name == "VQVAEConfig": from ..vq_vae import VQVAE model = VQVAE.load_from_folder(dir_path=dir_path) elif model_name == "WAE_MMD_Config": from ..wae_mmd import WAE_MMD model = WAE_MMD.load_from_folder(dir_path=dir_path) elif model_name == "MAFConfig": from ..normalizing_flows import MAF model = MAF.load_from_folder(dir_path=dir_path) elif model_name == "IAFConfig": from ..normalizing_flows import IAF model = IAF.load_from_folder(dir_path=dir_path) elif model_name == "PlanarFlowConfig": from ..normalizing_flows import PlanarFlow model = PlanarFlow.load_from_folder(dir_path=dir_path) elif model_name == "RadialFlowConfig": from ..normalizing_flows import RadialFlow model = RadialFlow.load_from_folder(dir_path=dir_path) elif model_name == "MADEConfig": from ..normalizing_flows import MADE model = MADE.load_from_folder(dir_path=dir_path) elif model_name == "PixelCNNConfig": from ..normalizing_flows import PixelCNN model = PixelCNN.load_from_folder(dir_path=dir_path) elif model_name == "PoincareVAEConfig": from ..pvae import PoincareVAE model = PoincareVAE.load_from_folder(dir_path=dir_path) elif model_name == "CIWAEConfig": from ..ciwae import CIWAE model = CIWAE.load_from_folder(dir_path=dir_path) elif model_name == "MIWAEConfig": from ..miwae import MIWAE model = MIWAE.load_from_folder(dir_path=dir_path) elif model_name == "PIWAEConfig": from ..piwae import PIWAE model = PIWAE.load_from_folder(dir_path=dir_path) else: raise NameError( "Cannot reload automatically the model... " f"The model name in the `model_config.json may be corrupted. Got {model_name}" ) return model
[docs] @classmethod def load_from_hf_hub( cls, hf_hub_path: str, allow_pickle: bool = False ): # pragma: no cover """Class method to be used to load a automaticaly a pretrained model from the Hugging Face hub Args: hf_hub_path (str): The path where the model should have been be saved on the hugginface hub. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. ``decoder.pkl``) if a custom encoder (resp. decoder) was provided """ if not hf_hub_is_available(): raise ModuleNotFoundError( "`huggingface_hub` package must be installed to load models from the HF hub. " "Run `python -m pip install huggingface_hub` and log in to your account with " "`huggingface-cli login`." ) else: from huggingface_hub import hf_hub_download logger.info(f"Downloading config file ...") config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") dir_path = os.path.dirname(config_path) with open(os.path.join(dir_path, "model_config.json")) as f: model_name = json.load(f)["name"] if model_name == "Adversarial_AE_Config": from ..adversarial_ae import Adversarial_AE model = Adversarial_AE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "AEConfig": from ..ae import AE model = AE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "BetaTCVAEConfig": from ..beta_tc_vae import BetaTCVAE model = BetaTCVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "BetaVAEConfig": from ..beta_vae import BetaVAE model = BetaVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "DisentangledBetaVAEConfig": from ..disentangled_beta_vae import DisentangledBetaVAE model = DisentangledBetaVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "FactorVAEConfig": from ..factor_vae import FactorVAE model = FactorVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "HVAEConfig": from ..hvae import HVAE model = HVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "INFOVAE_MMD_Config": from ..info_vae import INFOVAE_MMD model = INFOVAE_MMD.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "IWAEConfig": from ..iwae import IWAE model = IWAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "MSSSIM_VAEConfig": from ..msssim_vae import MSSSIM_VAE model = MSSSIM_VAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "RAE_GP_Config": from ..rae_gp import RAE_GP model = RAE_GP.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "RAE_L2_Config": from ..rae_l2 import RAE_L2 model = RAE_L2.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "RHVAEConfig": from ..rhvae import RHVAE model = RHVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "SVAEConfig": from ..svae import SVAE model = SVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VAEConfig": from ..vae import VAE model = VAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VAEGANConfig": from ..vae_gan import VAEGAN model = VAEGAN.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VAE_IAF_Config": from ..vae_iaf import VAE_IAF model = VAE_IAF.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VAE_LinNF_Config": from ..vae_lin_nf import VAE_LinNF model = VAE_LinNF.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VAMPConfig": from ..vamp import VAMP model = VAMP.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "VQVAEConfig": from ..vq_vae import VQVAE model = VQVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "WAE_MMD_Config": from ..wae_mmd import WAE_MMD model = WAE_MMD.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "MAFConfig": from ..normalizing_flows import MAF model = MAF.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "IAFConfig": from ..normalizing_flows import IAF model = IAF.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "PlanarFlowConfig": from ..normalizing_flows import PlanarFlow model = PlanarFlow.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "RadialFlowConfig": from ..normalizing_flows import RadialFlow model = RadialFlow.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "MADEConfig": from ..normalizing_flows import MADE model = MADE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "PixelCNNConfig": from ..normalizing_flows import PixelCNN model = PixelCNN.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "PoincareVAEConfig": from ..pvae import PoincareVAE model = PoincareVAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "CIWAEConfig": from ..ciwae import CIWAE model = CIWAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "MIWAEConfig": from ..miwae import MIWAE model = MIWAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) elif model_name == "PIWAEConfig": from ..piwae import PIWAE model = PIWAE.load_from_hf_hub( hf_hub_path=hf_hub_path, allow_pickle=allow_pickle ) else: raise NameError( "Cannot reload automatically the model... " f"The model name in the `model_config.json may be corrupted. Got {model_name}" ) return model