BaseNF¶
Abstract class
Base module for Normalizing Flows implementation
- class pythae.models.normalizing_flows.BaseNF(model_config)[source]¶
Base Class from Normalizing flows
- Parameters
model_config (BaseNFConfig) – The configuration setting the main parameters of the model.
- forward(x, **kwargs)[source]¶
Main forward pass mapping the data towards the prior This function should output a
ModelOutputinstance gathering all the model outputs- Parameters
x (torch.Tensor) – The training data.
- Returns
A ModelOutput instance providing the outputs of the model.
- Return type
- inverse(y, **kwargs)[source]¶
Main inverse pass mapping the prior toward the data This function should output a
ModelOutputinstance gathering all the model outputs- Parameters
inputs (torch.Tensor) – Data from the prior.
- Returns
A ModelOutput instance providing the outputs of the model.
- Return type
- classmethod load_from_folder(dir_path)[source]¶
Class method to be used to load the model from a specific folder
- Parameters
dir_path (str) – The path where the model should have been be saved.
Note
This function requires the folder to contain:
- a
model_config.jsonand amodel.ptif no custom architectures were provided
- save(dir_path)[source]¶
Method to save the model at a specific location. It saves, the model weights as a
models.ptfile along with the model config as amodel_config.jsonfile.- Parameters
dir_path (str) – The path where the model should be saved. If the path path does not exist a folder will be created at the provided location.