Source code for pythae.models.normalizing_flows.base.base_nf_model

import os
import sys
from copy import deepcopy

import numpy as np
import torch
import torch.nn as nn

from ....data.datasets import BaseDataset
from ...auto_model import AutoConfig
from ...base.base_config import EnvironmentConfig
from ...base.base_utils import ModelOutput
from .base_nf_config import BaseNFConfig


[docs]class BaseNF(nn.Module): """Base Class from Normalizing flows Args: model_config (BaseNFConfig): The configuration setting the main parameters of the model. """ def __init__(self, model_config: BaseNFConfig): nn.Module.__init__(self) if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" "'input_dim' parameter of MADEConfig instance must be set to 'data_shape' " "where the shape of the data is (C, H, W ..)]. Unable to build network" "automatically" ) self.model_config = model_config self.input_dim = np.prod(model_config.input_dim)
[docs] def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """Main forward pass mapping the data towards the prior This function should output a :class:`~pythae.models.base.base_utils.ModelOutput` instance gathering all the model outputs Args: x (torch.Tensor): The training data. Returns: ModelOutput: A ModelOutput instance providing the outputs of the model. """ raise NotImplementedError()
[docs] def inverse(self, y: torch.Tensor, **kwargs) -> ModelOutput: """Main inverse pass mapping the prior toward the data This function should output a :class:`~pythae.models.base.base_utils.ModelOutput` instance gathering all the model outputs Args: inputs (torch.Tensor): Data from the prior. Returns: ModelOutput: A ModelOutput instance providing the outputs of the model. """ raise NotImplementedError()
[docs] def update(self): """Method that allows model update during the training (at the end of a training epoch) If needed, this method must be implemented in a child class. By default, it does nothing. """
[docs] def save(self, dir_path): """Method to save the model at a specific location. It saves, the model weights as a ``models.pt`` file along with the model config as a ``model_config.json`` file. Args: dir_path (str): The path where the model should be saved. If the path path does not exist a folder will be created at the provided location. """ env_spec = EnvironmentConfig( python_version=f"{sys.version_info[0]}.{sys.version_info[1]}" ) model_dict = {"model_state_dict": deepcopy(self.state_dict())} if not os.path.exists(dir_path): try: os.makedirs(dir_path) except (FileNotFoundError, TypeError) as e: raise e env_spec.save_json(dir_path, "environment") self.model_config.save_json(dir_path, "model_config") torch.save(model_dict, os.path.join(dir_path, "model.pt"))
@classmethod def _load_model_config_from_folder(cls, dir_path): file_list = os.listdir(dir_path) if "model_config.json" not in file_list: raise FileNotFoundError( f"Missing model config file ('model_config.json') in" f"{dir_path}... Cannot perform model building." ) path_to_model_config = os.path.join(dir_path, "model_config.json") model_config = AutoConfig.from_json_file(path_to_model_config) return model_config @classmethod def _load_model_weights_from_folder(cls, dir_path): file_list = os.listdir(dir_path) if "model.pt" not in file_list: raise FileNotFoundError( f"Missing model weights file ('model.pt') file in" f"{dir_path}... Cannot perform model building." ) path_to_model_weights = os.path.join(dir_path, "model.pt") try: model_weights = torch.load(path_to_model_weights, map_location="cpu") except RuntimeError: RuntimeError( "Enable to load model weights. Ensure they are saves in a '.pt' format." ) if "model_state_dict" not in model_weights.keys(): raise KeyError( "Model state dict is not available in 'model.pt' file. Got keys:" f"{model_weights.keys()}" ) model_weights = model_weights["model_state_dict"] return model_weights
[docs] @classmethod def load_from_folder(cls, dir_path): """Class method to be used to load the model from a specific folder Args: dir_path (str): The path where the model should have been be saved. .. note:: This function requires the folder to contain: - | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided """ model_config = cls._load_model_config_from_folder(dir_path) model_weights = cls._load_model_weights_from_folder(dir_path) model = cls(model_config) model.load_state_dict(model_weights) return model
class NFModel(nn.Module): """Class wrapping the normalizing flows so it can articulate with :class:`~pythae.trainers.BaseTrainer` """ def __init__(self, prior: torch.distributions, flow: BaseNF): nn.Module.__init__(self) self.prior = prior self.flow = flow @property def model_config(self): return self.flow.model_config @property def model_name(self): return self.flow.model_name def forward(self, x: BaseDataset, **kwargs): x = x["data"] flow_output = self.flow(x, **kwargs) y = flow_output.out log_abs_det_jac = flow_output.log_abs_det_jac log_prob_prior = self.prior.log_prob(y).reshape(y.shape[0]) output = ModelOutput(loss=-(log_prob_prior + log_abs_det_jac).sum()) return output def update(self): self.flow.update() def save(self, dir_path): """Method to save the model at a specific location. It saves, the model weights as a ``models.pt`` file along with the model config as a ``model_config.json`` file. Args: dir_path (str): The path where the model should be saved. If the path path does not exist a folder will be created at the provided location. """ self.flow.save(dir_path=dir_path)