BaseNF

Abstract class

Base module for Normalizing Flows implementation

class pythae.models.normalizing_flows.BaseNF(model_config)[source]

Base Class from Normalizing flows

Parameters

model_config (BaseNFConfig) – The configuration setting the main parameters of the model.

forward(x, **kwargs)[source]

Main forward pass mapping the data towards the prior This function should output a ModelOutput instance gathering all the model outputs

Parameters

x (torch.Tensor) – The training data.

Returns

A ModelOutput instance providing the outputs of the model.

Return type

ModelOutput

inverse(y, **kwargs)[source]

Main inverse pass mapping the prior toward the data This function should output a ModelOutput instance gathering all the model outputs

Parameters

inputs (torch.Tensor) – Data from the prior.

Returns

A ModelOutput instance providing the outputs of the model.

Return type

ModelOutput

classmethod load_from_folder(dir_path)[source]

Class method to be used to load the model from a specific folder

Parameters

dir_path (str) – The path where the model should have been be saved.

Note

This function requires the folder to contain:

  • a model_config.json and a model.pt if no custom architectures were provided
save(dir_path)[source]

Method to save the model at a specific location. It saves, the model weights as a models.pt file along with the model config as a model_config.json file.

Parameters

dir_path (str) – The path where the model should be saved. If the path path does not exist a folder will be created at the provided location.

update()[source]

Method that allows model update during the training (at the end of a training epoch)

If needed, this method must be implemented in a child class.

By default, it does nothing.

class pythae.models.normalizing_flows.BaseNFConfig(input_dim=None)[source]

This is the Base Normalizing Flow config instance.

Parameters

input_dim (tuple) – The input data dimension. Default: None.