CoupledOptimizerTrainer

This module implements the dual optimizer trainer using two distinct optimizers for the encoder and the decoder. It is suitable for all models but must be used in particular to train a RAE_L2.

Available models:

AE

Vanilla Autoencoder model.

VAE

Vanilla Variational Autoencoder model.

BetaVAE

\(\beta\)-VAE model.

DisentangledBetaVAE

Disentangled \(\beta\)-VAE model.

BetaTCVAE

\(\beta\)-TCVAE model.

IWAE

Importance Weighted Autoencoder model.

MSSSIM_VAE

VAE using perseptual similarity metrics model.

INFOVAE_MMD

Info Variational Autoencoder model.

WAE_MMD

Wasserstein Autoencoder model.

VAMP

Variational Mixture of Posteriors (VAMP) VAE model

SVAE

\(\mathcal{S}\)-VAE model.

VQVAE

Vector Quantized-VAE model.

RAE_GP

Regularized Autoencoder with gradient penalty model.

RAE_L2

Regularized Autoencoder with L2 decoder params regularization model.

HVAE

Hamiltonian VAE.

RHVAE

Riemannian Hamiltonian VAE model.

class pythae.trainers.CoupledOptimizerTrainerConfig(output_dir=None, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_epochs=100, train_dataloader_num_workers=0, eval_dataloader_num_workers=0, optimizer_cls='Adam', optimizer_params=None, scheduler_cls=None, scheduler_params=None, learning_rate=0.0001, steps_saving=None, steps_predict=None, keep_best_on_train=False, seed=8, no_cuda=False, world_size=- 1, local_rank=- 1, rank=- 1, dist_backend='nccl', master_addr='localhost', master_port='12345', amp=False, encoder_optimizer_cls='Adam', encoder_optimizer_params=None, encoder_scheduler_cls=None, encoder_scheduler_params=None, decoder_optimizer_cls='Adam', decoder_optimizer_params=None, decoder_scheduler_cls=None, decoder_scheduler_params=None, encoder_learning_rate=0.0001, decoder_learning_rate=0.0001)[source]

CoupledOptimizerTrainer config class.

Parameters
  • output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.

  • per_device_train_batch_size (int) – The number of training samples per batch and per device. Default 64

  • per_device_eval_batch_size (int) – The number of evaluation samples per batch and per device. Default 64

  • num_epochs (int) – The maximal number of epochs for training. Default: 100

  • train_dataloader_num_workers (int) – Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0

  • eval_dataloader_num_workers (int) – Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0

  • encoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the encoder. Default: Adam.

  • encoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the encoder. If None, uses the default parameters. Default: None.

  • encoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the encoder. Default Adam.

  • encoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the encoder. If None, uses the default parameters. Default: None.

  • decoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the decoder. Default: Adam.

  • decoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the decoder. If None, uses the default parameters. Default: None.

  • decoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the decoder. Default Adam.

  • decoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the decoder. If None, uses the default parameters. Default: None.

  • encoder_learning_rate (int) – The learning rate applied to the Optimizer for the encoder. Default: 1e-4

  • decoder_learning_rate (int) – The learning rate applied to the Optimizer for the decoder. Default: 1e-4

  • steps_saving (int) – A model checkpoint will be saved every steps_saving epoch. Default: None

  • steps_saving – A prediction using the best model will be run every steps_predict epoch. Default: None

  • keep_best_on_train (bool) – Whether to keep the best model on the train set. Default: False

  • seed (int) – The random seed for reproducibility

  • no_cuda (bool) – Disable cuda training. Default: False

  • world_size (int) – The total number of process to run. Default: -1

  • local_rank (int) – The rank of the node for distributed training. Default: -1

  • rank (int) – The rank of the process for distributed training. Default: -1

  • dist_backend (str) – The distributed backend to use. Default: ‘nccl’

  • master_addr (str) – The master address for distributed training. Default: ‘localhost’

  • master_port (str) – The master port for distributed training. Default: ‘12345’

class pythae.trainers.CoupledOptimizerTrainer(model, train_dataset, eval_dataset=None, training_config=None, callbacks=None)[source]

Trainer using distinct optimizers for encoder and decoder nn.

Parameters
  • model (BaseAE) – The model to train

  • train_dataset (BaseDataset) – The training dataset of type BaseDataset

  • training_args (CoupledOptimizerTrainerConfig) – The training arguments summarizing the main parameters used for training. If None, a basic training instance of CoupledOptimizerTrainerConfig is used. Default: None.

  • encoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the encoder. If None, a Adam optimizer is used. Default: None.

  • decoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the decoder. If None, a Adam optimizer is used. Default: None.

eval_step(epoch)[source]

Perform an evaluation step

Parameters

epoch (int) – The current epoch number

Returns

The evaluation loss

Return type

(torch.Tensor)

prepare_training()[source]

Sets up the trainer for training

save_checkpoint(model, dir_path, epoch)[source]

Saves a checkpoint alowing to restart training from here

Parameters
  • dir_path (str) – The folder where the checkpoint should be saved

  • epochs_signature (int) – The epoch number

train(log_output_dir=None)[source]

This function is the main training function

Parameters

log_output_dir (str) – The path in which the log will be stored

train_step(epoch)[source]

The trainer performs training loop over the train_loader.

Parameters

epoch (int) – The current epoch number

Returns

The step training loss

Return type

(torch.Tensor)