Source code for pythae.models.normalizing_flows.maf.maf_model

import os

import torch
import torch.nn as nn

from ...base.base_utils import ModelOutput
from ..base import BaseNF
from ..layers import BatchNorm
from ..made import MADE, MADEConfig
from .maf_config import MAFConfig


[docs]class MAF(BaseNF): """Masked Autoregressive Flow. Args: model_config (MAFConfig): The MAF model configuration setting the main parameters of the model. """ def __init__(self, model_config: MAFConfig): BaseNF.__init__(self, model_config=model_config) self.net = [] self.m = {} self.model_config = model_config self.hidden_size = model_config.hidden_size self.model_name = "MAF" made_config = MADEConfig( input_dim=(self.input_dim,), output_dim=(self.input_dim,), hidden_sizes=[self.hidden_size] * self.model_config.n_hidden_in_made, degrees_ordering="sequential", ) for i in range(model_config.n_made_blocks): self.net.extend([MADE(made_config)]) if self.model_config.include_batch_norm: self.net.extend([BatchNorm(self.input_dim)]) self.net = nn.ModuleList(self.net)
[docs] def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """The input data is transformed toward the prior Args: inputs (torch.Tensor): An input tensor Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = x.reshape(x.shape[0], -1) sum_log_abs_det_jac = torch.zeros(x.shape[0]).to(x.device) for layer in self.net: layer_out = layer(x) if layer.__class__.__name__ == "MADE": mu, log_var = layer_out.mu, layer_out.log_var x = (x - mu) * (-log_var).exp() sum_log_abs_det_jac += -log_var.sum(dim=-1) # - alpha else: x = layer_out.out sum_log_abs_det_jac += layer_out.log_abs_det_jac x = x.flip(dims=(1,)) return ModelOutput(out=x, log_abs_det_jac=sum_log_abs_det_jac)
[docs] def inverse(self, y: torch.Tensor, **kwargs) -> ModelOutput: """The prior is transformed toward the input data Args: inputs (torch.Tensor): An input tensor Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ y = y.reshape(y.shape[0], -1) sum_log_abs_det_jac = torch.zeros(y.shape[0]).to(y.device) for layer in self.net[::-1]: y = y.flip(dims=(1,)) if layer.__class__.__name__ == "MADE": x = torch.zeros_like(y) for i in range(self.input_dim): layer_out = layer(x.clone()) mu, log_var = layer_out.mu, layer_out.log_var x[:, i] = y[:, i] * (log_var[:, i]).exp() + mu[:, i] sum_log_abs_det_jac += log_var[:, i] y = x else: layer_out = layer.inverse(y) y = layer_out.out sum_log_abs_det_jac += layer_out.log_abs_det_jac return ModelOutput(out=y, log_abs_det_jac=sum_log_abs_det_jac)