import contextlib
import datetime
import logging
import os
from copy import deepcopy
from typing import Any, Dict, List, Optional
import torch
import torch.distributed as dist
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from ...customexception import ModelError
from ...data.datasets import BaseDataset, collate_dataset_output
from ...models import BaseAE
from ..trainer_utils import set_seed
from ..training_callbacks import (
CallbackHandler,
MetricConsolePrinterCallback,
ProgressBarCallback,
TrainingCallback,
)
from .base_training_config import BaseTrainerConfig
logger = logging.getLogger(__name__)
# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]class BaseTrainer:
"""Base class to perform model training.
Args:
model (BaseAE): A instance of :class:`~pythae.models.BaseAE` to train
train_dataset (BaseDataset): The training dataset of type
:class:`~pythae.data.dataset.BaseDataset`
eval_dataset (BaseDataset): The evaluation dataset of type
:class:`~pythae.data.dataset.BaseDataset`
training_config (BaseTrainerConfig): The training arguments summarizing the main
parameters used for training. If None, a basic training instance of
:class:`BaseTrainerConfig` is used. Default: None.
callbacks (List[~pythae.trainers.training_callbacks.TrainingCallbacks]):
A list of callbacks to use during training.
"""
def __init__(
self,
model: BaseAE,
train_dataset: BaseDataset,
eval_dataset: Optional[BaseDataset] = None,
training_config: Optional[BaseTrainerConfig] = None,
callbacks: List[TrainingCallback] = None,
):
if training_config is None:
training_config = BaseTrainerConfig()
if training_config.output_dir is None:
output_dir = "dummy_output_dir"
training_config.output_dir = output_dir
self.training_config = training_config
self.model_config = model.model_config
self.model_name = model.model_name
# for distributed training
self.world_size = self.training_config.world_size
self.local_rank = self.training_config.local_rank
self.rank = self.training_config.rank
self.dist_backend = self.training_config.dist_backend
if self.world_size > 1:
self.distributed = True
else:
self.distributed = False
if self.distributed:
device = self._setup_devices()
else:
device = (
"cuda"
if torch.cuda.is_available() and not self.training_config.no_cuda
else "cpu"
)
self.amp_context = (
torch.autocast("cuda")
if self.training_config.amp
else contextlib.nullcontext()
)
if (
hasattr(model.model_config, "reconstruction_loss")
and model.model_config.reconstruction_loss == "bce"
):
self.amp_context = contextlib.nullcontext()
self.device = device
# place model on device
model = model.to(device)
model.device = device
if self.distributed:
model = DDP(model, device_ids=[self.local_rank])
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
# Define the loaders
train_loader = self.get_train_dataloader(train_dataset)
if eval_dataset is not None:
eval_loader = self.get_eval_dataloader(eval_dataset)
else:
logger.info(
"! No eval dataset provided ! -> keeping best model on train.\n"
)
self.training_config.keep_best_on_train = True
eval_loader = None
self.train_loader = train_loader
self.eval_loader = eval_loader
self.callbacks = callbacks
# run sanity check on the model
self._run_model_sanity_check(model, train_loader)
if self.is_main_process:
logger.info("Model passed sanity check !\n" "Ready for training.\n")
self.model = model
@property
def is_main_process(self):
if self.rank == 0 or self.rank == -1:
return True
else:
return False
def _setup_devices(self):
"""Sets up the devices to perform distributed training."""
if dist.is_available() and dist.is_initialized() and self.local_rank == -1:
logger.warning(
"torch.distributed process group is initialized, but local_rank == -1. "
)
if self.training_config.no_cuda:
self._n_gpus = 0
device = "cpu"
else:
torch.cuda.set_device(self.local_rank)
device = torch.device("cuda", self.local_rank)
if not dist.is_initialized():
dist.init_process_group(
backend=self.dist_backend,
init_method="env://",
world_size=self.world_size,
rank=self.rank,
)
return device
def get_train_dataloader(
self, train_dataset: BaseDataset
) -> torch.utils.data.DataLoader:
if self.distributed:
train_sampler = DistributedSampler(
train_dataset, num_replicas=self.world_size, rank=self.rank
)
else:
train_sampler = None
return DataLoader(
dataset=train_dataset,
batch_size=self.training_config.per_device_train_batch_size,
num_workers=self.training_config.train_dataloader_num_workers,
shuffle=(train_sampler is None),
sampler=train_sampler,
collate_fn=collate_dataset_output,
)
def get_eval_dataloader(
self, eval_dataset: BaseDataset
) -> torch.utils.data.DataLoader:
if self.distributed:
eval_sampler = DistributedSampler(
eval_dataset, num_replicas=self.world_size, rank=self.rank
)
else:
eval_sampler = None
return DataLoader(
dataset=eval_dataset,
batch_size=self.training_config.per_device_eval_batch_size,
num_workers=self.training_config.eval_dataloader_num_workers,
shuffle=(eval_sampler is None),
sampler=eval_sampler,
collate_fn=collate_dataset_output,
)
def set_optimizer(self):
optimizer_cls = getattr(optim, self.training_config.optimizer_cls)
if self.training_config.optimizer_params is not None:
optimizer = optimizer_cls(
self.model.parameters(),
lr=self.training_config.learning_rate,
**self.training_config.optimizer_params,
)
else:
optimizer = optimizer_cls(
self.model.parameters(), lr=self.training_config.learning_rate
)
self.optimizer = optimizer
def set_scheduler(self):
if self.training_config.scheduler_cls is not None:
scheduler_cls = getattr(lr_scheduler, self.training_config.scheduler_cls)
if self.training_config.scheduler_params is not None:
scheduler = scheduler_cls(
self.optimizer, **self.training_config.scheduler_params
)
else:
scheduler = scheduler_cls(self.optimizer)
else:
scheduler = None
self.scheduler = scheduler
def _set_output_dir(self):
# Create folder
if not os.path.exists(self.training_config.output_dir) and self.is_main_process:
os.makedirs(self.training_config.output_dir, exist_ok=True)
logger.info(
f"Created {self.training_config.output_dir} folder since did not exist.\n"
)
self._training_signature = (
str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-")
)
training_dir = os.path.join(
self.training_config.output_dir,
f"{self.model_name}_training_{self._training_signature}",
)
self.training_dir = training_dir
if not os.path.exists(training_dir) and self.is_main_process:
os.makedirs(training_dir, exist_ok=True)
logger.info(
f"Created {training_dir}. \n"
"Training config, checkpoints and final model will be saved here.\n"
)
def _get_file_logger(self, log_output_dir):
log_dir = log_output_dir
# if dir does not exist create it
if not os.path.exists(log_dir) and self.is_main_process:
os.makedirs(log_dir, exist_ok=True)
logger.info(f"Created {log_dir} folder since did not exists.")
logger.info("Training logs will be recodered here.\n")
logger.info(" -> Training can be monitored here.\n")
# create and set logger
log_name = f"training_logs_{self._training_signature}"
file_logger = logging.getLogger(log_name)
file_logger.setLevel(logging.INFO)
f_handler = logging.FileHandler(
os.path.join(log_dir, f"training_logs_{self._training_signature}.log")
)
f_handler.setLevel(logging.INFO)
file_logger.addHandler(f_handler)
# Do not output logs in the console
file_logger.propagate = False
return file_logger
def _setup_callbacks(self):
if self.callbacks is None:
self.callbacks = [TrainingCallback()]
self.callback_handler = CallbackHandler(
callbacks=self.callbacks, model=self.model
)
self.callback_handler.add_callback(ProgressBarCallback())
self.callback_handler.add_callback(MetricConsolePrinterCallback())
def _run_model_sanity_check(self, model, loader):
try:
inputs = next(iter(loader))
train_dataset = self._set_inputs_to_device(inputs)
model(train_dataset)
except Exception as e:
raise ModelError(
"Error when calling forward method from model. Potential issues: \n"
" - Wrong model architecture -> check encoder, decoder and metric architecture if "
"you provide yours \n"
" - The data input dimension provided is wrong -> when no encoder, decoder or metric "
"provided, a network is built automatically but requires the shape of the flatten "
"input data.\n"
f"Exception raised: {type(e)} with message: " + str(e)
) from e
def _set_optimizer_on_device(self, optim, device):
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
return optim
def _set_inputs_to_device(self, inputs: Dict[str, Any]):
inputs_on_device = inputs
if self.device == "cuda":
cuda_inputs = dict.fromkeys(inputs)
for key in inputs.keys():
if torch.is_tensor(inputs[key]):
cuda_inputs[key] = inputs[key].cuda()
else:
cuda_inputs[key] = inputs[key]
inputs_on_device = cuda_inputs
return inputs_on_device
def _optimizers_step(self, model_output=None):
loss = model_output.loss
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
def _schedulers_step(self, metrics=None):
if self.scheduler is None:
pass
elif isinstance(self.scheduler, lr_scheduler.ReduceLROnPlateau):
self.scheduler.step(metrics)
else:
self.scheduler.step()
[docs] def prepare_training(self):
"""Sets up the trainer for training"""
# set random seed
set_seed(self.training_config.seed)
# set optimizer
self.set_optimizer()
# set scheduler
self.set_scheduler()
# create folder for saving
self._set_output_dir()
# set callbacks
self._setup_callbacks()
[docs] def train(self, log_output_dir: str = None):
"""This function is the main training function
Args:
log_output_dir (str): The path in which the log will be stored
"""
self.prepare_training()
self.callback_handler.on_train_begin(
training_config=self.training_config, model_config=self.model_config
)
log_verbose = False
msg = (
f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n"
" - per_device_train_batch_size: "
f"{self.training_config.per_device_train_batch_size}\n"
" - per_device_eval_batch_size: "
f"{self.training_config.per_device_eval_batch_size}\n"
f" - checkpoint saving every: {self.training_config.steps_saving}\n"
f"Optimizer: {self.optimizer}\n"
f"Scheduler: {self.scheduler}\n"
)
if self.is_main_process:
logger.info(msg)
# set up log file
if log_output_dir is not None and self.is_main_process:
log_verbose = True
file_logger = self._get_file_logger(log_output_dir=log_output_dir)
file_logger.info(msg)
if self.is_main_process:
logger.info("Successfully launched training !\n")
# set best losses for early stopping
best_train_loss = 1e10
best_eval_loss = 1e10
for epoch in range(1, self.training_config.num_epochs + 1):
self.callback_handler.on_epoch_begin(
training_config=self.training_config,
epoch=epoch,
train_loader=self.train_loader,
eval_loader=self.eval_loader,
)
metrics = {}
epoch_train_loss = self.train_step(epoch)
metrics["train_epoch_loss"] = epoch_train_loss
if self.eval_dataset is not None:
epoch_eval_loss = self.eval_step(epoch)
metrics["eval_epoch_loss"] = epoch_eval_loss
self._schedulers_step(epoch_eval_loss)
else:
epoch_eval_loss = best_eval_loss
self._schedulers_step(epoch_train_loss)
if (
epoch_eval_loss < best_eval_loss
and not self.training_config.keep_best_on_train
):
best_eval_loss = epoch_eval_loss
best_model = deepcopy(self.model)
self._best_model = best_model
elif (
epoch_train_loss < best_train_loss
and self.training_config.keep_best_on_train
):
best_train_loss = epoch_train_loss
best_model = deepcopy(self.model)
self._best_model = best_model
if (
self.training_config.steps_predict is not None
and epoch % self.training_config.steps_predict == 0
and self.is_main_process
):
true_data, reconstructions, generations = self.predict(best_model)
self.callback_handler.on_prediction_step(
self.training_config,
true_data=true_data,
reconstructions=reconstructions,
generations=generations,
global_step=epoch,
)
self.callback_handler.on_epoch_end(training_config=self.training_config)
# save checkpoints
if (
self.training_config.steps_saving is not None
and epoch % self.training_config.steps_saving == 0
):
if self.is_main_process:
self.save_checkpoint(
model=best_model, dir_path=self.training_dir, epoch=epoch
)
logger.info(f"Saved checkpoint at epoch {epoch}\n")
if log_verbose:
file_logger.info(f"Saved checkpoint at epoch {epoch}\n")
self.callback_handler.on_log(
self.training_config,
metrics,
logger=logger,
global_step=epoch,
rank=self.rank,
)
final_dir = os.path.join(self.training_dir, "final_model")
if self.is_main_process:
self.save_model(best_model, dir_path=final_dir)
logger.info("Training ended!")
logger.info(f"Saved final model in {final_dir}")
if self.distributed:
dist.destroy_process_group()
self.callback_handler.on_train_end(self.training_config)
[docs] def eval_step(self, epoch: int):
"""Perform an evaluation step
Parameters:
epoch (int): The current epoch number
Returns:
(torch.Tensor): The evaluation loss
"""
self.callback_handler.on_eval_step_begin(
training_config=self.training_config,
eval_loader=self.eval_loader,
epoch=epoch,
rank=self.rank,
)
self.model.eval()
epoch_loss = 0
with self.amp_context:
for inputs in self.eval_loader:
inputs = self._set_inputs_to_device(inputs)
try:
with torch.no_grad():
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.eval_loader.dataset),
uses_ddp=self.distributed,
)
except RuntimeError:
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.eval_loader.dataset),
uses_ddp=self.distributed,
)
loss = model_output.loss
epoch_loss += loss.item()
if epoch_loss != epoch_loss:
raise ArithmeticError("NaN detected in eval loss")
self.callback_handler.on_eval_step_end(
training_config=self.training_config
)
epoch_loss /= len(self.eval_loader)
return epoch_loss
[docs] def train_step(self, epoch: int):
"""The trainer performs training loop over the train_loader.
Parameters:
epoch (int): The current epoch number
Returns:
(torch.Tensor): The step training loss
"""
self.callback_handler.on_train_step_begin(
training_config=self.training_config,
train_loader=self.train_loader,
epoch=epoch,
rank=self.rank,
)
# set model in train model
self.model.train()
epoch_loss = 0
for inputs in self.train_loader:
inputs = self._set_inputs_to_device(inputs)
with self.amp_context:
model_output = self.model(
inputs,
epoch=epoch,
dataset_size=len(self.train_loader.dataset),
uses_ddp=self.distributed,
)
self._optimizers_step(model_output)
loss = model_output.loss
epoch_loss += loss.item()
if epoch_loss != epoch_loss:
raise ArithmeticError("NaN detected in train loss")
self.callback_handler.on_train_step_end(
training_config=self.training_config
)
# Allows model updates if needed
if self.distributed:
self.model.module.update()
else:
self.model.update()
epoch_loss /= len(self.train_loader)
return epoch_loss
[docs] def save_model(self, model: BaseAE, dir_path: str):
"""This method saves the final model along with the config files
Args:
model (BaseAE): The model to be saved
dir_path (str): The folder where the model and config files should be saved
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# save model
if self.distributed:
model.module.save(dir_path)
else:
model.save(dir_path)
# save training config
self.training_config.save_json(dir_path, "training_config")
self.callback_handler.on_save(self.training_config)
[docs] def save_checkpoint(self, model: BaseAE, dir_path, epoch: int):
"""Saves a checkpoint alowing to restart training from here
Args:
dir_path (str): The folder where the checkpoint should be saved
epochs_signature (int): The epoch number"""
checkpoint_dir = os.path.join(dir_path, f"checkpoint_epoch_{epoch}")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# save optimizer
torch.save(
deepcopy(self.optimizer.state_dict()),
os.path.join(checkpoint_dir, "optimizer.pt"),
)
# save scheduler
if self.scheduler is not None:
torch.save(
deepcopy(self.scheduler.state_dict()),
os.path.join(checkpoint_dir, "scheduler.pt"),
)
# save model
if self.distributed:
model.module.save(checkpoint_dir)
else:
model.save(checkpoint_dir)
# save training config
self.training_config.save_json(checkpoint_dir, "training_config")
def predict(self, model: BaseAE):
model.eval()
with self.amp_context:
inputs = next(iter(self.eval_loader))
inputs = self._set_inputs_to_device(inputs)
model_out = model(inputs)
reconstructions = model_out.recon_x.cpu().detach()[
: min(inputs["data"].shape[0], 10)
]
z_enc = model_out.z[: min(inputs["data"].shape[0], 10)]
z = torch.randn_like(z_enc)
if self.distributed:
normal_generation = (
model.module.decoder(z).reconstruction.detach().cpu()
)
else:
normal_generation = model.decoder(z).reconstruction.detach().cpu()
return (
inputs["data"][: min(inputs["data"].shape[0], 10)],
reconstructions,
normal_generation,
)