import inspect
import logging
import os
import warnings
from collections import deque
from copy import deepcopy
from typing import Optional
import cloudpickle
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from ...customexception import BadInheritanceError
from ...data.datasets import BaseDataset
from ..base.base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available
from ..nn import BaseDecoder, BaseEncoder, BaseMetric
from ..nn.default_architectures import Metric_MLP
from ..vae import VAE
from .rhvae_config import RHVAEConfig
from .rhvae_utils import create_inverse_metric, create_metric
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]class RHVAE(VAE):
r"""
Riemannian Hamiltonian VAE model.
Args:
model_config (RHVAEConfig): A model configuration setting the main parameters of the model.
encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which
plays the role of encoder. This argument allows you to use your own neural networks
architectures if desired. If None is provided, a simple Multi Layer Preception
(https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which
plays the role of decoder. This argument allows you to use your own neural networks
architectures if desired. If None is provided, a simple Multi Layer Preception
(https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
.. note::
For high dimensional data we advice you to provide you own network architectures. With the
provided MLP you may end up with a ``MemoryError``.
"""
def __init__(
self,
model_config: RHVAEConfig,
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
metric: Optional[BaseMetric] = None,
):
VAE.__init__(self, model_config=model_config, encoder=encoder, decoder=decoder)
self.model_name = "RHVAE"
if metric is None:
metric = Metric_MLP(model_config)
self.model_config.uses_default_metric = True
else:
self.model_config.uses_default_metric = False
self.set_metric(metric)
self.temperature = nn.Parameter(
torch.Tensor([model_config.temperature]), requires_grad=False
)
self.lbd = nn.Parameter(
torch.Tensor([model_config.regularization]), requires_grad=False
)
self.beta_zero_sqrt = nn.Parameter(
torch.Tensor([model_config.beta_zero]), requires_grad=False
)
self.n_lf = model_config.n_lf
self.eps_lf = model_config.eps_lf
# this is used to store the matrices and centroids throughout training for
# further use in metric update (L is the cholesky decomposition of M)
self.M = deque(maxlen=100)
self.centroids = deque(maxlen=100)
self.M_tens = torch.randn(
1, self.model_config.latent_dim, self.model_config.latent_dim
)
self.centroids_tens = torch.randn(1, self.model_config.latent_dim)
# define a starting metric (gamma_i = 0 & L = I_d)
def G(z):
return torch.inverse(
(
torch.eye(self.latent_dim, device=z.device).unsqueeze(0)
* torch.exp(-torch.norm(z.unsqueeze(1), dim=-1) ** 2)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1)
+ self.lbd * torch.eye(self.latent_dim).to(z.device)
)
def G_inv(z):
return (
torch.eye(self.latent_dim, device=z.device).unsqueeze(0)
* torch.exp(-torch.norm(z.unsqueeze(1), dim=-1) ** 2)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(z.device)
self.G = G
self.G_inv = G_inv
[docs] def update(self):
r"""
As soon as the model has seen all the data points (i.e. at the end of 1 loop)
we update the final metric function using \mu(x_i) as centroids
"""
self._update_metric()
[docs] def set_metric(self, metric: BaseMetric) -> None:
r"""This method is called to set the metric network outputing the
:math:`L_{\psi_i}` of the metric matrices
Args:
metric (BaseMetric): The metric module that need to be set to the model.
"""
if not issubclass(type(metric), BaseMetric):
raise BadInheritanceError(
(
"Metric must inherit from BaseMetric class from "
"pythae.models.base_architectures.BaseMetric. Refer to documentation."
)
)
self.metric = metric
[docs] def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:
r"""
The input data is first encoded. The reparametrization is used to produce a sample
:math:`z_0` from the approximate posterior :math:`q_{\phi}(z|x)`. Then Riemannian
Hamiltonian equations are solved using the generalized leapfrog integrator. In the meantime,
the input data :math:`x` is fed to the metric network outputing the matrices
:math:`L_{\psi}`. The metric is computed and used with the integrator.
Args:
inputs (BaseDataset): The training data with labels
Returns:
ModelOutput: An instance of ModelOutput containing all the relevant parameters
"""
x = inputs["data"]
encoder_output = self.encoder(x)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance
std = torch.exp(0.5 * log_var)
z0, eps0 = self._sample_gauss(mu, std)
z = z0
if self.training:
# update the metric using batch data points
L = self.metric(x)["L"]
M = L @ torch.transpose(L, 1, 2)
# store LL^T and mu(x_i) to update final metric
self.M.append(M.detach().clone())
self.centroids.append(mu.detach().clone())
G_inv = (
M.unsqueeze(0)
* torch.exp(
-torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2
/ (self.temperature**2)
)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(x.device)
else:
G = self.G(z)
G_inv = self.G_inv(z)
L = torch.linalg.cholesky(G)
G_log_det = -torch.logdet(G_inv)
gamma = torch.randn_like(z0, device=x.device)
rho = gamma / self.beta_zero_sqrt
beta_sqrt_old = self.beta_zero_sqrt
# sample \rho from N(0, G)
rho = (L @ rho.unsqueeze(-1)).squeeze(-1)
recon_x = self.decoder(z)["reconstruction"]
for k in range(self.n_lf):
# perform leapfrog steps
# step 1
rho_ = self._leap_step_1(recon_x, x, z, rho, G_inv, G_log_det)
# step 2
z = self._leap_step_2(recon_x, x, z, rho_, G_inv, G_log_det)
recon_x = self.decoder(z)["reconstruction"]
if self.training:
G_inv = (
M.unsqueeze(0)
* torch.exp(
-torch.norm(mu.unsqueeze(0) - z.unsqueeze(1), dim=-1) ** 2
/ (self.temperature**2)
)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(x.device)
else:
# compute metric value on new z using final metric
G = self.G(z)
G_inv = self.G_inv(z)
G_log_det = -torch.logdet(G_inv)
# step 3
rho__ = self._leap_step_3(recon_x, x, z, rho_, G_inv, G_log_det)
# tempering
beta_sqrt = self._tempering(k + 1, self.n_lf)
rho = (beta_sqrt_old / beta_sqrt) * rho__
beta_sqrt_old = beta_sqrt
loss = self.loss_function(
recon_x, x, z0, z, rho, eps0, gamma, mu, log_var, G_inv, G_log_det
)
output = ModelOutput(
loss=loss,
recon_x=recon_x,
z=z,
z0=z0,
rho=rho,
eps0=eps0,
gamma=gamma,
mu=mu,
log_var=log_var,
G_inv=G_inv,
G_log_det=G_log_det,
)
return output
[docs] def predict(self, inputs: torch.Tensor) -> ModelOutput:
"""The input data is encoded and decoded without computing loss
Args:
inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding.
Returns:
ModelOutput: An instance of ModelOutput containing reconstruction, raw embedding (output of encoder), and the final embedding (output of metric)
"""
encoder_output = self.encoder(inputs)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance
std = torch.exp(0.5 * log_var)
z0, _ = self._sample_gauss(mu, std)
z = z0
G = self.G(z)
G_inv = self.G_inv(z)
L = torch.linalg.cholesky(G)
G_log_det = -torch.logdet(G_inv)
gamma = torch.randn_like(z0, device=inputs.device)
rho = gamma / self.beta_zero_sqrt
beta_sqrt_old = self.beta_zero_sqrt
# sample \rho from N(0, G)
rho = (L @ rho.unsqueeze(-1)).squeeze(-1)
recon_x = self.decoder(z)["reconstruction"]
for k in range(self.n_lf):
# perform leapfrog steps
# step 1
rho_ = self._leap_step_1(recon_x, inputs, z, rho, G_inv, G_log_det)
# step 2
z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det)
recon_x = self.decoder(z)["reconstruction"]
# compute metric value on new z using final metric
G = self.G(z)
G_inv = self.G_inv(z)
G_log_det = -torch.logdet(G_inv)
# step 3
rho__ = self._leap_step_3(recon_x, inputs, z, rho_, G_inv, G_log_det)
# tempering
beta_sqrt = self._tempering(k + 1, self.n_lf)
rho = (beta_sqrt_old / beta_sqrt) * rho__
beta_sqrt_old = beta_sqrt
output = ModelOutput(
recon_x=recon_x,
raw_embedding=encoder_output.embedding,
embedding=z if self.n_lf > 0 else encoder_output.embedding,
)
return output
def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3):
"""
Resolves first equation of generalized leapfrog integrator
using fixed point iterations
"""
def f_(rho_):
H = self._hamiltonian(recon_x, x, z, rho_, G_inv, G_log_det)
gz = grad(H, z, retain_graph=True)[0]
return rho - 0.5 * self.eps_lf * gz
rho_ = rho.clone()
for _ in range(steps):
rho_ = f_(rho_)
return rho_
def _leap_step_2(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3):
"""
Resolves second equation of generalized leapfrog integrator
using fixed point iterations
"""
H0 = self._hamiltonian(recon_x, x, z, rho, G_inv, G_log_det)
grho_0 = grad(H0, rho)[0]
def f_(z_):
H = self._hamiltonian(recon_x, x, z_, rho, G_inv, G_log_det)
grho = grad(H, rho, retain_graph=True)[0]
return z + 0.5 * self.eps_lf * (grho_0 + grho)
z_ = z.clone()
for _ in range(steps):
z_ = f_(z_)
return z_
def _leap_step_3(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3):
"""
Resolves third equation of generalized leapfrog integrator
"""
H = self._hamiltonian(recon_x, x, z, rho, G_inv, G_log_det)
gz = grad(H, z, create_graph=True)[0]
return rho - 0.5 * self.eps_lf * gz
def _hamiltonian(self, recon_x, x, z, rho, G_inv=None, G_log_det=None):
"""
Computes the Hamiltonian function.
used for RHVAE
"""
norm = (
torch.transpose(rho.unsqueeze(-1), 1, 2) @ G_inv @ rho.unsqueeze(-1)
).sum()
return -self._log_p_xz(recon_x, x, z).sum() + 0.5 * norm + 0.5 * G_log_det.sum()
def _update_metric(self):
# convert to 1 big tensor
self.M_tens = torch.cat(list(self.M))
self.centroids_tens = torch.cat(list(self.centroids))
# define new metric
def G(z):
return torch.inverse(
(
self.M_tens.unsqueeze(0)
* torch.exp(
-torch.norm(
self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1
)
** 2
/ (self.temperature**2)
)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1)
+ self.lbd * torch.eye(self.latent_dim).to(z.device)
)
def G_inv(z):
return (
self.M_tens.unsqueeze(0)
* torch.exp(
-torch.norm(
self.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1
)
** 2
/ (self.temperature**2)
)
.unsqueeze(-1)
.unsqueeze(-1)
).sum(dim=1) + self.lbd * torch.eye(self.latent_dim).to(z.device)
self.G = G
self.G_inv = G_inv
self.M = deque(maxlen=100)
self.centroids = deque(maxlen=100)
def loss_function(
self, recon_x, x, z0, zK, rhoK, eps0, gamma, mu, log_var, G_inv, G_log_det
):
logpxz = self._log_p_xz(recon_x, x, zK) # log p(x, z_K)
logrhoK = (
-0.5
* (torch.transpose(rhoK.unsqueeze(-1), 1, 2) @ G_inv @ rhoK.unsqueeze(-1))
.squeeze()
.squeeze()
- 0.5 * G_log_det
# - torch.log(torch.tensor([2 * np.pi]).to(x.device)) * self.latent_dim / 2
) # log p(\rho_K)
logp = logpxz + logrhoK
# define a N(0, I) distribution
normal = torch.distributions.MultivariateNormal(
loc=torch.zeros(self.latent_dim).to(x.device),
covariance_matrix=torch.eye(self.latent_dim).to(x.device),
)
logq = normal.log_prob(eps0) - 0.5 * log_var.sum(dim=1) # log(q(z_0|x))
return -(logp - logq).mean(dim=0)
def _sample_gauss(self, mu, std):
# Reparametrization trick
# Sample N(0, I)
eps = torch.randn_like(std)
return mu + eps * std, eps
def _tempering(self, k, K):
"""Perform tempering step"""
beta_k = (
(1 - 1 / self.beta_zero_sqrt) * (k / K) ** 2
) + 1 / self.beta_zero_sqrt
return 1 / beta_k
def _log_p_x_given_z(self, recon_x, x):
if self.model_config.reconstruction_loss == "mse":
# sigma is taken as I_D
recon_loss = -0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
-torch.log(torch.tensor([2 * np.pi]).to(x.device)) * np.prod(
self.input_dim
) / 2
elif self.model_config.reconstruction_loss == "bce":
recon_loss = -F.binary_cross_entropy(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
return recon_loss
def _log_z(self, z):
"""
Return Normal density function as prior on z
"""
# define a N(0, I) distribution
normal = torch.distributions.MultivariateNormal(
loc=torch.zeros(self.latent_dim).to(z.device),
covariance_matrix=torch.eye(self.latent_dim).to(z.device),
)
return normal.log_prob(z)
def _log_p_xz(self, recon_x, x, z):
"""
Estimate log(p(x, z)) using Bayes rule
"""
logpxz = self._log_p_x_given_z(recon_x, x)
logpz = self._log_z(z)
return logpxz + logpz
[docs] def get_nll(self, data, n_samples=1, batch_size=100):
"""
Function computed the estimate negative log-likelihood of the model. It uses importance
sampling method with the approximate posterior distribution. This may take a while.
Args:
data (torch.Tensor): The input data from which the log-likelihood should be estimated.
Data must be of shape [Batch x n_channels x ...]
n_samples (int): The number of importance samples to use for estimation
batch_size (int): The batchsize to use to avoid memory issues
"""
normal = torch.distributions.MultivariateNormal(
loc=torch.zeros(self.model_config.latent_dim).to(data.device),
covariance_matrix=torch.eye(self.model_config.latent_dim).to(data.device),
)
if n_samples <= batch_size:
n_full_batch = 1
else:
n_full_batch = n_samples // batch_size
n_samples = batch_size
log_p = []
for i in range(len(data)):
x = data[i].unsqueeze(0)
log_p_x = []
for j in range(n_full_batch):
x_rep = torch.cat(batch_size * [x])
encoder_output = self.encoder(x_rep)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance
std = torch.exp(0.5 * log_var)
z0, eps = self._sample_gauss(mu, std)
gamma = torch.randn_like(z0, device=x.device)
rho = gamma / self.beta_zero_sqrt
z = z0
beta_sqrt_old = self.beta_zero_sqrt
G = self.G(z0)
G_inv = self.G_inv(z0)
G_log_det = -torch.logdet(G_inv)
L = torch.linalg.cholesky(G)
# initialization
gamma = torch.randn_like(z0, device=z.device)
rho = gamma / self.beta_zero_sqrt
beta_sqrt_old = self.beta_zero_sqrt
rho = (L @ rho.unsqueeze(-1)).squeeze(
-1
) # sample from the multivariate N(0, G)
recon_x = self.decoder(z)["reconstruction"]
for k in range(self.n_lf):
# perform leapfrog steps
# step 1
rho_ = self._leap_step_1(recon_x, x_rep, z, rho, G_inv, G_log_det)
# step 2
z = self._leap_step_2(recon_x, x_rep, z, rho_, G_inv, G_log_det)
recon_x = self.decoder(z)["reconstruction"]
G_inv = self.G_inv(z)
G_log_det = -torch.logdet(G_inv)
# step 3
rho__ = self._leap_step_3(recon_x, x_rep, z, rho_, G_inv, G_log_det)
# tempering steps
beta_sqrt = self._tempering(k + 1, self.n_lf)
rho = (beta_sqrt_old / beta_sqrt) * rho__
beta_sqrt_old = beta_sqrt
log_q_z0_given_x = -0.5 * (
log_var + (z0 - mu) ** 2 / torch.exp(log_var)
).sum(dim=-1)
log_p_z = -0.5 * (z**2).sum(dim=-1)
log_p_rho0 = normal.log_prob(gamma) - torch.logdet(
L / self.beta_zero_sqrt
) # rho0 ~ N(0, 1/beta_0 * G(z0))
log_p_rho = (
-0.5
* (
torch.transpose(rho.unsqueeze(-1), 1, 2)
[docs] @ G_inv
@ rho.unsqueeze(-1)
)
.squeeze()
.squeeze()
- 0.5 * G_log_det
) - torch.log(
torch.tensor([2 * np.pi]).to(z.device)
) * self.latent_dim / 2 # rho0 ~ N(0, G(z))
if self.model_config.reconstruction_loss == "mse":
log_p_x_given_z = -0.5 * F.mse_loss(
recon_x.reshape(x_rep.shape[0], -1),
x_rep.reshape(x_rep.shape[0], -1),
reduction="none",
).sum(dim=-1) - torch.tensor(
[np.prod(self.input_dim) / 2 * np.log(np.pi * 2)]
).to(
data.device
) # decoding distribution is assumed unit variance N(mu, I)
elif self.model_config.reconstruction_loss == "bce":
log_p_x_given_z = -F.binary_cross_entropy(
recon_x.reshape(x_rep.shape[0], -1),
x_rep.reshape(x_rep.shape[0], -1),
reduction="none",
).sum(dim=-1)
log_p_x.append(
log_p_x_given_z
+ log_p_z
+ log_p_rho
- log_p_rho0
- log_q_z0_given_x
) # N*log(2*pi) simplifies in prior and posterior
log_p_x = torch.cat(log_p_x)
log_p.append((torch.logsumexp(log_p_x, 0) - np.log(len(log_p_x))).item())
return np.mean(log_p)
def save(self, dir_path: str):
"""Method to save the model at a specific location
Args:
dir_path (str): The path where the model should be saved. If the path
path does not exist a folder will be created at the provided location.
"""
# This creates the dir if not available
super().save(dir_path)
model_path = dir_path
model_dict = {
"M": deepcopy(self.M_tens.clone().detach()),
"centroids": deepcopy(self.centroids_tens.clone().detach()),
"model_state_dict": deepcopy(self.state_dict()),
}
if not self.model_config.uses_default_metric:
with open(os.path.join(model_path, "metric.pkl"), "wb") as fp:
cloudpickle.register_pickle_by_value(inspect.getmodule(self.metric))
cloudpickle.dump(self.metric, fp)
torch.save(model_dict, os.path.join(model_path, "model.pt"))
@classmethod
def _load_custom_metric_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)
if "metric.pkl" not in file_list:
raise FileNotFoundError(
f"Missing metric pkl file ('metric.pkl') in"
f"{dir_path}... This file is needed to rebuild custom metrics."
" Cannot perform model building."
)
else:
with open(os.path.join(dir_path, "metric.pkl"), "rb") as fp:
metric = CPU_Unpickler(fp).load()
return metric
@classmethod
def _load_metric_matrices_and_centroids(cls, dir_path):
"""this function can be called safely since it is called after
_load_model_weights_from_folder which handles FileNotFoundError and
loading issues"""
path_to_model_weights = os.path.join(dir_path, "model.pt")
model_weights = torch.load(path_to_model_weights, map_location="cpu")
if "M" not in model_weights.keys():
raise KeyError(
"Metric M matrices are not available in 'model.pt' file. Got keys:"
f"{model_weights.keys()}. These are needed to build the metric."
)
metric_M = model_weights["M"]
if "centroids" not in model_weights.keys():
raise KeyError(
"Metric centroids are not available in 'model.pt' file. Got keys:"
f"{model_weights.keys()}. These are needed to build the metric."
)
metric_centroids = model_weights["centroids"]
return metric_M, metric_centroids
[docs] @classmethod
def load_from_folder(cls, dir_path):
"""Class method to be used to load the model from a specific folder
Args:
dir_path (str): The path where the model should have been be saved.
.. note::
This function requires the folder to contain:
- | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided
**or**
- | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
``decoder.pkl`` or/and ``metric.pkl``) if a custom encoder (resp. decoder or/and
metric) was provided
"""
model_config = cls._load_model_config_from_folder(dir_path)
model_weights = cls._load_model_weights_from_folder(dir_path)
if not model_config.uses_default_encoder:
encoder = cls._load_custom_encoder_from_folder(dir_path)
else:
encoder = None
if not model_config.uses_default_decoder:
decoder = cls._load_custom_decoder_from_folder(dir_path)
else:
decoder = None
if not model_config.uses_default_metric:
metric = cls._load_custom_metric_from_folder(dir_path)
else:
metric = None
model = cls(model_config, encoder=encoder, decoder=decoder, metric=metric)
metric_M, metric_centroids = cls._load_metric_matrices_and_centroids(dir_path)
model.M_tens = metric_M
model.centroids_tens = metric_centroids
model.G = create_metric(model)
model.G_inv = create_inverse_metric(model)
model.load_state_dict(model_weights)
return model
[docs] @classmethod
def load_from_hf_hub(
cls, hf_hub_path: str, allow_pickle: bool = False
): # pragma: no cover
"""Class method to be used to load a pretrained model from the Hugging Face hub
Args:
hf_hub_path (str): The path where the model should have been be saved on the
hugginface hub.
.. note::
This function requires the folder to contain:
- | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided
**or**
- | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
``decoder.pkl`` and ``metric.pkl``) if a custom encoder (resp. decoder and/or
metric) was provided
"""
if not hf_hub_is_available():
raise ModuleNotFoundError(
"`huggingface_hub` package must be installed to load models from the HF hub. "
"Run `python -m pip install huggingface_hub` and log in to your account with "
"`huggingface-cli login`."
)
else:
from huggingface_hub import hf_hub_download
logger.info(f"Downloading {cls.__name__} files for rebuilding...")
config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json")
dir_path = os.path.dirname(config_path)
_ = hf_hub_download(repo_id=hf_hub_path, filename="model.pt")
model_config = cls._load_model_config_from_folder(dir_path)
if (
cls.__name__ + "Config" != model_config.name
and cls.__name__ + "_Config" != model_config.name
):
warnings.warn(
f"You are trying to load a "
f"`{ cls.__name__}` while a "
f"`{model_config.name}` is given."
)
model_weights = cls._load_model_weights_from_folder(dir_path)
if (
not model_config.uses_default_encoder
or not model_config.uses_default_decoder
or not model_config.uses_default_metric
) and not allow_pickle:
warnings.warn(
"You are about to download pickled files from the HF hub that may have "
"been created by a third party and so could potentially harm your computer. If you "
"are sure that you want to download them set `allow_pickle=true`."
)
else:
if not model_config.uses_default_encoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pkl")
encoder = cls._load_custom_encoder_from_folder(dir_path)
else:
encoder = None
if not model_config.uses_default_decoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pkl")
decoder = cls._load_custom_decoder_from_folder(dir_path)
else:
decoder = None
if not model_config.uses_default_metric:
_ = hf_hub_download(repo_id=hf_hub_path, filename="metric.pkl")
metric = cls._load_custom_metric_from_folder(dir_path)
else:
metric = None
logger.info(f"Successfully downloaded {cls.__name__} model!")
model = cls(model_config, encoder=encoder, decoder=decoder, metric=metric)
metric_M, metric_centroids = cls._load_metric_matrices_and_centroids(
dir_path
)
model.M_tens = metric_M
model.centroids_tens = metric_centroids
model.G = create_metric(model)
model.G_inv = create_inverse_metric(model)
model.load_state_dict(model_weights)
return model