Source code for pyraug.models.base.base_sampler

import logging
import os

import torch

from .base_config import BaseSamplerConfig
from .base_vae import BaseVAE

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class BaseSampler: """Base class for sampler used to generate from the VAEs models Args: model (BaseVAE): The vae model to sample from. sampler_config (BaseSamplerConfig): An instance of BaseSamplerConfig in which any sampler's parameters is made available. If None a default configuration is used. Default: None """ def __init__(self, model: BaseVAE, sampler_config: BaseSamplerConfig = None): if sampler_config.output_dir is None: output_dir = "dummy_output_dir" sampler_config.output_dir = output_dir if not os.path.exists(sampler_config.output_dir): os.makedirs(sampler_config.output_dir) logger.info( f"Created {sampler_config.output_dir} folder since did not exist.\n" ) self.model = model self.sampler_config = sampler_config self.batch_size = sampler_config.batch_size self.samples_per_save = self.sampler_config.samples_per_save self.device = ( "cuda" if torch.cuda.is_available() and not sampler_config.no_cuda else "cpu" ) self.model.to(self.device)
[docs] def sample(self, num_samples): """Main sampling function of the samplers. The data is saved in the ``output_dir/generation_`` folder passed in the `~pyraug.models.model_config.SamplerConfig` instance. If ``output_dir`` if None, a folder named ``dummy_output_dir`` is created in this folder. Args: num_samples (int): The number of samples to generate """ raise NotImplementedError()
[docs] def save(self, dir_path): """Method to save the sampler config. The config is saved a as ``sampler_config.json`` file in ``dir_path``""" self.sampler_config.save_json(dir_path, "sampler_config")
[docs] def save_data_batch(self, data, dir_path, number_of_samples, batch_idx): """ Method to save a batch of generated data. The data will be saved in the ``dir_path`` folder. The batch of data is saved in a file named ``generated_data_{number_of_samples}_{batch_idx}.pt`` Args: data (torch.Tensor): The data to save dir_path (str): The folder where the data and config file must be saved batch_idx (int): The batch idx .. note:: You can then easily reload the generated data using .. code-block: >>> import torch >>> import os >>> data = torch.load( ... os.path.join( ... 'dir_path', 'generated_data_{number_of_samples}_{batch_idx}.pt')) """ if not os.path.exists(dir_path): os.makedirs(dir_path) torch.save( data, os.path.join( dir_path, f"generated_data_{number_of_samples}_{batch_idx}.pt" ), )