CoupledOptimizerAdversarialTrainer¶
This module implements the trainer to be used when using adversarial models. Contrary to
AdversarialTrainer it uses three distinct
optimizers, one for the encoder, one for the decoder of the AE and one for the discriminator.
It is suitable for GAN based models models.
Available models:¶
Adversarial Autoencoder model. |
|
Variational Autoencoder using Adversarial reconstruction loss model. |
- class pythae.trainers.CoupledOptimizerAdversarialTrainerConfig(output_dir=None, per_device_train_batch_size=64, per_device_eval_batch_size=64, num_epochs=100, train_dataloader_num_workers=0, eval_dataloader_num_workers=0, optimizer_cls='Adam', optimizer_params=None, scheduler_cls=None, scheduler_params=None, learning_rate=0.0001, steps_saving=None, steps_predict=None, keep_best_on_train=False, seed=8, no_cuda=False, world_size=- 1, local_rank=- 1, rank=- 1, dist_backend='nccl', master_addr='localhost', master_port='12345', amp=False, encoder_optimizer_cls='Adam', encoder_optimizer_params=None, encoder_scheduler_cls=None, encoder_scheduler_params=None, discriminator_optimizer_cls='Adam', decoder_optimizer_cls='Adam', decoder_optimizer_params=None, decoder_scheduler_cls=None, decoder_scheduler_params=None, discriminator_optimizer_params=None, discriminator_scheduler_cls=None, discriminator_scheduler_params=None, encoder_learning_rate=0.0001, decoder_learning_rate=0.0001, discriminator_learning_rate=0.0001)[source]¶
CoupledOptimizerAdversarialTrainer config class.
- Parameters
output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.
per_device_train_batch_size (int) – The number of training samples per batch and per device. Default 64
per_device_eval_batch_size (int) – The number of evaluation samples per batch and per device. Default 64
num_epochs (int) – The maximal number of epochs for training. Default: 100
train_dataloader_num_workers (int) – Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0
eval_dataloader_num_workers (int) – Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0
encoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the encoder. Default:
Adam.encoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the encoder. If None, uses the default parameters. Default: None.
encoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the encoder. Default
Adam.encoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the encoder. If None, uses the default parameters. Default: None.
decoder_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the decoder. Default:
Adam.decoder_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the decoder. If None, uses the default parameters. Default: None.
decoder_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the decoder. Default
Adam.decoder_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the decoder. If None, uses the default parameters. Default: None.
discriminator_optimizer_cls (str) – The name of the torch.optim.Optimizer used for the training of the discriminator. Default:
Adam.discriminator_optimizer_params (dict) – A dict containing the parameters to use for the torch.optim.Optimizer for the discriminator. If None, uses the default parameters. Default: None.
discriminator_scheduler_cls (str) – The name of the torch.optim.lr_scheduler used for the training of the discriminator. Default
Adam.discriminator_scheduler_params (dict) – A dict containing the parameters to use for the torch.optim.le_scheduler for the discriminator. If None, uses the default parameters. Default: None.
encoder_learning_rate (int) – The learning rate applied to the Optimizer for the encoder. Default: 1e-4
decoder_learning_rate (int) – The learning rate applied to the Optimizer for the encoder. Default: 1e-4
discriminator_learning_rate (int) – The learning rate applied to the Optimizer for the discriminator. Default: 1e-4
steps_saving (int) – A model checkpoint will be saved every steps_saving epoch. Default: None
steps_saving – A prediction using the best model will be run every steps_predict epoch. Default: None
keep_best_on_train (bool) – Whether to keep the best model on the train set. Default: False
seed (int) – The random seed for reproducibility
no_cuda (bool) – Disable cuda training. Default: False
world_size (int) – The total number of process to run. Default: -1
local_rank (int) – The rank of the node for distributed training. Default: -1
rank (int) – The rank of the process for distributed training. Default: -1
dist_backend (str) – The distributed backend to use. Default: ‘nccl’
master_addr (str) – The master address for distributed training. Default: ‘localhost’
master_port (str) – The master port for distributed training. Default: ‘12345’
- class pythae.trainers.CoupledOptimizerAdversarialTrainer(model, train_dataset, eval_dataset=None, training_config=None, callbacks=None)[source]¶
Trainer using distinct optimizers for the encoder, decoder and discriminator.
- Parameters
model (BaseAE) – The model to train
train_dataset (BaseDataset) – The training dataset of type
BaseDatasettraining_args (CoupledOptimizerAdversarialTrainerConfig) – The training arguments summarizing the main parameters used for training. If None, a basic training instance of
AdversarialTrainerConfigis used. Default: None.encoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the encoder. If None, a
Adamoptimizer is used. Default: None.decoder_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the decoder. If None, a
Adamoptimizer is used. Default: None.discriminator_optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training the discriminator. If None, a
Adamoptimizer is used. Default: None.
- eval_step(epoch)[source]¶
Perform an evaluation step
- Parameters
epoch (int) – The current epoch number
- Returns
The evaluation loss
- Return type
- save_checkpoint(model, dir_path, epoch)[source]¶
Saves a checkpoint alowing to restart training from here
- train(log_output_dir=None)[source]¶
This function is the main training function
- Parameters
log_output_dir (str) – The path in which the log will be stored