pyraug.trainer

class pyraug.trainers.training_config.TrainingConfig(output_dir=None, batch_size=50, max_epochs=10000, learning_rate=0.0001, train_early_stopping=50, eval_early_stopping=None, steps_saving=None, seed=8, no_cuda=False, verbose=True)[source]

TrainingConfig is the class in which all the training arguments are stored. This instance is then provided to a Trainer instance which performs a model training.

Parameters
  • output_dir (str) – The directory where model checkpoints, configs and final model will be stored. Default: None.

  • batch_size (int) – The number of training samples per batch. Default 50

  • max_epochs (int) – The maximal number of epochs for training. Default: 10000

  • learning_rate (int) – The learning rate applied to the Optimizer. Default: 1e-3

  • train_early_stopping (int) – The maximal number of epochs authorized without train loss improvement. If None no early stopping is performed. Default: 50

  • eval_early_stopping (int) – The maximal number of epochs authorized without eval loss improvement. If None no early stopping is performed. Default: None

  • steps_saving (int) – A model checkpoint will be saved every steps_saving epoch

  • seed (int) – The random seed for reprodicibility

  • no_cuda (bool) – Disable cuda training. Default: False

  • verbose (bool) – Allow verbosity

classmethod from_dict(config_dict)

Creates a BaseConfig instance from a dictionnary

Parameters

config_dict (dict) – The Python dictionnary containing all the parameters

Returns

The created instance

Return type

BaseConfig

classmethod from_json_file(json_path)

Creates a BaseConfig instance from a JSON config file

Parameters

json_path (str) – The path to the json file containing all the parameters

Returns

The created instance

Return type

BaseConfig

save_json(dir_path, filename)

Saves a .json file from the dataclass

Parameters
  • dir_path (str) – path to the folder

  • filename (str) – the name of the file

to_dict()

Transforms object into a Python dictionnary

Returns

The dictionnary containing all the parameters

Return type

(dict)

to_json_string()

Transforms object into a JSON string

Returns

The JSON str containing all the parameters

Return type

(str)

class pyraug.trainers.Trainer(model, train_dataset, eval_dataset=None, training_config=None, optimizer=None)[source]

Trainer is the main class to perform model training.

Parameters
  • model (BaseVAE) – The model to train

  • train_dataset (BaseDataset) – The training dataset of type

  • training_args (TrainingConfig) – The training arguments summarizing the main parameters used for training. If None, a basic training instance of TrainingConfig is used. Default: None.

  • optimizer (Optimizer) – An instance of torch.optim.Optimizer used for training. If None, a Adam optimizer is used. Default: None.

eval_step()[source]

Perform an evaluation step

Returns

The evaluation loss

Return type

(torch.Tensor)

save_checkpoint(dir_path, epoch)[source]

Saves a checkpoint alowing to restart training from here

Parameters
  • dir_path (str) – The folder where the checkpoint should be saved

  • epochs_signature (int) – The epoch number

save_model(dir_path)[source]

This method saves the final model along with the config files

Parameters

dir_path (str) – The folder where the model and config files should be saved

train(log_output_dir=None)[source]

This function is the main training function

Parameters

log_output_dir (str) – The path in which the log will be stored

train_step()[source]

The trainer performs training loop over the train_loader.

Returns

The step training loss

Return type

(torch.Tensor)