Source code for pythae.models.base.base_model

import inspect
import logging
import os
import shutil
import sys
import tempfile
import warnings
from copy import deepcopy
from http.cookiejar import LoadError
from typing import Optional

import cloudpickle
import torch
import torch.nn as nn

from ...customexception import BadInheritanceError
from ...data.datasets import BaseDataset, DatasetOutput
from ..auto_model import AutoConfig
from ..nn import BaseDecoder, BaseEncoder
from ..nn.default_architectures import Decoder_AE_MLP
from .base_config import BaseAEConfig, EnvironmentConfig
from .base_utils import (
    CPU_Unpickler,
    ModelOutput,
    hf_hub_is_available,
    model_card_template,
)

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class BaseAE(nn.Module): """Base class for Autoencoder based models. Args: model_config (BaseAEConfig): An instance of BaseAEConfig in which any model's parameters is made available. encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which plays the role of encoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which plays the role of decoder. This argument allows you to use your own neural networks architectures if desired. If None is provided, a simple Multi Layer Preception (https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None. .. note:: For high dimensional data we advice you to provide you own network architectures. With the provided MLP you may end up with a ``MemoryError``. """ def __init__( self, model_config: BaseAEConfig, encoder: Optional[BaseEncoder] = None, decoder: Optional[BaseDecoder] = None, ): nn.Module.__init__(self) self.model_name = "BaseAE" self.input_dim = model_config.input_dim self.latent_dim = model_config.latent_dim self.model_config = model_config if decoder is None: if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" "'input_dim' parameter of BaseAEConfig instance must be set to 'data_shape' where " "the shape of the data is (C, H, W ..)]. Unable to build decoder" "automatically" ) decoder = Decoder_AE_MLP(model_config) self.model_config.uses_default_decoder = True else: self.model_config.uses_default_decoder = False self.set_decoder(decoder) self.device = None
[docs] def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: """Main forward pass outputing the VAE outputs This function should output a :class:`~pythae.models.base.base_utils.ModelOutput` instance gathering all the model outputs Args: inputs (BaseDataset): The training data with labels, masks etc... Returns: ModelOutput: A ModelOutput instance providing the outputs of the model. .. note:: The loss must be computed in this forward pass and accessed through ``loss = model_output.loss``""" raise NotImplementedError()
[docs] def reconstruct(self, inputs: torch.Tensor): """This function returns the reconstructions of given input data. Args: inputs (torch.Tensor): The inputs data to be reconstructed of shape [B x input_dim] ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape Returns: torch.Tensor: A tensor of shape [B x input_dim] containing the reconstructed samples. """ return self(DatasetOutput(data=inputs)).recon_x
[docs] def embed(self, inputs: torch.Tensor) -> torch.Tensor: """Return the embeddings of the input data. Args: inputs (torch.Tensor): The input data to be embedded, of shape [B x input_dim]. Returns: torch.Tensor: A tensor of shape [B x latent_dim] containing the embeddings. """ return self(DatasetOutput(data=inputs)).z
[docs] def predict(self, inputs: torch.Tensor) -> ModelOutput: """The input data is encoded and decoded without computing loss Args: inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding. Returns: ModelOutput: An instance of ModelOutput containing reconstruction and embedding """ z = self.encoder(inputs).embedding recon_x = self.decoder(z)["reconstruction"] output = ModelOutput( recon_x=recon_x, embedding=z, ) return output
[docs] def interpolate( self, starting_inputs: torch.Tensor, ending_inputs: torch.Tensor, granularity: int = 10, ): """This function performs a linear interpolation in the latent space of the autoencoder from starting inputs to ending inputs. It returns the interpolation trajectories. Args: starting_inputs (torch.Tensor): The starting inputs in the interpolation of shape [B x input_dim] ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape [B x input_dim] granularity (int): The granularity of the interpolation. Returns: torch.Tensor: A tensor of shape [B x granularity x input_dim] containing the interpolation trajectories. """ assert starting_inputs.shape[0] == ending_inputs.shape[0], ( "The number of starting_inputs should equal the number of ending_inputs. Got " f"{starting_inputs.shape[0]} sampler for starting_inputs and {ending_inputs.shape[0]} " "for endinging_inputs." ) starting_z = self(DatasetOutput(data=starting_inputs)).z ending_z = self(DatasetOutput(data=ending_inputs)).z t = torch.linspace(0, 1, granularity).to(starting_inputs.device) intep_line = ( torch.kron( starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1) ) + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1)) ).reshape((starting_z.shape[0] * t.shape[0],) + (starting_z.shape[1:])) decoded_line = self.decoder(intep_line).reconstruction.reshape( ( starting_inputs.shape[0], t.shape[0], ) + (starting_inputs.shape[1:]) ) return decoded_line
[docs] def update(self): """Method that allows model update during the training (at the end of a training epoch) If needed, this method must be implemented in a child class. By default, it does nothing. """
[docs] def save(self, dir_path: str): """Method to save the model at a specific location. It saves, the model weights as a ``models.pt`` file along with the model config as a ``model_config.json`` file. If the model to save used custom encoder (resp. decoder) provided by the user, these are also saved as ``decoder.pkl`` (resp. ``decoder.pkl``). Args: dir_path (str): The path where the model should be saved. If the path path does not exist a folder will be created at the provided location. """ env_spec = EnvironmentConfig( python_version=f"{sys.version_info[0]}.{sys.version_info[1]}" ) model_dict = {"model_state_dict": deepcopy(self.state_dict())} if not os.path.exists(dir_path): try: os.makedirs(dir_path) except FileNotFoundError as e: raise e env_spec.save_json(dir_path, "environment") self.model_config.save_json(dir_path, "model_config") # only save .pkl if custom architecture provided if not self.model_config.uses_default_encoder: with open(os.path.join(dir_path, "encoder.pkl"), "wb") as fp: cloudpickle.register_pickle_by_value(inspect.getmodule(self.encoder)) cloudpickle.dump(self.encoder, fp) if not self.model_config.uses_default_decoder: with open(os.path.join(dir_path, "decoder.pkl"), "wb") as fp: cloudpickle.register_pickle_by_value(inspect.getmodule(self.decoder)) cloudpickle.dump(self.decoder, fp) torch.save(model_dict, os.path.join(dir_path, "model.pt"))
[docs] def push_to_hf_hub(self, hf_hub_path: str): # pragma: no cover """Method allowing to save your model directly on the Hugging Face hub. You will need to have the `huggingface_hub` package installed and a valid Hugging Face account. You can install the package using .. code-block:: bash python -m pip install huggingface_hub end then login using .. code-block:: bash huggingface-cli login Args: hf_hub_path (str): path to your repo on the Hugging Face hub. """ if not hf_hub_is_available(): raise ModuleNotFoundError( "`huggingface_hub` package must be installed to push your model to the HF hub. " "Run `python -m pip install huggingface_hub` and log in to your account with " "`huggingface-cli login`." ) else: from huggingface_hub import CommitOperationAdd, HfApi logger.info( f"Uploading {self.model_name} model to {hf_hub_path} repo in HF hub..." ) tempdir = tempfile.mkdtemp() self.save(tempdir) model_files = os.listdir(tempdir) api = HfApi() hf_operations = [] for file in model_files: hf_operations.append( CommitOperationAdd( path_in_repo=file, path_or_fileobj=f"{str(os.path.join(tempdir, file))}", ) ) with open(os.path.join(tempdir, "model_card.md"), "w") as f: f.write(model_card_template) hf_operations.append( CommitOperationAdd( path_in_repo="README.md", path_or_fileobj=os.path.join(tempdir, "model_card.md"), ) ) try: api.create_commit( commit_message=f"Uploading {self.model_name} in {hf_hub_path}", repo_id=hf_hub_path, operations=hf_operations, ) logger.info( f"Successfully uploaded {self.model_name} to {hf_hub_path} repo in HF hub!" ) except: from huggingface_hub import create_repo repo_name = os.path.basename(os.path.normpath(hf_hub_path)) logger.info( f"Creating {repo_name} in the HF hub since it does not exist..." ) create_repo(repo_id=repo_name) logger.info(f"Successfully created {repo_name} in the HF hub!") api.create_commit( commit_message=f"Uploading {self.model_name} in {hf_hub_path}", repo_id=hf_hub_path, operations=hf_operations, ) shutil.rmtree(tempdir)
@classmethod def _load_model_config_from_folder(cls, dir_path): file_list = os.listdir(dir_path) if "model_config.json" not in file_list: raise FileNotFoundError( f"Missing model config file ('model_config.json') in" f"{dir_path}... Cannot perform model building." ) path_to_model_config = os.path.join(dir_path, "model_config.json") model_config = AutoConfig.from_json_file(path_to_model_config) return model_config @classmethod def _load_model_weights_from_folder(cls, dir_path): file_list = os.listdir(dir_path) if "model.pt" not in file_list: raise FileNotFoundError( f"Missing model weights file ('model.pt') file in" f"{dir_path}... Cannot perform model building." ) path_to_model_weights = os.path.join(dir_path, "model.pt") try: model_weights = torch.load(path_to_model_weights, map_location="cpu") except RuntimeError: RuntimeError( "Enable to load model weights. Ensure they are saves in a '.pt' format." ) if "model_state_dict" not in model_weights.keys(): raise KeyError( "Model state dict is not available in 'model.pt' file. Got keys:" f"{model_weights.keys()}" ) model_weights = model_weights["model_state_dict"] return model_weights @classmethod def _load_custom_encoder_from_folder(cls, dir_path): file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) if "encoder.pkl" not in file_list: raise FileNotFoundError( f"Missing encoder pkl file ('encoder.pkl') in" f"{dir_path}... This file is needed to rebuild custom encoders." " Cannot perform model building." ) else: with open(os.path.join(dir_path, "encoder.pkl"), "rb") as fp: encoder = CPU_Unpickler(fp).load() return encoder @classmethod def _load_custom_decoder_from_folder(cls, dir_path): file_list = os.listdir(dir_path) cls._check_python_version_from_folder(dir_path=dir_path) if "decoder.pkl" not in file_list: raise FileNotFoundError( f"Missing decoder pkl file ('decoder.pkl') in" f"{dir_path}... This file is needed to rebuild custom decoders." " Cannot perform model building." ) else: with open(os.path.join(dir_path, "decoder.pkl"), "rb") as fp: decoder = CPU_Unpickler(fp).load() return decoder
[docs] @classmethod def load_from_folder(cls, dir_path): """Class method to be used to load the model from a specific folder Args: dir_path (str): The path where the model should have been be saved. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. ``decoder.pkl``) if a custom encoder (resp. decoder) was provided """ model_config = cls._load_model_config_from_folder(dir_path) model_weights = cls._load_model_weights_from_folder(dir_path) if not model_config.uses_default_encoder: encoder = cls._load_custom_encoder_from_folder(dir_path) else: encoder = None if not model_config.uses_default_decoder: decoder = cls._load_custom_decoder_from_folder(dir_path) else: decoder = None model = cls(model_config, encoder=encoder, decoder=decoder) model.load_state_dict(model_weights) return model
[docs] @classmethod def load_from_hf_hub(cls, hf_hub_path: str, allow_pickle=False): # pragma: no cover """Class method to be used to load a pretrained model from the Hugging Face hub Args: hf_hub_path (str): The path where the model should have been be saved on the hugginface hub. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided **or** - | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp. ``decoder.pkl``) if a custom encoder (resp. decoder) was provided """ if not hf_hub_is_available(): raise ModuleNotFoundError( "`huggingface_hub` package must be installed to load models from the HF hub. " "Run `python -m pip install huggingface_hub` and log in to your account with " "`huggingface-cli login`." ) else: from huggingface_hub import hf_hub_download logger.info(f"Downloading {cls.__name__} files for rebuilding...") _ = hf_hub_download(repo_id=hf_hub_path, filename="environment.json") config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json") dir_path = os.path.dirname(config_path) _ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt") model_config = cls._load_model_config_from_folder(dir_path) if ( cls.__name__ + "Config" != model_config.name and cls.__name__ + "_Config" != model_config.name ): warnings.warn( f"You are trying to load a " f"`{ cls.__name__}` while a " f"`{model_config.name}` is given." ) model_weights = cls._load_model_weights_from_folder(dir_path) if ( not model_config.uses_default_encoder or not model_config.uses_default_decoder ) and not allow_pickle: warnings.warn( "You are about to download pickled files from the HF hub that may have " "been created by a third party and so could potentially harm your computer. If you " "are sure that you want to download them set `allow_pickle=true`." ) else: if not model_config.uses_default_encoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl") encoder = cls._load_custom_encoder_from_folder(dir_path) else: encoder = None if not model_config.uses_default_decoder: _ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl") decoder = cls._load_custom_decoder_from_folder(dir_path) else: decoder = None logger.info(f"Successfully downloaded {cls.__name__} model!") model = cls(model_config, encoder=encoder, decoder=decoder) model.load_state_dict(model_weights) return model
[docs] def set_encoder(self, encoder: BaseEncoder) -> None: """Set the encoder of the model""" if not issubclass(type(encoder), BaseEncoder): raise BadInheritanceError( ( "Encoder must inherit from BaseEncoder class from " "pythae.models.base_architectures.BaseEncoder. Refer to documentation." ) ) self.encoder = encoder
[docs] def set_decoder(self, decoder: BaseDecoder) -> None: """Set the decoder of the model""" if not issubclass(type(decoder), BaseDecoder): raise BadInheritanceError( ( "Decoder must inherit from BaseDecoder class from " "pythae.models.base_architectures.BaseDecoder. Refer to documentation." ) ) self.decoder = decoder
@classmethod def _check_python_version_from_folder(cls, dir_path: str): if "environment.json" in os.listdir(dir_path): env_spec = EnvironmentConfig.from_json_file( os.path.join(dir_path, "environment.json") ) python_version = env_spec.python_version python_version_minor = python_version.split(".")[1] if python_version_minor == "7" and sys.version_info[1] > 7: raise LoadError( "Trying to reload a model saved with python3.7 with python3.8+. " "Please create a virtual env with python 3.7 to reload this model." ) elif int(python_version_minor) >= 8 and sys.version_info[1] == 7: raise LoadError( "Trying to reload a model saved with python3.8+ with python3.7. " "Please create a virtual env with python 3.8+ to reload this model." )