Source code for pythae.models.normalizing_flows.planar_flow.planar_flow_model

import os

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...base.base_utils import ModelOutput
from ..base import BaseNF
from .planar_flow_config import PlanarFlowConfig

ACTIVATION = {"elu": F.elu, "tanh": torch.tanh, "linear": lambda x: x}

ACTIVATION_DERIVATIVES = {
    "elu": lambda x: torch.ones_like(x) * (x >= 0) + torch.exp(x) * (x < 0),
    "tanh": lambda x: 1 - torch.tanh(x) ** 2,
    "linear": lambda x: 1,
}


[docs]class PlanarFlow(BaseNF): """Planar Flow model. Args: model_config (PlanarFlowConfig): The PlanarFlow model configuration setting the main parameters of the model. """ def __init__(self, model_config: PlanarFlowConfig): BaseNF.__init__(self, model_config) self.w = nn.Parameter(torch.randn(1, self.input_dim)) self.u = nn.Parameter(torch.randn(1, self.input_dim)) self.b = nn.Parameter(torch.randn(1)) self.activation = ACTIVATION[model_config.activation] self.activation_derivative = ACTIVATION_DERIVATIVES[model_config.activation] self.model_name = "PlanarFlow" nn.init.normal_(self.w) nn.init.normal_(self.u) nn.init.normal_(self.b)
[docs] def forward(self, x: torch.Tensor, **kwargs) -> ModelOutput: """The input data is transformed toward the prior Args: inputs (torch.Tensor): An input tensor Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ x = x.reshape(x.shape[0], -1) lin = x @ self.w.T + self.b # [B x 1] f = x + self.u * self.activation(lin) # [B x input_dim] phi = self.activation_derivative(lin) @ self.w # [B x input_dim] log_det = torch.log(torch.abs(1 + phi @ self.u.T) + 1e-4).squeeze() # [B] output = ModelOutput(out=f, log_abs_det_jac=log_det) return output