Source code for pythae.samplers.manifold_sampler.rhvae_sampler

import torch

from ...models import RHVAE
from ..base import BaseSampler
from .rhvae_sampler_config import RHVAESamplerConfig


[docs]class RHVAESampler(BaseSampler): """Sampling form the inverse of the metric volume element of a :class:`~pythae.models.RHVAE` model. Args: model (RHVAE): The VAE model to sample from sampler_config (RHVAESamplerConfig): A RHVAESamplerConfig instance containing the main parameters of the sampler. If None, a pre-defined configuration is used. Default: None """ def __init__(self, model: RHVAE, sampler_config: RHVAESamplerConfig = None): if sampler_config is None: sampler_config = RHVAESamplerConfig() BaseSampler.__init__(self, model=model, sampler_config=sampler_config) self.model.M_tens = self.model.M_tens.to(self.device) self.model.centroids_tens = self.model.centroids_tens.to(self.device) self.mcmc_steps_nbr = sampler_config.mcmc_steps_nbr self.n_lf = torch.tensor([sampler_config.n_lf]).to(self.device) self.eps_lf = torch.tensor([sampler_config.eps_lf]).to(self.device) self.beta_zero_sqrt = ( torch.tensor([sampler_config.beta_zero]).to(self.device).sqrt() ) self.log_pi = RHVAESampler.log_sqrt_det_G_inv self.grad_func = RHVAESampler.grad_log_prop
[docs] def sample( self, num_samples: int = 1, batch_size: int = 500, output_dir: str = None, return_gen: bool = True, save_sampler_config: bool = False, ) -> torch.Tensor: """Main sampling function of the sampler. Args: num_samples (int): The number of samples to generate batch_size (int): The batch size to use during sampling output_dir (str): The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None. return_gen (bool): Whether the sampler should directly return a tensor of generated data. Default: True. save_sampler_config (bool): Whether to save the sampler config. It is saved in output_dir Returns: ~torch.Tensor: The generated images """ full_batch_nbr = int(num_samples / batch_size) last_batch_samples_nbr = num_samples % batch_size x_gen_list = [] for i in range(full_batch_nbr): samples = self.hmc_sampling(batch_size) x_gen = self.model.decoder(z=samples)["reconstruction"].detach() if output_dir is not None: for j in range(batch_size): self.save_img( x_gen[j], output_dir, "%08d.png" % int(batch_size * i + j) ) x_gen_list.append(x_gen) if last_batch_samples_nbr > 0: samples = self.hmc_sampling(last_batch_samples_nbr) x_gen = self.model.decoder(z=samples)["reconstruction"].detach() if output_dir is not None: for j in range(last_batch_samples_nbr): self.save_img( x_gen[j], output_dir, "%08d.png" % int(batch_size * full_batch_nbr + j), ) x_gen_list.append(x_gen) if save_sampler_config: self.save(output_dir) if return_gen: return torch.cat(x_gen_list, dim=0)
def hmc_sampling(self, n_samples: int): with torch.no_grad(): idx = torch.randint(len(self.model.centroids_tens), (n_samples,)) z0 = self.model.centroids_tens[idx] beta_sqrt_old = self.beta_zero_sqrt z = z0 for i in range(self.mcmc_steps_nbr): gamma = torch.randn_like(z, device=self.device) rho = gamma / self.beta_zero_sqrt H0 = -self.log_pi(z, self.model) + 0.5 * torch.norm(rho, dim=1) ** 2 # print(model.G_inv(z).det()) for k in range(self.n_lf): g = -self.grad_func(z, self.model).reshape( n_samples, self.model.latent_dim ) # step 1 rho_ = rho - (self.eps_lf / 2) * g # step 2 z = z + self.eps_lf * rho_ g = -self.grad_func(z, self.model).reshape( n_samples, self.model.latent_dim ) # g = (Sigma_inv @ (z - mu).T).reshape(n_samples, 2) # step 3 rho__ = rho_ - (self.eps_lf / 2) * g # tempering beta_sqrt = RHVAESampler.tempering( k + 1, self.n_lf, self.beta_zero_sqrt ) rho = (beta_sqrt_old / beta_sqrt) * rho__ beta_sqrt_old = beta_sqrt H = -self.log_pi(z, self.model) + 0.5 * torch.norm(rho, dim=1) ** 2 alpha = torch.exp(-H) / (torch.exp(-H0)) acc = torch.rand(n_samples).to(self.device) moves = (acc < alpha).type(torch.int).reshape(n_samples, 1) z = z * moves + (1 - moves) * z0 z0 = z return z @staticmethod def tempering(k, K, beta_zero_sqrt): beta_k = ((1 - 1 / beta_zero_sqrt) * (k / K) ** 2) + 1 / beta_zero_sqrt return 1 / beta_k @staticmethod def log_sqrt_det_G_inv(z, model): return torch.log(torch.sqrt(torch.det(model.G_inv(z))) + 1e-10) @staticmethod def grad_log_sqrt_det_G_inv(z, model): return ( -0.5 * torch.transpose(model.G(z), 1, 2) @ torch.transpose( ( -2 / (model.temperature ** 2) * (model.centroids_tens.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(2) @ ( model.M_tens.unsqueeze(0) * torch.exp( -torch.norm( model.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1, ) ** 2 / (model.temperature ** 2) ) .unsqueeze(-1) .unsqueeze(-1) ) ).sum(dim=1), 1, 2, ) ) @staticmethod def grad_log_prop(z, model): def grad_func(z, model): return RHVAESampler.grad_log_sqrt_det_G_inv(z, model) return grad_func(z, model)