Source code for pythae.trainers.coupled_optimizer_adversarial_trainer.coupled_optimizer_adversarial_trainer

import datetime
import dis
import logging
import os
from copy import deepcopy
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler

from ...data.datasets import BaseDataset
from ...models import BaseAE
from ..base_trainer import BaseTrainer
from ..trainer_utils import set_seed
from ..training_callbacks import TrainingCallback
from .coupled_optimizer_adversarial_trainer_config import (
    CoupledOptimizerAdversarialTrainerConfig,
)

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class CoupledOptimizerAdversarialTrainer(BaseTrainer): """Trainer using distinct optimizers for the encoder, decoder and discriminator. Args: model (BaseAE): The model to train train_dataset (BaseDataset): The training dataset of type :class:`~pythae.data.dataset.BaseDataset` training_args (CoupledOptimizerAdversarialTrainerConfig): The training arguments summarizing the main parameters used for training. If None, a basic training instance of :class:`AdversarialTrainerConfig` is used. Default: None. encoder_optimizer (~torch.optim.Optimizer): An instance of `torch.optim.Optimizer` used for training the encoder. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None. decoder_optimizer (~torch.optim.Optimizer): An instance of `torch.optim.Optimizer` used for training the decoder. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None. discriminator_optimizer (~torch.optim.Optimizer): An instance of `torch.optim.Optimizer` used for training the discriminator. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None. """ def __init__( self, model: BaseAE, train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[CoupledOptimizerAdversarialTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): BaseTrainer.__init__( self, model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, training_config=training_config, callbacks=callbacks, ) def set_encoder_optimizer(self): encoder_optimizer_cls = getattr( optim, self.training_config.encoder_optimizer_cls ) if self.training_config.encoder_optimizer_params is not None: if self.distributed: encoder_optimizer = encoder_optimizer_cls( self.model.module.encoder.parameters(), lr=self.training_config.encoder_learning_rate, **self.training_config.encoder_optimizer_params, ) else: encoder_optimizer = encoder_optimizer_cls( self.model.encoder.parameters(), lr=self.training_config.encoder_learning_rate, **self.training_config.encoder_optimizer_params, ) else: if self.distributed: encoder_optimizer = encoder_optimizer_cls( self.model.module.encoder.parameters(), lr=self.training_config.encoder_learning_rate, ) else: encoder_optimizer = encoder_optimizer_cls( self.model.encoder.parameters(), lr=self.training_config.encoder_learning_rate, ) self.encoder_optimizer = encoder_optimizer def set_encoder_scheduler(self): if self.training_config.encoder_scheduler_cls is not None: encoder_scheduler_cls = getattr( lr_scheduler, self.training_config.encoder_scheduler_cls ) if self.training_config.encoder_scheduler_params is not None: scheduler = encoder_scheduler_cls( self.encoder_optimizer, **self.training_config.encoder_scheduler_params, ) else: scheduler = encoder_scheduler_cls(self.encoder_optimizer) else: scheduler = None self.encoder_scheduler = scheduler def set_decoder_optimizer(self): decoder_optimizer_cls = getattr( optim, self.training_config.decoder_optimizer_cls ) if self.training_config.decoder_optimizer_params is not None: if self.distributed: decoder_optimizer = decoder_optimizer_cls( self.model.module.decoder.parameters(), lr=self.training_config.decoder_learning_rate, **self.training_config.decoder_optimizer_params, ) else: decoder_optimizer = decoder_optimizer_cls( self.model.decoder.parameters(), lr=self.training_config.decoder_learning_rate, **self.training_config.decoder_optimizer_params, ) else: if self.distributed: decoder_optimizer = decoder_optimizer_cls( self.model.module.decoder.parameters(), lr=self.training_config.decoder_learning_rate, ) else: decoder_optimizer = decoder_optimizer_cls( self.model.decoder.parameters(), lr=self.training_config.decoder_learning_rate, ) self.decoder_optimizer = decoder_optimizer def set_decoder_scheduler(self): if self.training_config.decoder_scheduler_cls is not None: decoder_scheduler_cls = getattr( lr_scheduler, self.training_config.decoder_scheduler_cls ) if self.training_config.decoder_scheduler_params is not None: scheduler = decoder_scheduler_cls( self.decoder_optimizer, **self.training_config.decoder_scheduler_params, ) else: scheduler = decoder_scheduler_cls(self.decoder_optimizer) else: scheduler = None self.decoder_scheduler = scheduler def set_discriminator_optimizer(self): discriminator_cls = getattr( optim, self.training_config.discriminator_optimizer_cls ) if self.training_config.discriminator_optimizer_params is not None: if self.distributed: discriminator_optimizer = discriminator_cls( self.model.module.discriminator.parameters(), lr=self.training_config.discriminator_learning_rate, **self.training_config.discriminator_optimizer_params, ) else: discriminator_optimizer = discriminator_cls( self.model.discriminator.parameters(), lr=self.training_config.discriminator_learning_rate, **self.training_config.discriminator_optimizer_params, ) else: if self.distributed: discriminator_optimizer = discriminator_cls( self.model.module.discriminator.parameters(), lr=self.training_config.discriminator_learning_rate, ) else: discriminator_optimizer = discriminator_cls( self.model.discriminator.parameters(), lr=self.training_config.discriminator_learning_rate, ) self.discriminator_optimizer = discriminator_optimizer def set_discriminator_scheduler(self) -> torch.optim.lr_scheduler: if self.training_config.discriminator_scheduler_cls is not None: discriminator_scheduler_cls = getattr( lr_scheduler, self.training_config.discriminator_scheduler_cls ) if self.training_config.discriminator_scheduler_params is not None: scheduler = discriminator_scheduler_cls( self.discriminator_optimizer, **self.training_config.discriminator_scheduler_params, ) else: scheduler = discriminator_scheduler_cls(self.discriminator_optimizer) else: scheduler = None self.discriminator_scheduler = scheduler def _optimizers_step(self, model_output): encoder_loss = model_output.encoder_loss decoder_loss = model_output.decoder_loss discriminator_loss = model_output.discriminator_loss # Reset optimizers if model_output.update_encoder: self.encoder_optimizer.zero_grad() encoder_loss.backward(retain_graph=True) if model_output.update_decoder: self.decoder_optimizer.zero_grad() decoder_loss.backward(retain_graph=True) if model_output.update_discriminator: self.discriminator_optimizer.zero_grad() discriminator_loss.backward() if model_output.update_encoder: self.encoder_optimizer.step() if model_output.update_decoder: self.decoder_optimizer.step() if model_output.update_discriminator: self.discriminator_optimizer.step() def _schedulers_step( self, encoder_metrics=None, decoder_metrics=None, discriminator_metrics=None ): if self.encoder_scheduler is None: pass elif isinstance(self.encoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.encoder_scheduler.step(encoder_metrics) else: self.encoder_scheduler.step() if self.decoder_scheduler is None: pass elif isinstance(self.decoder_scheduler, lr_scheduler.ReduceLROnPlateau): self.decoder_scheduler.step(decoder_metrics) else: self.decoder_scheduler.step() if self.discriminator_scheduler is None: pass elif isinstance(self.discriminator_scheduler, lr_scheduler.ReduceLROnPlateau): self.discriminator_scheduler.step(discriminator_metrics) else: self.discriminator_scheduler.step()
[docs] def prepare_training(self): """Sets up the trainer for training""" # set random seed set_seed(self.training_config.seed) # set encoder optimizer and scheduler self.set_encoder_optimizer() self.set_encoder_scheduler() # set decoder optimizer and scheduler self.set_decoder_optimizer() self.set_decoder_scheduler() # set discriminator optimizer and scheduler self.set_discriminator_optimizer() self.set_discriminator_scheduler() # create foder for saving self._set_output_dir() # set callbacks self._setup_callbacks()
[docs] def train(self, log_output_dir: str = None): """This function is the main training function Args: log_output_dir (str): The path in which the log will be stored """ self.prepare_training() self.callback_handler.on_train_begin( training_config=self.training_config, model_config=self.model_config ) log_verbose = False msg = ( f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" " - per_device_train_batch_size: " f"{self.training_config.per_device_train_batch_size}\n" " - per_device_eval_batch_size: " f"{self.training_config.per_device_eval_batch_size}\n" f" - checkpoint saving every: {self.training_config.steps_saving}\n" f"Encoder Optimizer: {self.encoder_optimizer}\n" f"Encoder Scheduler: {self.encoder_scheduler}\n" f"Decoder Optimizer: {self.decoder_optimizer}\n" f"Decoder Scheduler: {self.decoder_scheduler}\n" f"Discriminator Optimizer: {self.discriminator_optimizer}\n" f"Discriminator Scheduler: {self.discriminator_scheduler}\n" ) if self.is_main_process: logger.info(msg) # set up log file if log_output_dir is not None and self.is_main_process: log_verbose = True file_logger = self._get_file_logger(log_output_dir=log_output_dir) file_logger.info(msg) if self.is_main_process: logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, train_loader=self.train_loader, eval_loader=self.eval_loader, ) metrics = {} train_losses = self.train_step(epoch) [ epoch_train_loss, epoch_train_encoder_loss, epoch_train_decoder_loss, epoch_train_discriminator_loss, ] = train_losses metrics["train_epoch_loss"] = epoch_train_loss metrics["train_encoder_loss"] = epoch_train_encoder_loss metrics["train_decoder_loss"] = epoch_train_decoder_loss metrics["train_discriminator_loss"] = epoch_train_discriminator_loss if self.eval_dataset is not None: eval_losses = self.eval_step(epoch) [ epoch_eval_loss, epoch_eval_encoder_loss, epoch_eval_decoder_loss, epoch_eval_discriminator_loss, ] = eval_losses metrics["eval_epoch_loss"] = epoch_eval_loss metrics["eval_encoder_loss"] = epoch_eval_encoder_loss metrics["eval_decoder_loss"] = epoch_eval_decoder_loss metrics["eval_discriminator_loss"] = epoch_eval_discriminator_loss self._schedulers_step( encoder_metrics=epoch_eval_encoder_loss, decoder_metrics=epoch_eval_decoder_loss, discriminator_metrics=epoch_eval_discriminator_loss, ) else: epoch_eval_loss = best_eval_loss self._schedulers_step( encoder_metrics=epoch_train_encoder_loss, decoder_metrics=epoch_train_decoder_loss, discriminator_metrics=epoch_train_discriminator_loss, ) if ( epoch_eval_loss < best_eval_loss and not self.training_config.keep_best_on_train ): best_eval_loss = epoch_eval_loss best_model = deepcopy(self.model) self._best_model = best_model elif ( epoch_train_loss < best_train_loss and self.training_config.keep_best_on_train ): best_train_loss = epoch_train_loss best_model = deepcopy(self.model) self._best_model = best_model if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) self.callback_handler.on_prediction_step( self.training_config, true_data=true_data, reconstructions=reconstructions, generations=generations, global_step=epoch, ) self.callback_handler.on_epoch_end(training_config=self.training_config) # save checkpoints if ( self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): if self.is_main_process: self.save_checkpoint( model=best_model, dir_path=self.training_dir, epoch=epoch ) logger.info(f"Saved checkpoint at epoch {epoch}\n") if log_verbose: file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( self.training_config, metrics, logger=logger, global_step=epoch, rank=self.rank, ) final_dir = os.path.join(self.training_dir, "final_model") if self.is_main_process: self.save_model(best_model, dir_path=final_dir) logger.info("----------------------------------") logger.info("Training ended!") logger.info(f"Saved final model in {final_dir}") if self.distributed: dist.destroy_process_group() self.callback_handler.on_train_end(training_config=self.training_config)
[docs] def eval_step(self, epoch: int): """Perform an evaluation step Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The evaluation loss """ self.callback_handler.on_eval_step_begin( training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, rank=self.rank, ) self.model.eval() epoch_encoder_loss = 0 epoch_decoder_loss = 0 epoch_discriminator_loss = 0 epoch_loss = 0 for inputs in self.eval_loader: inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, ) encoder_loss = model_output.encoder_loss decoder_loss = model_output.decoder_loss discriminator_loss = model_output.discriminator_loss loss = encoder_loss + decoder_loss + discriminator_loss epoch_encoder_loss += encoder_loss.item() epoch_decoder_loss += decoder_loss.item() epoch_discriminator_loss += discriminator_loss.item() epoch_loss += loss.item() if epoch_loss != epoch_loss: raise ArithmeticError("NaN detected in eval loss") self.callback_handler.on_eval_step_end(training_config=self.training_config) epoch_encoder_loss /= len(self.eval_loader) epoch_decoder_loss /= len(self.eval_loader) epoch_discriminator_loss /= len(self.eval_loader) epoch_loss /= len(self.eval_loader) return ( epoch_loss, epoch_encoder_loss, epoch_decoder_loss, epoch_discriminator_loss, )
[docs] def train_step(self, epoch: int): """The trainer performs training loop over the train_loader. Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The step training loss """ self.callback_handler.on_train_step_begin( training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, rank=self.rank, ) # set model in train model self.model.train() epoch_encoder_loss = 0 epoch_decoder_loss = 0 epoch_discriminator_loss = 0 epoch_loss = 0 for inputs in self.train_loader: inputs = self._set_inputs_to_device(inputs) model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset), uses_ddp=self.distributed, ) self._optimizers_step(model_output) encoder_loss = model_output.encoder_loss decoder_loss = model_output.decoder_loss discriminator_loss = model_output.discriminator_loss loss = encoder_loss + decoder_loss + discriminator_loss epoch_encoder_loss += encoder_loss.item() epoch_decoder_loss += decoder_loss.item() epoch_discriminator_loss += discriminator_loss.item() epoch_loss += loss.item() self.callback_handler.on_train_step_end( training_config=self.training_config ) # Allows model updates if needed if self.distributed: self.model.module.update() else: self.model.update() epoch_encoder_loss /= len(self.train_loader) epoch_decoder_loss /= len(self.train_loader) epoch_discriminator_loss /= len(self.train_loader) epoch_loss /= len(self.train_loader) return ( epoch_loss, epoch_encoder_loss, epoch_decoder_loss, epoch_discriminator_loss, )
[docs] def save_checkpoint(self, model: BaseAE, dir_path, epoch: int): """Saves a checkpoint alowing to restart training from here Args: dir_path (str): The folder where the checkpoint should be saved epochs_signature (int): The epoch number""" checkpoint_dir = os.path.join(dir_path, f"checkpoint_epoch_{epoch}") if not os.path.exists(checkpoint_dir): os.makedirs(checkpoint_dir) # save optimizers torch.save( deepcopy(self.encoder_optimizer.state_dict()), os.path.join(checkpoint_dir, "encoder_optimizer.pt"), ) torch.save( deepcopy(self.decoder_optimizer.state_dict()), os.path.join(checkpoint_dir, "decoder_optimizer.pt"), ) torch.save( deepcopy(self.discriminator_optimizer.state_dict()), os.path.join(checkpoint_dir, "discriminator_optimizer.pt"), ) # save model if self.distributed: model.module.save(checkpoint_dir) else: model.save(checkpoint_dir) # save training config self.training_config.save_json(checkpoint_dir, "training_config")