BaseSampler

Abstract class

This is the base Sampler architecture module from which all future samplers should inherit.

class pythae.samplers.BaseSamplerConfig[source]

BaseSampler config class.

class pythae.samplers.BaseSampler(model, sampler_config=None)[source]

Base class for samplers used to generate from the VAEs models.

Parameters
  • model (BaseAE) – The vae model to sample from.

  • sampler_config (BaseSamplerConfig) – An instance of BaseSamplerConfig in which any sampler’s parameters is made available. If None a default configuration is used. Default: None

fit(*args, **kwargs)[source]

Function to be called to fit the sampler before sampling

sample(num_samples=1, batch_size=500, output_dir=None, return_gen=True)[source]

Main sampling function of the sampler.

Parameters
  • num_samples (int) – The number of samples to generate

  • batch_size (int) – The batch size to use during sampling

  • output_dir (str) – The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None.

  • return_gen (bool) – Whether the sampler should directly return a tensor of generated data. Default: True.

Returns

The generated images

Return type

Tensor

save(dir_path)[source]

Method to save the sampler config. The config is saved a as sampler_config.json file in dir_path

save_img(img_tensor, dir_path, img_name)[source]

Saves a data point as .png file in dir_path with img_name as name.

Parameters
  • img_tensor (torch.Tensor) – The image of shape CxHxW in the range [0-1]

  • dir_path (str) – The folder where in which the images must be saved

  • ig_name (str) – The name to apply to the file containing the image.