Source code for pythae.models.normalizing_flows.made.made_model

import os

import numpy as np
import torch
import torch.nn as nn

from pythae.models.base.base_utils import ModelOutput

from ..base import BaseNF
from ..layers import MaskedLinear
from .made_config import MADEConfig


[docs]class MADE(BaseNF): """Masked Autoencoder model Args: model_config (MADEConfig): The MADE model configuration setting the main parameters of the model """ def __init__(self, model_config: MADEConfig): BaseNF.__init__(self, model_config=model_config) self.net = [] self.m = {} self.model_config = model_config self.input_dim = np.prod(model_config.input_dim) self.output_dim = np.prod(model_config.output_dim) self.hidden_sizes = model_config.hidden_sizes self.model_name = "MADE" if model_config.input_dim is None: raise AttributeError( "No input dimension provided !" "'input_dim' parameter of MADEConfig instance must be set to 'data_shape' " "where the shape of the data is (C, H, W ..)]. Unable to build network" "automatically" ) if model_config.output_dim is None: raise AttributeError( "No input dimension provided !" "'output_dim' parameter of MADEConfig instance must be set to 'data_shape' " "where the shape of the data is (C, H, W ..)]. Unable to build network" "automatically" ) hidden_sizes = [self.input_dim] + model_config.hidden_sizes + [self.output_dim] masks = self._make_mask(ordering=self.model_config.degrees_ordering) for inp, out, mask in zip(hidden_sizes[:-1], hidden_sizes[1:-1], masks[:-1]): self.net.extend([MaskedLinear(inp, out, mask), nn.ReLU()]) # outputs mean and logvar self.net.extend( [ MaskedLinear( self.hidden_sizes[-1], 2 * self.output_dim, masks[-1].repeat(2, 1) ) ] ) self.net = nn.Sequential(*self.net) def _make_mask(self, ordering="sequential"): # Get degrees for mask creation if ordering == "sequential": self.m[-1] = torch.arange(self.input_dim) for i in range(len(self.hidden_sizes)): self.m[i] = torch.arange(self.hidden_sizes[i]) % (self.input_dim - 1) else: self.m[-1] = torch.randperm(self.input_dim) for i in range(len(self.hidden_sizes)): self.m[i] = torch.randint( self.m[-1].min(), self.input_dim - 1, (self.hidden_sizes[i],) ) masks = [] for i in range(len(self.hidden_sizes)): masks += [(self.m[i].unsqueeze(-1) >= self.m[i - 1].unsqueeze(0)).float()] masks.append( ( self.m[len(self.hidden_sizes) - 1].unsqueeze(0) < self.m[-1].unsqueeze(-1) ).float() ) return masks
[docs] def forward(self, x: torch.tensor, **kwargs) -> ModelOutput: """The input data is transformed toward the prior Args: inputs (torch.Tensor): An input tensor Returns: ModelOutput: An instance of ModelOutput containing all the relevant parameters """ net_output = self.net(x.reshape(x.shape[0], -1)) mu = net_output[:, : self.input_dim] log_var = net_output[:, self.input_dim :] return ModelOutput(mu=mu, log_var=log_var)