CoupledOptimizerAdversarialTrainer

This module implements the trainer to be used when using adversarial models. Contrary to AdversarialTrainer it uses three distinct optimizers, one for the encoder, one for the decoder of the AE and one for the discriminator. It is suitable for GAN based models models.

Available models:

Adversarial_AE

Adversarial Autoencoder model.

VAEGAN

Variational Autoencoder using Adversarial reconstruction loss model.

class pythae.trainers.CoupledOptimizerAdversarialTrainerConfig(output_dir=None, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_epochs=100, train_dataloader_num_workers=0, eval_dataloader_num_workers=0, optimizer_cls='Adam', optimizer_params=None, scheduler_cls=None, scheduler_params=None, learning_rate=0.0001, steps_saving=None, steps_predict=None, keep_best_on_train=False, seed=8, no_cuda=False, world_size=- 1, local_rank=- 1, rank=- 1, dist_backend='nccl', master_addr='localhost', master_port='12345', amp=False, encoder_optimizer_cls='Adam', encoder_optimizer_params=None, encoder_scheduler_cls=None, encoder_scheduler_params=None, discriminator_optimizer_cls='Adam', decoder_optimizer_cls='Adam', decoder_optimizer_params=None, decoder_scheduler_cls=None, decoder_scheduler_params=None, discriminator_optimizer_params=None, discriminator_scheduler_cls=None, discriminator_scheduler_params=None, encoder_learning_rate=0.0001, decoder_learning_rate=0.0001, discriminator_learning_rate=0.0001)[source]

CoupledOptimizerAdversarialTrainer config class.

Parameters
  • output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.

  • per_device_train_batch_size (int) – The number of training samples per batch and per device. Default 64

  • per_device_eval_batch_size (int) – The number of evaluation samples per batch and per device. Default 64

  • num_epochs (int) – The maximal number of epochs for training. Default: 100

  • train_dataloader_num_workers (int) – Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0

  • eval_dataloader_num_workers (int) – Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0

  • encoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the encoder. Default: Adam.

  • encoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the encoder. If None, uses the default parameters. Default: None.

  • encoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the encoder. Default Adam.

  • encoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the encoder. If None, uses the default parameters. Default: None.

  • decoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the decoder. Default: Adam.

  • decoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the decoder. If None, uses the default parameters. Default: None.

  • decoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the decoder. Default Adam.

  • decoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the decoder. If None, uses the default parameters. Default: None.

  • discriminator_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the discriminator. Default: Adam.

  • discriminator_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the discriminator. If None, uses the default parameters. Default: None.

  • discriminator_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the discriminator. Default Adam.

  • discriminator_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the discriminator. If None, uses the default parameters. Default: None.

  • encoder_learning_rate (int) – The learning rate applied to the Optimizer for the encoder. Default: 1e-4

  • decoder_learning_rate (int) – The learning rate applied to the Optimizer for the encoder. Default: 1e-4

  • discriminator_learning_rate (int) – The learning rate applied to the Optimizer for the discriminator. Default: 1e-4

  • steps_saving (int) – A model checkpoint will be saved every steps_saving epoch. Default: None

  • steps_saving – A prediction using the best model will be run every steps_predict epoch. Default: None

  • keep_best_on_train (bool) – Whether to keep the best model on the train set. Default: False

  • seed (int) – The random seed for reproducibility

  • no_cuda (bool) – Disable cuda training. Default: False

  • world_size (int) – The total number of process to run. Default: -1

  • local_rank (int) – The rank of the node for distributed training. Default: -1

  • rank (int) – The rank of the process for distributed training. Default: -1

  • dist_backend (str) – The distributed backend to use. Default: ‘nccl’

  • master_addr (str) – The master address for distributed training. Default: ‘localhost’

  • master_port (str) – The master port for distributed training. Default: ‘12345’

class pythae.trainers.CoupledOptimizerAdversarialTrainer(model, train_dataset, eval_dataset=None, training_config=None, callbacks=None)[source]

Trainer using distinct optimizers for the encoder, decoder and discriminator.

Parameters
  • model (BaseAE) – The model to train

  • train_dataset (BaseDataset) – The training dataset of type BaseDataset

  • training_args (CoupledOptimizerAdversarialTrainerConfig) – The training arguments summarizing the main parameters used for training. If None, a basic training instance of AdversarialTrainerConfig is used. Default: None.

  • encoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the encoder. If None, a Adam optimizer is used. Default: None.

  • decoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the decoder. If None, a Adam optimizer is used. Default: None.

  • discriminator_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the discriminator. If None, a Adam optimizer is used. Default: None.

eval_step(epoch)[source]

Perform an evaluation step

Parameters

epoch (int) – The current epoch number

Returns

The evaluation loss

Return type

(torch.Tensor)

prepare_training()[source]

Sets up the trainer for training

save_checkpoint(model, dir_path, epoch)[source]

Saves a checkpoint alowing to restart training from here

Parameters
  • dir_path (str) – The folder where the checkpoint should be saved

  • epochs_signature (int) – The epoch number

train(log_output_dir=None)[source]

This function is the main training function

Parameters

log_output_dir (str) – The path in which the log will be stored

train_step(epoch)[source]

The trainer performs training loop over the train_loader.

Parameters

epoch (int) – The current epoch number

Returns

The step training loss

Return type

(torch.Tensor)