pyraug.trainer¶
- class pyraug.trainers.training_config.TrainingConfig(output_dir=None, batch_size=50, max_epochs=10000, learning_rate=0.0001, train_early_stopping=50, eval_early_stopping=None, steps_saving=None, seed=8, no_cuda=False, verbose=True)[source]¶
TrainingConfig
is the class in which all the training arguments are stored. This instance is then provided to aTrainer
instance which performs a model training.- Parameters
output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.
batch_size (int) – The number of training samples per batch. Default 50
max_epochs (int) – The maximal number of epochs for training. Default: 10000
learning_rate (int) – The learning rate applied to the Optimizer. Default: 1e-3
train_early_stopping (int) – The maximal number of epochs authorized without train loss improvement. If None no early stopping is performed. Default: 50
eval_early_stopping (int) – The maximal number of epochs authorized without eval loss improvement. If None no early stopping is performed. Default: None
steps_saving (int) – A model checkpoint will be saved every steps_saving epoch
seed (int) – The random seed for reprodicibility
no_cuda (bool) – Disable cuda training. Default: False
verbose (bool) – Allow verbosity
- classmethod from_dict(config_dict)¶
Creates a
BaseConfig
instance from a dictionnary- Parameters
config_dict (dict) – The Python dictionnary containing all the parameters
- Returns
The created instance
- Return type
BaseConfig
- classmethod from_json_file(json_path)¶
Creates a
BaseConfig
instance from a JSON config file- Parameters
json_path (str) – The path to the json file containing all the parameters
- Returns
The created instance
- Return type
BaseConfig
- save_json(dir_path, filename)¶
Saves a
.json
file from the dataclass
- class pyraug.trainers.Trainer(model, train_dataset, eval_dataset=None, training_config=None, optimizer=None)[source]¶
Trainer is the main class to perform model training.
- Parameters
model (BaseVAE) – The model to train
train_dataset (BaseDataset) – The training dataset of type
training_args (TrainingConfig) – The training arguments summarizing the main parameters used for training. If None, a basic training instance of
TrainingConfig
is used. Default: None.optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training. If None, a
Adam
optimizer is used. Default: None.
- save_model(dir_path)[source]¶
This method saves the final model along with the config files
- Parameters
dir_path (str) – The folder where the model and config files should be saved
- train(log_output_dir=None)[source]¶
This function is the main training function
- Parameters
log_output_dir (str) – The path in which the log will be stored