TrainingPipeline

class pythae.pipelines.TrainingPipeline(model, training_config=None)[source]

This Pipeline provides an end to end way to train your VAE model. The trained model will be saved in output_dir stated in the BaseTrainerConfig. A folder training_YYYY-MM-DD_hh-mm-ss is created where checkpoints and final model will be saved. Checkpoints are saved in checkpoint_epoch_{epoch} folder (optimizer and training config saved as well to resume training if needed) and the final model is saved in a final_model folder. If output_dir is None, data is saved in dummy_output_dir/training_YYYY-MM-DD_hh-mm-ss is created.

Parameters
  • model (Optional[BaseAE]) – An instance of BaseAE you want to train. If None, a default VAE model is used. Default: None.

  • training_config (Optional[BaseTrainerConfig]) – An instance of BaseTrainerConfig stating the training parameters. If None, a default configuration is used.

__call__(train_data, eval_data=None, callbacks=None)[source]

Launch the model training on the provided data.

Parameters
  • training_data (Union[ndarray, Tensor]) – The training data as a numpy.ndarray or torch.Tensor of shape (mini_batch x n_channels x …)

  • eval_data (Optional[Union[ndarray, Tensor]]) – The evaluation data as a numpy.ndarray or torch.Tensor of shape (mini_batch x n_channels x …). If None, only uses train_fata for training. Default: None.

  • callbacks (List[TrainingCallbacks]) – A list of callbacks to use during training.