import os
from copy import deepcopy
from typing import Optional
import dill
import torch
import torch.nn as nn
from pyraug.customexception import BadInheritanceError
from pyraug.models.nn import BaseDecoder, BaseEncoder
from pyraug.models.nn.default_architectures import Decoder_MLP, Encoder_MLP
from .base_config import BaseModelConfig
[docs]class BaseVAE(nn.Module):
"""Base class for VAE based models.
Args:
model_config (BaseModelConfig): An instance of BaseModelConfig in which any model's parameters is
made available.
encoder (BaseEncoder): An instance of BaseEncoder (inheriting from `torch.nn.Module` which
plays the role of encoder. This argument allows you to use your own neural networks
architectures if desired. If None is provided, a simple Multi Layer Preception
(https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
decoder (BaseDecoder): An instance of BaseDecoder (inheriting from `torch.nn.Module` which
plays the role of encoder. This argument allows you to use your own neural networks
architectures if desired. If None is provided, a simple Multi Layer Preception
(https://en.wikipedia.org/wiki/Multilayer_perceptron) is used. Default: None.
.. note::
For high dimensional data we advice you to provide you own network architectures. With the
provided MLP you may end up with a ``MemoryError``.
"""
def __init__(
self,
model_config: BaseModelConfig,
encoder: Optional[BaseEncoder] = None,
decoder: Optional[BaseDecoder] = None,
):
nn.Module.__init__(self)
self.input_dim = model_config.input_dim
self.latent_dim = model_config.latent_dim
self.model_config = model_config
if encoder is None:
if model_config.input_dim is None:
raise AttributeError(
"No input dimension provided !"
"'input_dim' parameter of BaseModelConfig instance must be set to 'data_shape' where "
"the shape of the data is [mini_batch x data_shape]. Unable to build encoder "
"automatically"
)
encoder = Encoder_MLP(model_config)
self.model_config.uses_default_encoder = True
else:
self.model_config.uses_default_encoder = False
if decoder is None:
if model_config.input_dim is None:
raise AttributeError(
"No input dimension provided !"
"'input_dim' parameter of BaseModelConfig instance must be set to 'data_shape' where "
"the shape of the data is [mini_batch x data_shape]. Unable to build decoder"
"automatically"
)
decoder = Decoder_MLP(model_config)
self.model_config.uses_default_decoder = True
else:
self.model_config.uses_default_decoder = False
self.set_encoder(encoder)
self.set_decoder(decoder)
self.device = None
[docs] def forward(self, inputs):
"""Main forward pass outputing the VAE outputs
This function should output an model_output instance gathering all the model outputs
Args:
inputs (Dict[str, torch.Tensor]): The training data with labels, masks etc...
Returns:
(ModelOutput): The output of the model.
.. note::
The loss must be computed in this forward pass and accessed through
``loss = model_output.loss`` """
raise NotImplementedError()
[docs] def update(self):
"""Method that allows model update during the training.
If needed, this method must be implemented in a child class.
By default, it does nothing.
"""
pass
[docs] def save(self, dir_path):
"""Method to save the model at a specific location. It saves, the model weights as a
``models.pt`` file along with the model config as a ``model_config.json`` file. If the
model to save used custom encoder (resp. decoder) provided by the user, these are also saved as
``decoder.pkl`` (resp. ``decoder.pkl``).
Args:
dir_path (str): The path where the model should be saved. If the path
path does not exist a folder will be created at the provided location.
"""
model_path = dir_path
model_dict = {"model_state_dict": deepcopy(self.state_dict())}
if not os.path.exists(model_path):
try:
os.makedirs(model_path)
except FileNotFoundError as e:
raise e
self.model_config.save_json(model_path, "model_config")
# only save .pkl if custom architecture provided
if not self.model_config.uses_default_encoder:
with open(os.path.join(model_path, "encoder.pkl"), "wb") as fp:
dill.dump(self.encoder, fp)
if not self.model_config.uses_default_decoder:
with open(os.path.join(model_path, "decoder.pkl"), "wb") as fp:
dill.dump(self.decoder, fp)
torch.save(model_dict, os.path.join(model_path, "model.pt"))
@classmethod
def _load_model_config_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
if "model_config.json" not in file_list:
raise FileNotFoundError(
f"Missing model config file ('model_config.json') in"
f"{dir_path}... Cannot perform model building."
)
path_to_model_config = os.path.join(dir_path, "model_config.json")
model_config = BaseModelConfig.from_json_file(path_to_model_config)
return model_config
@classmethod
def _load_model_weights_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
if "model.pt" not in file_list:
raise FileNotFoundError(
f"Missing model weights file ('model.pt') file in"
f"{dir_path}... Cannot perform model building."
)
path_to_model_weights = os.path.join(dir_path, "model.pt")
try:
model_weights = torch.load(path_to_model_weights, map_location="cpu")
except RuntimeError:
RuntimeError(
"Enable to load model weights. Ensure they are saves in a '.pt' format."
)
if "model_state_dict" not in model_weights.keys():
raise KeyError(
"Model state dict is not available in 'model.pt' file. Got keys:"
f"{model_weights.keys()}"
)
model_weights = model_weights["model_state_dict"]
return model_weights
@classmethod
def _load_custom_encoder_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
if "encoder.pkl" not in file_list:
raise FileNotFoundError(
f"Missing encoder pkl file ('encoder.pkl') in"
f"{dir_path}... This file is needed to rebuild custom encoders."
" Cannot perform model building."
)
else:
with open(os.path.join(dir_path, "encoder.pkl"), "rb") as fp:
encoder = dill.load(fp)
return encoder
@classmethod
def _load_custom_decoder_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
if "decoder.pkl" not in file_list:
raise FileNotFoundError(
f"Missing decoder pkl file ('decoder.pkl') in"
f"{dir_path}... This file is needed to rebuild custom decoders."
" Cannot perform model building."
)
else:
with open(os.path.join(dir_path, "decoder.pkl"), "rb") as fp:
decoder = dill.load(fp)
return decoder
[docs] @classmethod
def load_from_folder(cls, dir_path):
"""Class method to be used to load the model from a specific folder
Args:
dir_path (str): The path where the model should have been be saved.
.. note::
This function requires the folder to contain:
a ``model_config.json`` and a ``model.pt`` if no custom architectures were
provided
or
a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
``decoder.pkl``) if a custom encoder (resp. decoder) was provided
"""
model_config = cls._load_model_config_from_folder(dir_path)
model_weights = cls._load_model_weights_from_folder(dir_path)
if not model_config.uses_default_encoder:
encoder = cls._load_custom_encoder_from_folder(dir_path)
else:
encoder = None
if not model_config.uses_default_decoder:
decoder = cls._load_custom_decoder_from_folder(dir_path)
else:
decoder = None
model = cls(model_config, encoder=encoder, decoder=decoder)
model.load_state_dict(model_weights)
return model
[docs] def set_encoder(self, encoder: BaseEncoder) -> None:
"""Set the encoder of the model"""
if not issubclass(type(encoder), BaseEncoder):
raise BadInheritanceError(
(
"Encoder must inherit from BaseEncoder class from "
"pyraug.models.base_architectures.BaseEncoder. Refer to documentation."
)
)
self.encoder = encoder
[docs] def set_decoder(self, decoder: BaseDecoder) -> None:
"""Set the decoder of the model"""
if not issubclass(type(decoder), BaseDecoder):
raise BadInheritanceError(
(
"Decoder must inherit from BaseDecoder class from "
"pyraug.models.base_architectures.BaseDecoder. Refer to documentation."
)
)
self.decoder = decoder