Source code for pythae.samplers.vamp_sampler.vamp_sampler

import torch

from ...models import VAMP
from ..base import BaseSampler
from .vamp_sampler_config import VAMPSamplerConfig


[docs]class VAMPSampler(BaseSampler): """Sampling from the VAMP prior. Args: model (VAMP): The vae model to sample from. sampler_config (BaseSamplerConfig): An instance of BaseSamplerConfig in which any sampler's parameters is made available. If None a default configuration is used. Default: None """ def __init__(self, model: VAMP, sampler_config: VAMPSamplerConfig = None): assert isinstance(model, VAMP), "This sampler is only suitable for VAMP model" if sampler_config is None: sampler_config = VAMPSamplerConfig() BaseSampler.__init__(self, model=model, sampler_config=sampler_config)
[docs] def sample( self, num_samples: int = 1, batch_size: int = 500, output_dir: str = None, return_gen: bool = True, save_sampler_config: bool = False, ) -> torch.Tensor: """Main sampling function of the sampler. Args: num_samples (int): The number of samples to generate batch_size (int): The batch size to use during sampling output_dir (str): The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None. return_gen (bool): Whether the sampler should directly return a tensor of generated data. Default: True. save_sampler_config (bool): Whether to save the sampler config. It is saved in output_dir Returns: ~torch.Tensor: The generated images """ batch_size = min(self.model.idle_input.shape[0], batch_size) full_batch_nbr = int(num_samples / batch_size) last_batch_samples_nbr = num_samples % batch_size x_gen_list = [] for i in range(full_batch_nbr): means = self.model.pseudo_inputs(self.model.idle_input.to(self.device))[ :batch_size ].reshape((batch_size,) + self.model.model_config.input_dim) encoder_output = self.model.encoder(means) mu, log_var = ( encoder_output.embedding, torch.tanh(encoder_output.log_covariance), ) std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) z = mu + eps * std x_gen = self.model.decoder(z)["reconstruction"].detach() if output_dir is not None: for j in range(batch_size): self.save_img( x_gen[j], output_dir, "%08d.png" % int(batch_size * i + j) ) x_gen_list.append(x_gen) if last_batch_samples_nbr > 0: means = self.model.pseudo_inputs(self.model.idle_input.to(self.device))[ :last_batch_samples_nbr ].reshape((last_batch_samples_nbr,) + self.model.model_config.input_dim) encoder_output = self.model.encoder(means) mu, log_var = ( encoder_output.embedding, torch.tanh(encoder_output.log_covariance), ) std = torch.exp(0.5 * log_var) eps = torch.randn_like(std) z = mu + eps * std x_gen = self.model.decoder(z)["reconstruction"].detach() if output_dir is not None: for j in range(last_batch_samples_nbr): self.save_img( x_gen[j], output_dir, "%08d.png" % int(batch_size * full_batch_nbr + j), ) x_gen_list.append(x_gen) if save_sampler_config: self.save(output_dir) if return_gen: return torch.cat(x_gen_list, dim=0)