Source code for pyraug.data.datasets

"""The pyraug's Datasets inherit from
:class:`torch.utils.data.Dataset` and must be used to convert the data before
training. As of today, it only contains the :class:`pyraug.data.BaseDatset` useful to train a
VAE model but other Datatsets will be added as models are added.
"""
import torch
from torch.utils.data import Dataset


[docs]class BaseDataset(Dataset): """This class is the Base class for pyraug's dataset A ``__getitem__`` is redefined and outputs a python dictionnary with the keys corresponding to `data`, `labels` etc... This Class should be used for any new data sets. """ def __init__(self, digits, labels, binarize=False): self.labels = labels.type(torch.float) if binarize: self.data = (torch.rand_like(digits) < digits).type(torch.float) else: self.data = digits.type(torch.float) def __len__(self): return len(self.data)
[docs] def __getitem__(self, index): """Generates one sample of data Args: index (int): The index of the data in the Dataset Returns: (dict): A dictionnary with the keys 'data' and 'labels' and corresponding torch.Tensor """ # Select sample X = self.data[index] # Load data and get label # X = torch.load('data/' + DATA + '.pt') y = self.labels[index] return {"data": X, "labels": y}