Source code for pythae.models.nn.base_architectures

import torch
import torch.nn as nn


[docs]class BaseEncoder(nn.Module): """This is a base class for Encoders neural networks.""" def __init__(self): nn.Module.__init__(self)
[docs] def forward(self, x): r"""This function must be implemented in a child class. It takes the input data and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput`. If you decide to provide your own encoder network, you must make sure your model inherit from this class by setting and then defining your forward function as such: .. code-block:: >>> from pythae.models.nn import BaseEncoder >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Encoder(BaseEncoder): ... ... def __init__(self): ... BaseEncoder.__init__(self) ... # your code ... ... def forward(self, x: torch.Tensor): ... # your code ... output = ModelOutput( ... embedding=embedding, ... log_covariance=log_var # for VAE based models ... ) ... return output Parameters: x (torch.Tensor): The input data that must be encoded Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the encoder """ raise NotImplementedError()
[docs]class BaseDecoder(nn.Module): """This is a base class for Decoders neural networks.""" def __init__(self): nn.Module.__init__(self)
[docs] def forward(self, z: torch.Tensor): r"""This function must be implemented in a child class. It takes the input data and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput`. If you decide to provide your own decoder network, you must make sure your model inherit from this class by setting and then defining your forward function as such: .. code-block:: >>> from pythae.models.nn import BaseDecoder >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_decoder(BaseDecoder): ... ... def __init__(self): ... BaseDecoder.__init__(self) ... # your code ... ... def forward(self, z: torch.Tensor): ... # your code ... output = ModelOutput( ... reconstruction=reconstruction ... ) ... return output Parameters: z (torch.Tensor): The latent data that must be decoded Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the decoder .. note:: By convention, the reconstruction tensors should be in [0, 1] and of shape BATCH x channels x ... """ raise NotImplementedError()
[docs]class BaseMetric(nn.Module): """This is a base class for Metrics neural networks (only applicable for Riemannian based VAE) """ def __init__(self): nn.Module.__init__(self)
[docs] def forward(self, x): r"""This function must be implemented in a child class. It takes the input data and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput`. If you decide to provide your own metric network, you must make sure your model inherit from this class by setting and then defining your forward function as such: .. code-block:: >>> from pythae.models.nn import BaseMetric >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Metric(BaseMetric): ... ... def __init__(self): ... BaseMetric.__init__(self) ... # your code ... ... def forward(self, x: torch.Tensor): ... # your code ... output = ModelOutput( ... L=L # L matrices in the metric of Riemannian based VAE (see docs) ... ) ... return output Parameters: x (torch.Tensor): The input data that must be encoded Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the metric """ raise NotImplementedError()
[docs]class BaseDiscriminator(nn.Module): """This is a base class for Discriminator neural networks.""" def __init__(self): nn.Module.__init__(self)
[docs] def forward(self, x): r"""This function must be implemented in a child class. It takes the input data and returns an instance of :class:`~pythae.models.base.base_utils.ModelOutput`. If you decide to provide your own disctriminator network, you must make sure your model inherit from this class by setting and then defining your forward function as such: .. code-block:: >>> from pythae.models.nn import BaseDiscriminator >>> from pythae.models.base.base_utils import ModelOutput ... >>> class My_Discriminator(BaseDiscriminator): ... ... def __init__(self): ... BaseDiscriminator.__init__(self) ... # your code ... ... def forward(self, x: torch.Tensor): ... # your code ... output = ModelOutput( ... adversarial_cost=adversarial_cost ... ) ... return output Parameters: x (torch.Tensor): The input data that must be encoded Returns: output (~pythae.models.base.base_utils.ModelOutput): The output of the encoder """ raise NotImplementedError()