PixelCNNSampler

Sampler fitting a PixelCNN in the VQVAE’s latent space.

Available models:

VQVAE

Vector Quantized-VAE model.

class pythae.samplers.PixelCNNSamplerConfig(n_layers=10, kernel_size=5)[source]

This is the PixelCNN sampler configuration instance.

Parameters
  • input_dim (tuple) – The input data dimension. Default: None.

  • n_layers (int) – The number of convolutional layers in the model. Default: 10.

  • kernel_size (int) – The kernel size in the convolutional layers. It must be odd. Default: 5

class pythae.samplers.PixelCNNSampler(model, sampler_config=None)[source]

Fits a PixelCNN in the VQVAE’s latent space.

Parameters
  • model (VQVAE) – The AE model to sample from

  • sampler_config (PixelCNNSamplerConfig) – A PixelCNNSamplerConfig instance containing the main parameters of the sampler. If None, a pre-defined configuration is used. Default: None

Note

The method fit must be called to fit the sampler before sampling.

fit(train_data, eval_data=None, training_config=None, batch_size=64)[source]

Method to fit the sampler from the training data

Parameters
  • train_data (Union[torch.Tensor, np.ndarray, Dataset]) – The train data needed to retrieve the training embeddings and fit the PixelCNN model in the latent space.

  • eval_data (Union[torch.Tensor, np.ndarray, Dataset]) – The train data needed to retrieve the evaluation embeddings and fit the PixelCNN model in the latent space.

  • training_config (BaseTrainerConfig) – the training config to use to fit the flow.

  • batch_size (int) – The batch size to use to retrieve the embeddings. Default: 64.

sample(num_samples=1, batch_size=500, output_dir=None, return_gen=True, save_sampler_config=False)[source]

Main sampling function of the sampler.

Parameters
  • num_samples (int) – The number of samples to generate

  • batch_size (int) – The batch size to use during sampling

  • output_dir (str) – The directory where the images will be saved. If does not exist the folder is created. If None: the images are not saved. Defaults: None.

  • return_gen (bool) – Whether the sampler should directly return a tensor of generated data. Default: True.

  • save_sampler_config (bool) – Whether to save the sampler config. It is saved in output_dir

Returns

The generated images

Return type

Tensor