import datetime
import logging
import os
from copy import deepcopy
from typing import Any, Dict, Optional
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from pyraug.customexception import ModelError
from pyraug.data.datasets import BaseDataset
from pyraug.models import BaseVAE
from pyraug.trainers.trainer_utils import set_seed
from pyraug.trainers.training_config import TrainingConfig
logger = logging.getLogger(__name__)
# make it print to the console.
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
[docs]class Trainer:
"""Trainer is the main class to perform model training.
Args:
model (BaseVAE): The model to train
train_dataset (BaseDataset): The training dataset of type :class:`~pyraug.`
training_args (TrainingConfig): The training arguments summarizing the main parameters used
for training. If None, a basic training instance of :class:`TrainingConfig` is used.
Default: None.
optimizer (~torch.optim.Optimizer): An instance of `torch.optim.Optimizer` used for
training. If None, a :class:`~torch.optim.Adam` optimizer is used. Default: None.
"""
def __init__(
self,
model: BaseVAE,
train_dataset: BaseDataset,
eval_dataset: Optional[BaseDataset] = None,
training_config: Optional[TrainingConfig] = None,
optimizer: Optional[torch.optim.Optimizer] = None,
):
if training_config is None:
training_config = TrainingConfig()
if training_config.output_dir is None:
output_dir = "dummy_output_dir"
training_config.output_dir = output_dir
if not os.path.exists(training_config.output_dir):
os.makedirs(training_config.output_dir)
logger.info(
f"Created {training_config.output_dir} folder since did not exist.\n"
)
self.training_config = training_config
set_seed(self.training_config.seed)
device = (
"cuda"
if torch.cuda.is_available() and not training_config.no_cuda
else "cpu"
)
# place model on device
model = model.to(device)
model.device = device
# set optimizer
if optimizer is None:
optimizer = self.set_default_optimizer(model)
else:
optimizer = self._set_optimizer_on_device(optimizer, device)
self.train_dataset = train_dataset
self.eval_dataset = eval_dataset
self.model = model
self.optimizer = optimizer
self.device = device
# set early stopping flags
self._set_earlystopping_flags(train_dataset, eval_dataset, training_config)
# Define the loaders
train_loader = self.get_train_dataloader(train_dataset)
if eval_dataset is not None:
eval_loader = self.get_eval_dataloader(eval_dataset)
else:
eval_loader = None
self.train_loader = train_loader
self.eval_loader = eval_loader
def get_train_dataloader(
self, train_dataset: BaseDataset
) -> torch.utils.data.DataLoader:
return DataLoader(
dataset=train_dataset,
batch_size=self.training_config.batch_size,
shuffle=True,
)
def get_eval_dataloader(
self, eval_dataset: BaseDataset
) -> torch.utils.data.DataLoader:
return DataLoader(
dataset=eval_dataset,
batch_size=self.training_config.batch_size,
shuffle=False,
)
def set_default_optimizer(self, model: BaseVAE) -> torch.optim.Optimizer:
optimizer = optim.Adam(
model.parameters(), lr=self.training_config.learning_rate
)
return optimizer
def _run_model_sanity_check(self, model, train_dataset):
try:
train_dataset = self._set_inputs_to_device(train_dataset[:2])
model(train_dataset)
except Exception as e:
raise ModelError(
"Error when calling forward method from model. Potential issues: \n"
" - Wrong model architecture -> check encoder, decoder and metric architecture if "
"you provide yours \n"
" - The data input dimension provided is wrong -> when no encoder, decoder or metric "
"provided, a network is built automatically but requires the shape of the flatten "
"input data.\n"
f"Exception raised: {type(e)} with message: " + str(e)
) from e
#
def _set_earlystopping_flags(self, train_dataset, eval_dataset, training_config):
# Initialize early_stopping flags
self.make_eval_early_stopping = False
self.make_train_early_stopping = False
if training_config.train_early_stopping is not None:
self.make_train_early_stopping = True
# Check if eval_dataset is provided
if eval_dataset is not None and training_config.eval_early_stopping is not None:
self.make_eval_early_stopping = True
# By default we make the early stopping on evaluation dataset
self.make_train_early_stopping = False
def _set_optimizer_on_device(self, optim, device):
for param in optim.state.values():
# Not sure there are any global tensors in the state dict
if isinstance(param, torch.Tensor):
param.data = param.data.to(device)
if param._grad is not None:
param._grad.data = param._grad.data.to(device)
elif isinstance(param, dict):
for subparam in param.values():
if isinstance(subparam, torch.Tensor):
subparam.data = subparam.data.to(device)
if subparam._grad is not None:
subparam._grad.data = subparam._grad.data.to(device)
return optim
def _set_inputs_to_device(self, inputs: Dict[str, Any]):
inputs_on_device = inputs
if self.device == "cuda":
cuda_inputs = dict.fromkeys(inputs)
for key in inputs.keys():
if torch.is_tensor(inputs[key]):
cuda_inputs[key] = inputs[key].cuda()
else:
cuda_inputs = inputs[key]
inputs_on_device = cuda_inputs
return inputs_on_device
[docs] def train(self, log_output_dir: str = None):
"""This function is the main training function
Args:
log_output_dir (str): The path in which the log will be stored
"""
# run sanity check on the model
self._run_model_sanity_check(self.model, self.train_dataset)
logger.info("Model passed sanity check !\n")
self._training_signature = (
str(datetime.datetime.now())[0:19].replace(" ", "_").replace(":", "-")
)
training_dir = os.path.join(
self.training_config.output_dir, f"training_{self._training_signature}"
)
if not os.path.exists(training_dir):
os.makedirs(training_dir)
logger.info(
f"Created {training_dir}. \n"
"Training config, checkpoints and final model will be saved here.\n"
)
log_verbose = False
# set up log file
if log_output_dir is not None:
log_dir = log_output_dir
log_verbose = True
# if dir does not exist create it
if not os.path.exists(log_dir):
os.makedirs(log_dir)
logger.info(f"Created {log_dir} folder since did not exists.")
logger.info("Training logs will be recodered here.\n")
logger.info(" -> Training can be monitored here.\n")
# create and set logger
log_name = f"training_logs_{self._training_signature}"
file_logger = logging.getLogger(log_name)
file_logger.setLevel(logging.INFO)
f_handler = logging.FileHandler(
os.path.join(log_dir, f"training_logs_{self._training_signature}.log")
)
f_handler.setLevel(logging.INFO)
file_logger.addHandler(f_handler)
# Do not output logs in the console
file_logger.propagate = False
file_logger.info("Training started !\n")
file_logger.info(
f"Training params:\n - max_epochs: {self.training_config.max_epochs}\n"
f" - train es: {self.training_config.train_early_stopping}\n"
f" - eval es: {self.training_config.eval_early_stopping}\n"
f" - batch_size: {self.training_config.batch_size}\n"
f" - checkpoint saving every {self.training_config.steps_saving}\n"
)
file_logger.info(f"Model Architecture: {self.model}\n")
file_logger.info(f"Optimizer: {self.optimizer}\n")
logger.info("Successfully launched training !")
# set best losses for early stopping
best_train_loss = 1e10
best_eval_loss = 1e10
epoch_es_train = 0
epoch_es_eval = 0
for epoch in range(1, self.training_config.max_epochs):
epoch_train_loss = self.train_step()
if self.eval_dataset is not None:
epoch_eval_loss = self.eval_step()
# early stopping
if self.make_eval_early_stopping:
if epoch_eval_loss < best_eval_loss:
epoch_es_eval = 0
best_eval_loss = epoch_eval_loss
else:
epoch_es_eval += 1
if (
epoch_es_eval >= self.training_config.eval_early_stopping
and log_verbose
):
logger.info(
f"Training ended at epoch {epoch}! "
f" Eval loss did not improve for {epoch_es_eval} epochs."
)
file_logger.info(
f"Training ended at epoch {epoch}! "
f" Eval loss did not improve for {epoch_es_eval} epochs."
)
break
elif self.make_train_early_stopping:
if epoch_train_loss < best_train_loss:
epoch_es_train = 0
best_train_loss = epoch_train_loss
else:
epoch_es_train += 1
if (
epoch_es_train >= self.training_config.train_early_stopping
and log_verbose
):
logger.info(
f"Training ended at epoch {epoch}! "
f" Train loss did not improve for {epoch_es_train} epochs."
)
file_logger.info(
f"Training ended at epoch {epoch}! "
f" Train loss did not improve for {epoch_es_train} epochs."
)
break
# save checkpoints
if (
self.training_config.steps_saving is not None
and epoch % self.training_config.steps_saving == 0
):
self.save_checkpoint(dir_path=training_dir, epoch=epoch)
logger.info(f"Saved checkpoint at epoch {epoch}\n")
if log_verbose:
file_logger.info(f"Saved checkpoint at epoch {epoch}\n")
if log_verbose and epoch % 10 == 0:
if self.eval_dataset is not None:
if self.make_eval_early_stopping:
file_logger.info(
f"Epoch {epoch} / {self.training_config.max_epochs}\n"
f"- Current Train loss: {epoch_train_loss:.2f}\n"
f"- Current Eval loss: {epoch_eval_loss:.2f}\n"
f"- Eval Early Stopping: {epoch_es_eval}/{self.training_config.eval_early_stopping}"
f" (Best: {best_eval_loss:.2f})\n"
)
elif self.make_train_early_stopping:
file_logger.info(
f"Epoch {epoch} / {self.training_config.max_epochs}\n"
f"- Current Train loss: {epoch_train_loss:.2f}\n"
f"- Current Eval loss: {epoch_eval_loss:.2f}\n"
f"- Train Early Stopping: {epoch_es_train}/{self.training_config.train_early_stopping}"
f" (Best: {best_train_loss:.2f})\n"
)
else:
file_logger.info(
f"Epoch {epoch} / {self.training_config.max_epochs}\n"
f"- Current Train loss: {epoch_train_loss:.2f}\n"
f"- Current Eval loss: {epoch_eval_loss:.2f}\n"
)
else:
if self.make_train_early_stopping:
file_logger.info(
f"Epoch {epoch} / {self.training_config.max_epochs}\n"
f"- Current Train loss: {epoch_train_loss:.2f}\n"
f"- Train Early Stopping: {epoch_es_train}/{self.training_config.train_early_stopping}"
f" (Best: {best_train_loss:.2f})\n"
)
else:
file_logger.info(
f"Epoch {epoch} / {self.training_config.max_epochs}\n"
f"- Current Train loss: {epoch_train_loss:.2f}\n"
)
final_dir = os.path.join(training_dir, "final_model")
self.save_model(dir_path=final_dir)
logger.info("----------------------------------")
logger.info("Training ended!")
logger.info(f"Saved final model in {final_dir}")
[docs] def eval_step(self):
"""Perform an evaluation step
Returns:
(torch.Tensor): The evaluation loss
"""
self.model.eval()
epoch_loss = 0
for (batch_idx, inputs) in enumerate(self.eval_loader):
inputs = self._set_inputs_to_device(inputs)
model_output = self.model(inputs)
loss = model_output.loss
epoch_loss += loss.item()
epoch_loss /= len(self.eval_loader)
return epoch_loss
[docs] def train_step(self):
"""The trainer performs training loop over the train_loader.
Returns:
(torch.Tensor): The step training loss
"""
# set model in train model
self.model.train()
epoch_loss = 0
for (batch_idx, inputs) in enumerate(self.train_loader):
inputs = self._set_inputs_to_device(inputs)
self.optimizer.zero_grad()
model_output = self.model(inputs)
loss = model_output.loss
loss.backward()
self.optimizer.step()
epoch_loss += loss.item()
# Allows model updates if needed
self.model.update()
epoch_loss /= len(self.train_loader)
return epoch_loss
[docs] def save_model(self, dir_path):
"""This method saves the final model along with the config files
Args:
dir_path (str): The folder where the model and config files should be saved
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
# save model
self.model.save(dir_path)
# save training config
self.training_config.save_json(dir_path, "training_config")
[docs] def save_checkpoint(self, dir_path, epoch: int):
"""Saves a checkpoint alowing to restart training from here
Args:
dir_path (str): The folder where the checkpoint should be saved
epochs_signature (int): The epoch number"""
checkpoint_dir = os.path.join(dir_path, f"checkpoint_epoch_{epoch}")
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
# save optimizer
torch.save(
deepcopy(self.optimizer.state_dict()),
os.path.join(checkpoint_dir, "optimizer.pt"),
)
# save model
self.model.save(checkpoint_dir)
# save training config
self.training_config.save_json(checkpoint_dir, "training_config")