TrainingPipeline

class pyraug.pipelines.TrainingPipeline(data_loader=None, data_processor=None, model=None, optimizer=None, training_config=None)[source]

This Pipeline provides an end to end way to train your VAE model. The trained model will be saved in output_dir stated in the TrainingConfig. A folder training_YYYY-MM-DD_hh-mm-ss is created where checkpoints and final model will be saved. Checkpoints are saved in checkpoint_epoch_{epoch} folder (optimizer and training config saved as well to resume training if needed) and the final model is saved in a final_model folder. If output_dir is None, data is saved in dummy_output_dir/training_YYYY-MM-DD_hh-mm-ss is created.

Parameters
  • data_loader (Optional[BaseDataGetter]) – The data loader you want to use to load your data. This is usefull to get the data from a particular format and in a specific folder for instance. If None, the ImageGetterFromFolder is used. Default: None.

  • data_processor (Optional[DataProcessor]) – The data preprocessor you want to use to preprocess your data (e.g. normalization, reshaping, type conversion). If None, a basic DataProcessor is used (by default data is normalized such that the max value of each data is 1 and the min 0). Default: None.

  • model (Optional[BaseVAE]) – An instance of BaseVAE you want to train. If None, a default RHVAE model is used. Default: None.

  • optimizer (Optional[Optimizer]) – An instance of Optimizer used to train the model. If None we provide an instance of Adam optimizer. Default: None.

  • training_config (Optional[TrainingConfig]=None) – An instance of TrainingConfig stating the training parameters. If None, a default configuration is used.

Note

If you did not provide any data_processor, a default one will be used. By default it normalizes the data so that the max value of each data equals 1 and min value 0.

__call__(train_data, eval_data=None, log_output_dir=None)[source]

Launch the model training on the provided data.

Parameters
  • training_data (Union[str, ndarray, Tensor]) – The training data coming from a folder in which each file is a data or a numpy.ndarray or torch.Tensor of shape (mini_batch x n_channels x data_shape)

  • eval_data (Optional[Union[str, ndarray, Tensor]]) – The evaluation data coming from a folder in which each file is a data or a np.ndarray or torch.Tensor. If None, no evaluation data is used.