TrainingPipeline¶
- class pyraug.pipelines.TrainingPipeline(data_loader=None, data_processor=None, model=None, optimizer=None, training_config=None)[source]¶
This Pipeline provides an end to end way to train your VAE model. The trained model will be saved in
output_dir
stated in theTrainingConfig
. A foldertraining_YYYY-MM-DD_hh-mm-ss
is created where checkpoints and final model will be saved. Checkpoints are saved incheckpoint_epoch_{epoch}
folder (optimizer and training config saved as well to resume training if needed) and the final model is saved in afinal_model
folder. Ifoutput_dir
is None, data is saved indummy_output_dir/training_YYYY-MM-DD_hh-mm-ss
is created.- Parameters
data_loader (Optional[BaseDataGetter]) – The data loader you want to use to load your data. This is usefull to get the data from a particular format and in a specific folder for instance. If None, the
ImageGetterFromFolder
is used. Default: None.data_processor (Optional[DataProcessor]) – The data preprocessor you want to use to preprocess your data (e.g. normalization, reshaping, type conversion). If None, a basic
DataProcessor
is used (by default data is normalized such that the max value of each data is 1 and the min 0). Default: None.model (Optional[BaseVAE]) – An instance of
BaseVAE
you want to train. If None, a defaultRHVAE
model is used. Default: None.optimizer (Optional[Optimizer]) – An instance of
Optimizer
used to train the model. If None we provide an instance ofAdam
optimizer. Default: None.training_config (Optional[TrainingConfig]=None) – An instance of
TrainingConfig
stating the training parameters. If None, a default configuration is used.
Note
If you did not provide any data_processor, a default one will be used. By default it normalizes the data so that the max value of each data equals 1 and min value 0.
- __call__(train_data, eval_data=None, log_output_dir=None)[source]¶
Launch the model training on the provided data.
- Parameters
training_data (Union[str, ndarray, Tensor]) – The training data coming from a folder in which each file is a data or a
numpy.ndarray
ortorch.Tensor
of shape (mini_batch x n_channels x data_shape)eval_data (Optional[Union[str, ndarray, Tensor]]) – The evaluation data coming from a folder in which each file is a data or a np.ndarray or torch.Tensor. If None, no evaluation data is used.