"""Training Callbacks for training monitoring integrated in `pythae` (inspired from
https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_callback.py)"""
import importlib
import logging
import numpy as np
from tqdm.auto import tqdm
from ..models import BaseAEConfig
from .base_trainer.base_training_config import BaseTrainerConfig
logger = logging.getLogger(__name__)
def wandb_is_available():
return importlib.util.find_spec("wandb") is not None
def mlflow_is_available():
return importlib.util.find_spec("mlflow") is not None
def comet_is_available():
return importlib.util.find_spec("comet_ml") is not None
def rename_logs(logs):
train_prefix = "train_"
eval_prefix = "eval_"
clean_logs = {}
for metric_name in logs.keys():
if metric_name.startswith(train_prefix):
clean_logs[metric_name.replace(train_prefix, "train/")] = logs[metric_name]
if metric_name.startswith(eval_prefix):
clean_logs[metric_name.replace(eval_prefix, "eval/")] = logs[metric_name]
return clean_logs
[docs]class TrainingCallback:
"""
Base class for creating training callbacks
"""
[docs] def on_init_end(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the end of the initialization of the [`Trainer`].
"""
[docs] def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the beginning of training.
"""
[docs] def on_train_end(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the end of training.
"""
[docs] def on_epoch_begin(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the beginning of an epoch.
"""
[docs] def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the end of an epoch.
"""
[docs] def on_train_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the beginning of a training step.
"""
[docs] def on_train_step_end(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the end of a training step.
"""
[docs] def on_eval_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the beginning of a evaluation step.
"""
[docs] def on_eval_step_end(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called at the end of a evaluation step.
"""
[docs] def on_evaluate(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called after an evaluation phase.
"""
[docs] def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called after a prediction phase.
"""
[docs] def on_save(self, training_config: BaseTrainerConfig, **kwargs):
"""
Event called after a checkpoint save.
"""
[docs] def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
"""
Event called after logging the last logs.
"""
def __repr__(self) -> str:
return self.__class__.__name__
[docs]class CallbackHandler:
"""
Class to handle list of Callback.
"""
def __init__(self, callbacks, model):
self.callbacks = []
for cb in callbacks:
self.add_callback(cb)
self.model = model
def add_callback(self, callback):
cb = callback() if isinstance(callback, type) else callback
cb_class = callback if isinstance(callback, type) else callback.__class__
if cb_class in [c.__class__ for c in self.callbacks]:
logger.warning(
f"You are adding a {cb_class} to the callbacks but there one is already used."
f" The current list of callbacks is\n: {self.callback_list}"
)
self.callbacks.append(cb)
@property
def callback_list(self):
return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
@property
def callback_list(self):
return "\n".join(cb.__class__.__name__ for cb in self.callbacks)
def on_init_end(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_init_end", training_config, **kwargs)
def on_train_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_train_step_begin", training_config, **kwargs)
def on_train_step_end(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_train_step_end", training_config, **kwargs)
def on_eval_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_eval_step_begin", training_config, **kwargs)
def on_eval_step_end(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_eval_step_end", training_config, **kwargs)
def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_train_begin", training_config, **kwargs)
def on_train_end(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_train_end", training_config, **kwargs)
def on_epoch_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_epoch_begin", training_config, **kwargs)
def on_epoch_end(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_epoch_end", training_config)
def on_evaluate(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_evaluate", **kwargs)
def on_save(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_save", training_config, **kwargs)
def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
self.call_event("on_log", training_config, logs=logs, **kwargs)
def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs):
self.call_event("on_prediction_step", training_config, **kwargs)
def call_event(self, event, training_config, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
training_config,
model=self.model,
**kwargs,
)
[docs]class MetricConsolePrinterCallback(TrainingCallback):
"""
A :class:`TrainingCallback` printing the training logs in the console.
"""
def __init__(self):
self.logger = logging.getLogger(__name__)
# make it print to the console.
console = logging.StreamHandler()
self.logger.addHandler(console)
self.logger.setLevel(logging.INFO)
[docs] def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
logger = kwargs.pop("logger", self.logger)
rank = kwargs.pop("rank", -1)
if logger is not None and (rank == -1 or rank == 0):
epoch_train_loss = logs.get("train_epoch_loss", None)
epoch_eval_loss = logs.get("eval_epoch_loss", None)
logger.info(
"--------------------------------------------------------------------------"
)
if epoch_train_loss is not None:
logger.info(f"Train loss: {np.round(epoch_train_loss, 4)}")
if epoch_eval_loss is not None:
logger.info(f"Eval loss: {np.round(epoch_eval_loss, 4)}")
logger.info(
"--------------------------------------------------------------------------"
)
class TrainHistoryCallback(MetricConsolePrinterCallback):
def __init__(self):
self.history = {"train_loss": [], "eval_loss": []}
super().__init__()
def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
self.history = {"train_loss": [], "eval_loss": []}
def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
logger = kwargs.pop("logger", self.logger)
if logger is not None:
epoch_train_loss = logs.get("train_epoch_loss", None)
epoch_eval_loss = logs.get("eval_epoch_loss", None)
self.history["train_loss"].append(epoch_train_loss)
self.history["eval_loss"].append(epoch_eval_loss)
[docs]class ProgressBarCallback(TrainingCallback):
"""
A :class:`TrainingCallback` printing the training progress bar.
"""
def __init__(self):
self.train_progress_bar = None
self.eval_progress_bar = None
[docs] def on_train_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
epoch = kwargs.pop("epoch", None)
train_loader = kwargs.pop("train_loader", None)
rank = kwargs.pop("rank", -1)
if train_loader is not None:
if rank == 0 or rank == -1:
self.train_progress_bar = tqdm(
total=len(train_loader),
unit="batch",
desc=f"Training of epoch {epoch}/{training_config.num_epochs}",
)
[docs] def on_eval_step_begin(self, training_config: BaseTrainerConfig, **kwargs):
epoch = kwargs.pop("epoch", None)
eval_loader = kwargs.pop("eval_loader", None)
rank = kwargs.pop("rank", -1)
if eval_loader is not None:
if rank == 0 or rank == -1:
self.eval_progress_bar = tqdm(
total=len(eval_loader),
unit="batch",
desc=f"Eval of epoch {epoch}/{training_config.num_epochs}",
)
[docs] def on_train_step_end(self, training_config: BaseTrainerConfig, **kwargs):
if self.train_progress_bar is not None:
self.train_progress_bar.update(1)
[docs] def on_eval_step_end(self, training_config: BaseTrainerConfig, **kwargs):
if self.eval_progress_bar is not None:
self.eval_progress_bar.update(1)
[docs] def on_epoch_end(self, training_config: BaseTrainerConfig, **kwags):
if self.train_progress_bar is not None:
self.train_progress_bar.close()
if self.eval_progress_bar is not None:
self.eval_progress_bar.close()
[docs]class WandbCallback(TrainingCallback): # pragma: no cover
"""
A :class:`TrainingCallback` integrating the experiment tracking tool
`wandb` (https://wandb.ai/).
It allows users to store their configs, monitor their trainings
and compare runs through a graphic interface. To be able use this feature you will need:
- a valid `wandb` account
- the package `wandb` installed in your virtual env. If not you can install it with
.. code-block::
$ pip install wandb
- to be logged in to your wandb account using
.. code-block::
$ wandb login
"""
def __init__(self):
if not wandb_is_available():
raise ModuleNotFoundError(
"`wandb` package must be installed. Run `pip install wandb`"
)
else:
import wandb
self._wandb = wandb
[docs] def setup(
self,
training_config: BaseTrainerConfig,
model_config: BaseAEConfig = None,
project_name: str = "pythae_experiment",
entity_name: str = None,
**kwargs,
):
"""
Setup the WandbCallback.
args:
training_config (BaseTrainerConfig): The training configuration used in the run.
model_config (BaseAEConfig): The model configuration used in the run.
project_name (str): The name of the wandb project to use.
entity_name (str): The name of the wandb entity to use.
"""
self.is_initialized = True
training_config_dict = training_config.to_dict()
self.run = self._wandb.init(project=project_name, entity=entity_name)
if model_config is not None:
model_config_dict = model_config.to_dict()
self._wandb.config.update(
{
"training_config": training_config_dict,
"model_config": model_config_dict,
}
)
else:
self._wandb.config.update({**training_config_dict})
self._wandb.define_metric("train/global_step")
self._wandb.define_metric("*", step_metric="train/global_step", step_sync=True)
[docs] def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
model_config = kwargs.pop("model_config", None)
if not self.is_initialized:
self.setup(training_config, model_config=model_config)
[docs] def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
global_step = kwargs.pop("global_step", None)
logs = rename_logs(logs)
self._wandb.log({**logs, "train/global_step": global_step})
[docs] def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs):
kwargs.pop("global_step", None)
column_names = ["images_id", "truth", "reconstruction", "normal_generation"]
true_data = kwargs.pop("true_data", None)
reconstructions = kwargs.pop("reconstructions", None)
generations = kwargs.pop("generations", None)
data_to_log = []
if (
true_data is not None
and reconstructions is not None
and generations is not None
):
for i in range(len(true_data)):
data_to_log.append(
[
f"img_{i}",
self._wandb.Image(
np.moveaxis(true_data[i].cpu().detach().numpy(), 0, -1)
),
self._wandb.Image(
np.clip(
np.moveaxis(
reconstructions[i].cpu().detach().numpy(), 0, -1
),
0,
255.0,
)
),
self._wandb.Image(
np.clip(
np.moveaxis(
generations[i].cpu().detach().numpy(), 0, -1
),
0,
255.0,
)
),
]
)
val_table = self._wandb.Table(data=data_to_log, columns=column_names)
self._wandb.log({"my_val_table": val_table})
[docs] def on_train_end(self, training_config: BaseTrainerConfig, **kwargs):
self.run.finish()
[docs]class MLFlowCallback(TrainingCallback): # pragma: no cover
"""
A :class:`TrainingCallback` integrating the experiment tracking tool
`mlflow` (https://mlflow.org/).
It allows users to store their configs, monitor their trainings
and compare runs through a graphic interface. To be able use this feature you will need:
- the package `mlfow` installed in your virtual env. If not you can install it with
.. code-block::
$ pip install mlflow
"""
def __init__(self):
if not mlflow_is_available():
raise ModuleNotFoundError(
"`mlflow` package must be installed. Run `pip install mlflow`"
)
else:
import mlflow
self._mlflow = mlflow
[docs] def setup(
self,
training_config: BaseTrainerConfig,
model_config: BaseAEConfig = None,
run_name: str = None,
**kwargs,
):
"""
Setup the MLflowCallback.
args:
training_config (BaseTrainerConfig): The training configuration used in the run.
model_config (BaseAEConfig): The model configuration used in the run.
run_name (str): The name to apply to the current run.
"""
self.is_initialized = True
training_config_dict = training_config.to_dict()
self._mlflow.start_run(run_name=run_name)
logger.info(
f"MLflow run started with run_id={self._mlflow.active_run().info.run_id}"
)
if model_config is not None:
model_config_dict = model_config.to_dict()
self._mlflow.log_params(
{
**training_config_dict,
**model_config_dict,
}
)
else:
self._mlflow.log_params({**training_config_dict})
[docs] def on_train_begin(self, training_config, **kwargs):
model_config = kwargs.pop("model_config", None)
if not self.is_initialized:
self.setup(training_config, model_config=model_config)
[docs] def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
global_step = kwargs.pop("global_step", None)
logs = rename_logs(logs)
metrics = {}
for k, v in logs.items():
if isinstance(v, (int, float)):
metrics[k] = v
self._mlflow.log_metrics(metrics=metrics, step=global_step)
[docs] def on_train_end(self, training_config: BaseTrainerConfig, **kwargs):
self._mlflow.end_run()
def __del__(self):
# if the previous run is not terminated correctly, the fluent API will
# not let you start a new run before the previous one is killed
if (
callable(getattr(self._mlflow, "active_run", None))
and self._mlflow.active_run() is not None
):
self._mlflow.end_run()
[docs]class CometCallback(TrainingCallback): # pragma: no cover
"""
A :class:`TrainingCallback` integrating the experiment tracking tool
`comet_ml` (https://www.comet.com/site/).
It allows users to store their configs, monitor
their trainings and compare runs through a graphic interface. To be able use this feature
you will need:
- the package `comet_ml` installed in your virtual env. If not you can install it with
.. code-block::
$ pip install comet_ml
"""
def __init__(self):
if not comet_is_available():
raise ModuleNotFoundError(
"`comet_ml` package must be installed. Run `pip install comet_ml`"
)
else:
import comet_ml
self._comet_ml = comet_ml
[docs] def setup(
self,
training_config: BaseTrainerConfig,
model_config: BaseTrainerConfig = None,
api_key: str = None,
project_name: str = "pythae_experiment",
workspace: str = None,
offline_run: bool = False,
offline_directory: str = "./",
**kwargs,
):
"""
Setup the CometCallback.
args:
training_config (BaseTraineronfig): The training configuration used in the run.
model_config (BaseAEConfig): The model configuration used in the run.
api_key (str): Your personal comet-ml `api_key`.
project_name (str): The name of the wandb project to use.
workspace (str): The name of your comet-ml workspace
offline_run: (bool): Whether to run comet-ml in offline mode.
offline_directory (str): The path to store the offline runs. They can to be
synchronized then by running `comet upload`.
"""
self.is_initialized = True
training_config_dict = training_config.to_dict()
if not offline_run:
experiment = self._comet_ml.Experiment(
api_key=api_key, project_name=project_name, workspace=workspace
)
experiment.log_other("Created from", "pythae")
else:
experiment = self._comet_ml.OfflineExperiment(
api_key=api_key,
project_name=project_name,
workspace=workspace,
offline_directory=offline_directory,
)
experiment.log_other("Created from", "pythae")
experiment.log_parameters(training_config, prefix="training_config/")
experiment.log_parameters(model_config, prefix="model_config/")
[docs] def on_train_begin(self, training_config: BaseTrainerConfig, **kwargs):
model_config = kwargs.pop("model_config", None)
if not self.is_initialized:
self.setup(training_config, model_config=model_config)
[docs] def on_log(self, training_config: BaseTrainerConfig, logs, **kwargs):
global_step = kwargs.pop("global_step", None)
experiment = self._comet_ml.get_global_experiment()
experiment.log_metrics(logs, step=global_step, epoch=global_step)
[docs] def on_prediction_step(self, training_config: BaseTrainerConfig, **kwargs):
global_step = kwargs.pop("global_step", None)
column_names = ["images_id", "truth", "reconstruction", "normal_generation"]
true_data = kwargs.pop("true_data", None)
reconstructions = kwargs.pop("reconstructions", None)
generations = kwargs.pop("generations", None)
experiment = self._comet_ml.get_global_experiment()
if (
true_data is not None
and reconstructions is not None
and generations is not None
):
for i in range(len(true_data)):
experiment.log_image(
np.moveaxis(true_data[i].cpu().detach().numpy(), 0, -1),
name=f"{i}_truth",
step=global_step,
)
experiment.log_image(
np.clip(
np.moveaxis(reconstructions[i].cpu().detach().numpy(), 0, -1),
0,
255.0,
),
name=f"{i}_reconstruction",
step=global_step,
)
experiment.log_image(
np.clip(
np.moveaxis(generations[i].cpu().detach().numpy(), 0, -1),
0,
255.0,
),
name=f"{i}_normal_generation",
step=global_step,
)
[docs] def on_train_end(self, training_config: BaseTrainerConfig, **kwargs):
experiment = self._comet_ml.config.get_global_experiment()
experiment.end()