Source code for pythae.trainers.coupled_optimizer_trainer.coupled_optimizer_trainer_config

from typing import Union

import torch.nn as nn
from pydantic.dataclasses import dataclass

from ..base_trainer import BaseTrainerConfig


[docs]@dataclass class CoupledOptimizerTrainerConfig(BaseTrainerConfig): """ CoupledOptimizerTrainer config class. Parameters: output_dir (str): The directory where model checkpoints, configs and final model will be stored. Default: None. per_device_train_batch_size (int): The number of training samples per batch and per device. Default 64 per_device_eval_batch_size (int): The number of evaluation samples per batch and per device. Default 64 num_epochs (int): The maximal number of epochs for training. Default: 100 train_dataloader_num_workers (int): Number of subprocesses to use for train data loading. 0 means that the data will be loaded in the main process. Default: 0 eval_dataloader_num_workers (int): Number of subprocesses to use for evaluation data loading. 0 means that the data will be loaded in the main process. Default: 0 encoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for the training of the encoder. Default: :class:`~torch.optim.Adam`. encoder_optimizer_params (dict): A dict containing the parameters to use for the `torch.optim.Optimizer` for the encoder. If None, uses the default parameters. Default: None. encoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for the training of the encoder. Default :class:`~torch.optim.Adam`. encoder_scheduler_params (dict): A dict containing the parameters to use for the `torch.optim.le_scheduler` for the encoder. If None, uses the default parameters. Default: None. decoder_optimizer_cls (str): The name of the `torch.optim.Optimizer` used for the training of the decoder. Default: :class:`~torch.optim.Adam`. decoder_optimizer_params (dict): A dict containing the parameters to use for the `torch.optim.Optimizer` for the decoder. If None, uses the default parameters. Default: None. decoder_scheduler_cls (str): The name of the `torch.optim.lr_scheduler` used for the training of the decoder. Default :class:`~torch.optim.Adam`. decoder_scheduler_params (dict): A dict containing the parameters to use for the `torch.optim.le_scheduler` for the decoder. If None, uses the default parameters. Default: None. encoder_learning_rate (int): The learning rate applied to the `Optimizer` for the encoder. Default: 1e-4 decoder_learning_rate (int): The learning rate applied to the `Optimizer` for the decoder. Default: 1e-4 steps_saving (int): A model checkpoint will be saved every `steps_saving` epoch. Default: None steps_saving (int): A prediction using the best model will be run every `steps_predict` epoch. Default: None keep_best_on_train (bool): Whether to keep the best model on the train set. Default: False seed (int): The random seed for reproducibility no_cuda (bool): Disable `cuda` training. Default: False world_size (int): The total number of process to run. Default: -1 local_rank (int): The rank of the node for distributed training. Default: -1 rank (int): The rank of the process for distributed training. Default: -1 dist_backend (str): The distributed backend to use. Default: 'nccl' master_addr (str): The master address for distributed training. Default: 'localhost' master_port (str): The master port for distributed training. Default: '12345' """ encoder_optimizer_cls: str = "Adam" encoder_optimizer_params: Union[dict, None] = None encoder_scheduler_cls: str = None encoder_scheduler_params: Union[dict, None] = None decoder_optimizer_cls: str = "Adam" decoder_optimizer_params: Union[dict, None] = None decoder_scheduler_cls: str = None decoder_scheduler_params: Union[dict, None] = None encoder_learning_rate: float = 1e-4 decoder_learning_rate: float = 1e-4 def __post_init__(self): """Check compatibilty""" super().__post_init__() # encoder optimizer and scheduler try: import torch.optim as optim encoder_optimizer_cls = getattr(optim, self.encoder_optimizer_cls) except AttributeError as e: raise AttributeError( f"Unable to import `{self.encoder_optimizer_cls}` encoder optimizer " "from 'torch.optim'. Check spelling and that it is part of " "'torch.optim.Optimizers.'" ) if self.encoder_optimizer_params is not None: try: encoder_optimizer = encoder_optimizer_cls( nn.Linear(2, 2).parameters(), lr=self.encoder_learning_rate, **self.encoder_optimizer_params, ) except TypeError as e: raise TypeError( "Error in optimizer's parameters. Check that the provided dict contains only " f"keys and values suitable for `{encoder_optimizer_cls}` optimizer. " f"Got {self.encoder_optimizer_params} as parameters.\n" f"Exception raised: {type(e)} with message: " + str(e) ) from e else: encoder_optimizer = encoder_optimizer_cls( nn.Linear(2, 2).parameters(), lr=self.encoder_learning_rate ) if self.encoder_scheduler_cls is not None: try: import torch.optim.lr_scheduler as schedulers encoder_scheduder_cls = getattr(schedulers, self.encoder_scheduler_cls) except AttributeError as e: raise AttributeError( f"Unable to import `{self.encoder_scheduler_cls}` encoder scheduler from " "'torch.optim.lr_scheduler'. Check spelling and that it is part of " "'torch.optim.lr_scheduler.'" ) if self.encoder_scheduler_params is not None: try: encoder_scheduder_cls( encoder_optimizer, **self.encoder_scheduler_params ) except TypeError as e: raise TypeError( "Error in scheduler's parameters. Check that the provided dict contains only " f"keys and values suitable for `{encoder_scheduder_cls}` scheduler. " f"Got {self.encoder_scheduler_params} as parameters.\n" f"Exception raised: {type(e)} with message: " + str(e) ) from e # decoder optimizer and scheduler try: decoder_optimizer_cls = getattr(optim, self.decoder_optimizer_cls) except AttributeError as e: raise AttributeError( f"Unable to import `{self.decoder_optimizer_cls}` decoder optimizer " "from 'torch.optim'. Check spelling and that it is part of " "'torch.optim.Optimizers.'" ) if self.decoder_optimizer_params is not None: try: decoder_optimizer = decoder_optimizer_cls( nn.Linear(2, 2).parameters(), lr=self.decoder_learning_rate, **self.decoder_optimizer_params, ) except TypeError as e: raise TypeError( "Error in optimizer's parameters. Check that the provided dict contains only " f"keys and values suitable for `{decoder_optimizer_cls}` optimizer. " f"Got {self.decoder_optimizer_params} as parameters.\n" f"Exception raised: {type(e)} with message: " + str(e) ) from e else: decoder_optimizer = decoder_optimizer_cls( nn.Linear(2, 2).parameters(), lr=self.decoder_learning_rate ) if self.decoder_scheduler_cls is not None: try: import torch.optim.lr_scheduler as schedulers decoder_scheduder_cls = getattr(schedulers, self.decoder_scheduler_cls) except AttributeError as e: raise AttributeError( f"Unable to import `{self.decoder_scheduler_cls}` decoder scheduler from " "'torch.optim.lr_scheduler'. Check spelling and that it is part of " "'torch.optim.lr_scheduler.'" ) if self.decoder_scheduler_params is not None: try: decoder_scheduder_cls( decoder_optimizer, **self.decoder_scheduler_params ) except TypeError as e: raise TypeError( "Error in scheduler's parameters. Check that the provided dict contains only " f"keys and values suitable for `{decoder_scheduder_cls}` scheduler. " f"Got {self.decoder_scheduler_params} as parameters.\n" f"Exception raised: {type(e)} with message: " + str(e) ) from e