Base Neural Nets

The base neural nets from which all further neural nets must inherit. In particular, if you decide to provide your own ones make them inherit from them. See tutorials for further details.

class pythae.models.nn.BaseEncoder[source]

This is a base class for Encoders neural networks.

forward(x)[source]

This function must be implemented in a child class. It takes the input data and returns an instance of ModelOutput. If you decide to provide your own encoder network, you must make sure your model inherit from this class by setting and then defining your forward function as such:

>>> from pythae.models.nn import BaseEncoder
>>> from pythae.models.base.base_utils import ModelOutput
...
>>> class My_Encoder(BaseEncoder):
...
...     def __init__(self):
...         BaseEncoder.__init__(self)
...         # your code
...
...     def forward(self, x: torch.Tensor):
...         # your code
...         output = ModelOutput(
...             embedding=embedding,
...             log_covariance=log_var # for VAE based models
...         )
...         return output
Parameters

x (torch.Tensor) – The input data that must be encoded

Returns

The output of the encoder

Return type

output (ModelOutput)

class pythae.models.nn.BaseDecoder[source]

This is a base class for Decoders neural networks.

forward(z)[source]

This function must be implemented in a child class. It takes the input data and returns an instance of ModelOutput. If you decide to provide your own decoder network, you must make sure your model inherit from this class by setting and then defining your forward function as such:

>>> from pythae.models.nn import BaseDecoder
>>> from pythae.models.base.base_utils import ModelOutput
...
>>> class My_decoder(BaseDecoder):
...
...    def __init__(self):
...        BaseDecoder.__init__(self)
...        # your code
...
...    def forward(self, z: torch.Tensor):
...        # your code
...        output = ModelOutput(
...             reconstruction=reconstruction
...         )
...        return output
Parameters

z (torch.Tensor) – The latent data that must be decoded

Returns

The output of the decoder

Return type

output (ModelOutput)

Note

By convention, the reconstruction tensors should be in [0, 1] and of shape BATCH x channels x …

class pythae.models.nn.BaseMetric[source]

This is a base class for Metrics neural networks (only applicable for Riemannian based VAE)

forward(x)[source]

This function must be implemented in a child class. It takes the input data and returns an instance of ModelOutput. If you decide to provide your own metric network, you must make sure your model inherit from this class by setting and then defining your forward function as such:

>>> from pythae.models.nn import BaseMetric
>>> from pythae.models.base.base_utils import ModelOutput
...
>>> class My_Metric(BaseMetric):
...
...    def __init__(self):
...        BaseMetric.__init__(self)
...        # your code
...
...    def forward(self, x: torch.Tensor):
...        # your code
...        output = ModelOutput(
...             L=L # L matrices in the metric of  Riemannian based VAE (see docs)
...         )
...        return output
Parameters

x (torch.Tensor) – The input data that must be encoded

Returns

The output of the metric

Return type

output (ModelOutput)

class pythae.models.nn.BaseDiscriminator[source]

This is a base class for Discriminator neural networks.

forward(x)[source]

This function must be implemented in a child class. It takes the input data and returns an instance of ModelOutput. If you decide to provide your own disctriminator network, you must make sure your model inherit from this class by setting and then defining your forward function as such:

>>> from pythae.models.nn import BaseDiscriminator
>>> from pythae.models.base.base_utils import ModelOutput
...
>>> class My_Discriminator(BaseDiscriminator):
...
...     def __init__(self):
...         BaseDiscriminator.__init__(self)
...         # your code
...
...     def forward(self, x: torch.Tensor):
...         # your code
...         output = ModelOutput(
...             adversarial_cost=adversarial_cost
...         )
...         return output
Parameters

x (torch.Tensor) – The input data that must be encoded

Returns

The output of the encoder

Return type

output (ModelOutput)