Source code for pyraug.models.rhvae.rhvae_sampler

import datetime
import logging
import os

import torch

from ..base.base_sampler import BaseSampler
from .rhvae_config import RHVAESamplerConfig
from .rhvae_model import RHVAE

logger = logging.getLogger(__name__)

# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]class RHVAESampler(BaseSampler): """Hamiltonian Monte Carlo Sampler class. This is an implementation of the Hamiltonian/Hybrid Monte Carlo sampler (https://en.wikipedia.org/wiki/Hamiltonian_Monte_Carlo) Args: model (RHVAE): The VAE model to sample from sampler_config (RHVAESamplerConfig): A HMCSamplerConfig instance containing the main parameters of the sampler. If None, a pre-defined configuration is used. Default: None """ def __init__(self, model: RHVAE, sampler_config: RHVAESamplerConfig = None): BaseSampler.__init__(self, model=model, sampler_config=sampler_config) self.sampler_config = sampler_config self.model.M_tens = self.model.M_tens.to(self.device) self.model.centroids_tens = self.model.centroids_tens.to(self.device) self.mcmc_steps_nbr = sampler_config.mcmc_steps_nbr self.n_lf = torch.tensor([sampler_config.n_lf]).to(self.device) self.eps_lf = torch.tensor([sampler_config.eps_lf]).to(self.device) self.beta_zero_sqrt = ( torch.tensor([sampler_config.beta_zero]).to(self.device).sqrt() ) self.log_pi = RHVAESampler.log_sqrt_det_G_inv self.grad_func = RHVAESampler.grad_log_prop
[docs] def sample(self, samples_number): """ HMC sampling with a RHVAE. The data is saved in the ``output_dir`` (folder passed in the :class:`~pyraug.models.base.base_config.BaseSamplerConfig` instance) in a folder named ``generation_YYYY-MM-DD_hh-mm-ss``. If ``output_dir`` is None, a folder named ``dummy_output_dir`` is created in this folder. Args: num_samples (int): The number of samples to generate """ assert samples_number > 0, "Provide a number of samples > 0" self._sampling_signature = ( str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-") ) sampling_dir = os.path.join( self.sampler_config.output_dir, f"generation_{self._sampling_signature}" ) if not os.path.exists(sampling_dir): os.makedirs(sampling_dir) logger.info( f"Created {sampling_dir}. " "Generated data and sampler config will be saved here.\n" ) full_batch_nbr = int(samples_number / self.sampler_config.batch_size) last_batch_samples_nbr = samples_number % self.sampler_config.batch_size generated_data = [] file_count = 0 data_count = 0 logger.info("Generation successfully launched !\n") for i in range(full_batch_nbr): samples = self.hmc_sampling(self.batch_size) x_gen = self.model.decoder(z=samples).detach() assert len(x_gen.shape) == 2 generated_data.append(x_gen) data_count += self.batch_size while data_count >= self.samples_per_save: self.save_data_batch( data=torch.cat(generated_data)[: self.samples_per_save], dir_path=sampling_dir, number_of_samples=self.samples_per_save, batch_idx=file_count, ) file_count += 1 data_count -= self.samples_per_save generated_data = list( torch.cat(generated_data)[self.samples_per_save :].unsqueeze(0) ) if last_batch_samples_nbr > 0: samples = self.hmc_sampling(last_batch_samples_nbr) x_gen = self.model.decoder(z=samples).detach() generated_data.append(x_gen) data_count += last_batch_samples_nbr while data_count >= self.samples_per_save: self.save_data_batch( data=torch.cat(generated_data)[: self.samples_per_save], dir_path=sampling_dir, number_of_samples=self.samples_per_save, batch_idx=file_count, ) file_count += 1 data_count -= self.samples_per_save generated_data = list( torch.cat(generated_data)[self.samples_per_save :].unsqueeze(0) ) if data_count > 0: self.save_data_batch( data=torch.cat(generated_data), dir_path=sampling_dir, number_of_samples=data_count, batch_idx=file_count, ) self.save(sampling_dir)
def hmc_sampling(self, n_samples): with torch.no_grad(): idx = torch.randint(len(self.model.centroids_tens), (n_samples,)) z0 = self.model.centroids_tens[idx] beta_sqrt_old = self.beta_zero_sqrt z = z0 for i in range(self.mcmc_steps_nbr): gamma = torch.randn_like(z, device=self.device) rho = gamma / self.beta_zero_sqrt H0 = -self.log_pi(z, self.model) + 0.5 * torch.norm(rho, dim=1) ** 2 # print(model.G_inv(z).det()) for k in range(self.n_lf): g = -self.grad_func(z, self.model).reshape( n_samples, self.model.latent_dim ) # step 1 rho_ = rho - (self.eps_lf / 2) * g # step 2 z = z + self.eps_lf * rho_ g = -self.grad_func(z, self.model).reshape( n_samples, self.model.latent_dim ) # g = (Sigma_inv @ (z - mu).T).reshape(n_samples, 2) # step 3 rho__ = rho_ - (self.eps_lf / 2) * g # tempering beta_sqrt = RHVAESampler.tempering( k + 1, self.n_lf, self.beta_zero_sqrt ) rho = (beta_sqrt_old / beta_sqrt) * rho__ beta_sqrt_old = beta_sqrt H = -self.log_pi(z, self.model) + 0.5 * torch.norm(rho, dim=1) ** 2 alpha = torch.exp(-H) / (torch.exp(-H0)) acc = torch.rand(n_samples).to(self.device) moves = (acc < alpha).type(torch.int).reshape(n_samples, 1) z = z * moves + (1 - moves) * z0 z0 = z return z @staticmethod def tempering(k, K, beta_zero_sqrt): beta_k = ((1 - 1 / beta_zero_sqrt) * (k / K) ** 2) + 1 / beta_zero_sqrt return 1 / beta_k @staticmethod def log_sqrt_det_G_inv(z, model): return torch.log(torch.sqrt(torch.det(model.G_inv(z))) + 1e-10) @staticmethod def grad_log_sqrt_det_G_inv(z, model): return ( -0.5 * torch.transpose(model.G(z), 1, 2) @ torch.transpose( ( -2 / (model.temperature ** 2) * (model.centroids_tens.unsqueeze(0) - z.unsqueeze(1)).unsqueeze(2) @ ( model.M_tens.unsqueeze(0) * torch.exp( -torch.norm( model.centroids_tens.unsqueeze(0) - z.unsqueeze(1), dim=-1, ) ** 2 / (model.temperature ** 2) ) .unsqueeze(-1) .unsqueeze(-1) ) ).sum(dim=1), 1, 2, ) ) @staticmethod def grad_log_prop(z, model): def grad_func(z, model): return RHVAESampler.grad_log_sqrt_det_G_inv(z, model) return grad_func(z, model)