BaseMetric

class pyraug.models.nn.BaseMetric[source]

This is a base class for Metrics neural networks (only applicable for Riemannian based VAE)

forward(x)[source]

This function must be implemented in a child class. It takes the input data and returns (L_psi). If you decide to provide your own metric network, you must make your model inherit from this class by setting and the define your forward function as such:

class My_Metric(BaseMetric):

    def __init__(self):
        BaseMetric.__init__(self)
        # your code

    def forward(self, x):
        # your code
        return L
Parameters

x (torch.Tensor) – The input data that must be encoded

Returns

The \(L_{\psi}\) matrices of the metric

Return type

(torch.Tensor)