Source code for pythae.samplers.base.base_sampler

import logging
import os

import numpy as np
import torch
from imageio import imwrite

from ...models import BaseAE
from .base_sampler_config import BaseSamplerConfig

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class BaseSampler: """Base class for samplers used to generate from the VAEs models. Args: model (BaseAE): The vae model to sample from. sampler_config (BaseSamplerConfig): An instance of BaseSamplerConfig in which any sampler's parameters is made available. If None a default configuration is used. Default: None """ def __init__(self, model: BaseAE, sampler_config: BaseSamplerConfig = None): if sampler_config is None: sampler_config = BaseSamplerConfig() self.model = model self.model.eval() self.sampler_config = sampler_config self.is_fitted = False device = "cuda" if torch.cuda.is_available() else "cpu" self.device = device self.model.device = device self.model.to(device)
[docs] def fit(self, *args, **kwargs): """Function to be called to fit the sampler before sampling""" pass
[docs] def sample( self, num_samples: int = 1, batch_size: int = 500, output_dir: str = None, return_gen: bool = True, ): """Main sampling function of the sampler. Args: num_samples (int): The number of samples to generate batch_size (int): The batch size to use during sampling output_dir (str): The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None. return_gen (bool): Whether the sampler should directly return a tensor of generated data. Default: True. Returns: ~torch.Tensor: The generated images """ raise NotImplementedError()
[docs] def save(self, dir_path): """Method to save the sampler config. The config is saved a as ``sampler_config.json`` file in ``dir_path``""" self.sampler_config.save_json(dir_path, "sampler_config")
[docs] def save_img(self, img_tensor: torch.Tensor, dir_path: str, img_name: str): """Saves a data point as .png file in dir_path with img_name as name. Args: img_tensor (torch.Tensor): The image of shape CxHxW in the range [0-1] dir_path (str): The folder where in which the images must be saved ig_name (str): The name to apply to the file containing the image. """ if not os.path.exists(dir_path): os.makedirs(dir_path) print(f"--> Created folder {dir_path}. Images will be saved here") img = 255.0 * torch.movedim(img_tensor, 0, 2).cpu().detach().numpy() if img.shape[-1] == 1: img = np.repeat(img, repeats=3, axis=-1) img = img.astype("uint8") imwrite(os.path.join(dir_path, f"{img_name}"), img)