BaseMetric¶
- class pyraug.models.nn.BaseMetric[source]¶
This is a base class for Metrics neural networks (only applicable for Riemannian based VAE)
- forward(x)[source]¶
This function must be implemented in a child class. It takes the input data and returns (L_psi). If you decide to provide your own metric network, you must make your model inherit from this class by setting and the define your forward function as such:
class My_Metric(BaseMetric): def __init__(self): BaseMetric.__init__(self) # your code def forward(self, x): # your code return L
- Parameters
x (torch.Tensor) – The input data that must be encoded
- Returns
The \(L_{\psi}\) matrices of the metric
- Return type