import sys
import os
import json
import logging
import numpy as np
import pythae.models.base.base_utils
import torch
import torch.nn.functional as F
from typing import Tuple, List, Dict, Optional
from copy import deepcopy
from dataclasses import asdict
from transformers import AutoTokenizer, AutoModelForCausalLM
from pythae.trainers import BaseTrainerConfig
# from pythae.models.nn import BaseEncoder, BaseDecoder
from pythae.models.base.base_config import BaseAEConfig, EnvironmentConfig
from pythae.models.base.base_utils import ModelOutput
from pythae.data.datasets import BaseDataset
from torch import Tensor
# from torch.utils.tensorboard import SummaryWriter
from pythae.models.vae import VAE, VAEConfig
from pythae.trainers.training_callbacks import TrainingCallback
from pythae.models.base.base_utils import hf_hub_is_available
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.sparse import densify_w_padding
logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)
model_card_template = """---
language: en
tags:
- langvae
license: apache-2.0
---
### Downloading this model from the Hub
This model was trained with {clsname}. It can be downloaded or reloaded using the method `load_from_hf_hub`
```python
>>> from langvae import {clsname}
>>> model = {clsname}.load_from_hf_hub(hf_hub_path="your_hf_username/repo_name")
```
"""
[docs]@torch.compile
def vae_nll_loss(recon_x: Tensor,
x: Tensor,
mu: Tensor,
log_var: Tensor,
z: Tensor,
pad_token_id: int,
beta: float,
target_kl: float) -> Tuple[Tensor, Tensor, Tensor]:
"""
Calculates the negative log-likelihood (NLL) loss for a Variational Autoencoder (VAE).
Args:
recon_x (Tensor): The reconstructed input tensor.
x (Tensor): The original input tensor.
mu (Tensor): The mean of the latent variable distribution.
log_var (Tensor): The logarithm of the variance of the latent variable distribution.
z (Tensor): The latent variable tensor.
pad_token_id (int): The padding token ID for the input sequence.
beta (float): A hyperparameter that controls the trade-off between reconstruction loss and KL divergence.
target_kl (float): A target value for the KL divergence (cut-off).
Returns:
Tuple[Tensor, Tensor, Tensor]:
- Total NLL loss (reconstruction loss + KL divergence).
- Average reconstruction loss.
- Average KL divergence.
"""
# x = torch.squeeze(x).to(recon_x.device)
# len = min(x.shape[1], recon_x.shape[1])
# # print(f"X [{x.shape[1]}], X' [{recon_x.shape[1]}]")
# recon_x = recon_x[:, :len, :]
if (x.layout == torch.sparse_coo):
x_tok_ids = densify_w_padding(x, pad_token_id)
# x_tok_ids = [x[i].coalesce().indices()[1][:len] for i in range(x.shape[0])]
# x_tok_ids = torch.stack([
# torch.cat([tok_ids, torch.tensor([pad_token_id] * int(tok_ids.shape[0] < len) +
# [0] * max(len - tok_ids.shape[0] - 1, 0),
# dtype=torch.int64, device=x.device)
# ])
# for tok_ids in x_tok_ids
# ])
mask = (x_tok_ids != pad_token_id).to(torch.int8)
else:
x_tok_ids = x.argmax(dim=-1)
mask = (x_tok_ids != pad_token_id).to(torch.int8)
recon_loss = (F.nll_loss(torch.log(recon_x).view(recon_x.shape[0] * recon_x.shape[1], recon_x.shape[2]),
x_tok_ids.view(recon_x.shape[0] * recon_x.shape[1]),
reduction="none").sum(dim=-1) * mask).sum(dim=-1) / x.shape[0]
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
kl_mask = (KLD > target_kl).float()
KLD = beta * (kl_mask * KLD)
# print(f"recon x: {recon_x.mean(dim=0)}")
# print(f"recon loss: {recon_loss.mean(dim=0)}, ")
return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0)
@torch.jit.script
def vae_nll_loss_supervised(recon_x: Tensor,
x: Tensor,
mu: Tensor,
log_var: Tensor,
z: Tensor,
pad_token_id: int,
beta: float,
target_kl: float,
num_annotations: int) -> Tuple[Tensor, Tensor, Tensor]:
x = torch.squeeze(x).to(recon_x.device)
x_split = x.chunk(num_annotations + 1, dim=-1)
recon_x_split = recon_x.chunk(num_annotations + 1, dim=-1)
x_tok_ids = torch.argmax(x_split[0], dim=-1)
mask = (x_tok_ids != pad_token_id).to(torch.int8)
rec_x = recon_x_split[0]
recon_loss = (F.nll_loss(torch.log(rec_x).view(rec_x.shape[0] * rec_x.shape[1], rec_x.shape[2]),
x_tok_ids.view(rec_x.shape[0] * rec_x.shape[1]),
reduction="none").sum(dim=-1) * mask).sum(dim=-1) / x_split[0].shape[0]
for lbl_split, rec_lbl in zip(x_split[1:], recon_x_split[1:]):
x_lbl_ids = torch.argmax(lbl_split, dim=-1)
mask = (x_lbl_ids != pad_token_id).to(torch.int8)
recon_loss += (F.nll_loss(torch.log(rec_lbl).view(rec_lbl.shape[0] * rec_lbl.shape[1], rec_lbl.shape[2]),
x_lbl_ids.view(rec_lbl.shape[0] * rec_lbl.shape[1]),
reduction="none").sum(dim=-1) * mask).sum(dim=-1) / lbl_split.shape[0]
KLD = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=-1)
kl_mask = (KLD > target_kl).float()
KLD = beta * (kl_mask * KLD)
return (recon_loss + KLD).mean(dim=0), recon_loss.mean(dim=0), KLD.mean(dim=0)
[docs]class LangVAE(VAE):
"""
A language-oriented Variational Autoencoder (VAE) that can be used for text generation.
Args:
model_config (VAEConfig): The configuration of the VAE model.
encoder (Optional[SentenceEncoder]): Language encoder model that processes input data and returns sentence embeddings.
decoder (Optional[SentenceDecoder]): Language decoder model that generates text from latent representations.
"""
# loss_writer = SummaryWriter()
def __init__(
self,
model_config: VAEConfig,
encoder: Optional[SentenceEncoder],
decoder: Optional[SentenceDecoder]
):
super().__init__(model_config=model_config, encoder=encoder, decoder=decoder)
self.cur_beta: float = 0.0
self.target_kl = 1.0
# Logging losses
self.debug = False
self._dbg_counter = 0
self._loss_agg = [0.0, 0.0]
[docs] def forward(self, inputs: BaseDataset, **kwargs):
"""
The VAE model
Args:
inputs (BaseDataset): The training dataset with labels
Returns:
ModelOutput: An instance of ModelOutput containing all the relevant parameters
"""
x = inputs["data"]
x_annot = inputs.keys() - {"data", "input_ids", "attention_mask"}
cvars = None
if (x_annot):
cvars = {annot: inputs[annot] for annot in x_annot}
encoder_output = self.encoder(x, cvars)
mu, log_var = encoder_output.embedding, encoder_output.log_covariance
cvars_emb = encoder_output.cvars_embedding
std = torch.exp(0.5 * log_var)
z, eps = self._sample_gauss(mu, std)
z = torch.cat([z] + cvars_emb, dim=-1) if (cvars and self.decoder.conditional) else z
recon_x = self.decoder(z, max_len=x.shape[1])["reconstruction"]
loss, recon_loss, kld = self.loss_function(recon_x, x, mu, log_var, z)
output = ModelOutput(
recon_loss=recon_loss,
reg_loss=kld,
loss=loss,
recon_x=recon_x,
z=z,
)
return output
[docs] def loss_function(self, recon_x, x, mu, log_var, z) -> Tuple[Tensor, Tensor, Tensor]:
"""
Computes the loss function for the VAE model.
Args:
recon_x (Tensor): The reconstructed input tensor.
x (Tensor): The original input tensor.
mu (Tensor): The mean of the latent variable distribution.
log_var (Tensor): The logarithm of the variance of the latent variable distribution.
z (Tensor): The sampled latent variable tensor.
Returns:
Tuple[Tensor, Tensor, Tensor]: A tuple containing the reconstruction loss, the KL divergence
loss, and the total loss.
"""
mu = mu.to(recon_x.device)
log_var = log_var.to(recon_x.device)
recon_x.clamp_min_(torch.finfo(recon_x.dtype).tiny * 10) # Prevents underflow
if (recon_x.shape[-1] > self.decoder.decoder.config.vocab_size):
num_annotations = recon_x.shape[-1] // self.decoder.decoder.config.vocab_size - 1
losses = vae_nll_loss_supervised(recon_x, x, mu, log_var, z, self.decoder.tokenizer.pad_token_id,
self.cur_beta, self.target_kl, num_annotations)
else:
losses = vae_nll_loss(recon_x, x, mu, log_var, z, self.decoder.tokenizer.pad_token_id,
self.cur_beta, self.target_kl)
# Log losses with tensorboard.
# self._loss_agg[0] += losses[0].item()
# self._loss_agg[1] += losses[2].item()
# if (self.debug and self._dbg_counter % 10 == 0):
# # print("\n", [l.item() for l in losses])
# LangVAE.loss_writer.add_scalar("Loss/train_joint", self._loss_agg[0] / 10, self._dbg_counter // 10)
# LangVAE.loss_writer.add_scalar("Loss/train_kld", self._loss_agg[1] / 10, self._dbg_counter // 10)
# LangVAE.loss_writer.flush()
# self._loss_agg[0] = 0.0
# self._loss_agg[1] = 0.0
# self._dbg_counter += 1
return losses
[docs] def encode_z(self, x: Tensor, c: Dict[str, Tensor] = None) -> Tuple[Tensor, List[Tensor]]:
"""
Encodes the input tensor into a latent variable tensor.
Args:
x (Tensor): The input tensor to be encoded.
Returns:
Tuple[Tensor, List[Tensor]]: A tuple of tensors containing the sampled latent variables and conditional
variable embeddings if available, respectively.
"""
encoded = self.encoder(x, c)
mu, log_var = encoded.embedding, encoded.log_covariance
cvars_emb = encoded.cvars_embedding
std = torch.exp(0.5 * log_var)
z, eps = self._sample_gauss(mu, std)
return (z, cvars_emb)
[docs] def decode_sentences(self, z: Tensor, cvars_emb: List[Tensor] = None) -> List[str]:
"""
Decodes the latent variable tensor into a list of sentences.
Args:
z (Tensor): The latent variable tensor to be decoded.
Returns:
List[str]: A list of strings representing the decoded sentences.
"""
z = torch.cat([z] + cvars_emb, dim=-1) if (cvars_emb and self.decoder.conditional) else z
generated = self.decoder(z)["reconstruction"]
sents = self.decoder.tokenizer.batch_decode(torch.argmax(generated, dim=-1), skip_special_tokens=True)
return sents
[docs] def push_to_hf_hub(self, hf_hub_path: str):
"""
Uploads the VAE model to the Hugging Face Hub.
Args:
hf_hub_path (str): The HF hub path where the model should be uploaded to.
"""
self.device = "cpu"
self.encoder.device = "cpu"
self.decoder.device = "cpu"
self.debug = False
self.encoder.debug = False
self.decoder.debug = False
pythae.models.base.base_utils.model_card_template = model_card_template.format(clsname=self.__class__.__name__)
super().push_to_hf_hub(hf_hub_path)
[docs] def save(self, dir_path: str):
"""Method to save the model at a specific location. It saves, the model weights as a
``models.pt`` file along with the model config as a ``model_config.json`` file. If the
model to save used custom encoder (resp. decoder) provided by the user, these are also
saved as ``decoder.pkl`` (resp. ``decoder.pkl``).
Args:
dir_path (str): The path where the model should be saved. If the path
path does not exist a folder will be created at the provided location.
"""
env_spec = EnvironmentConfig(
python_version=f"{sys.version_info[0]}.{sys.version_info[1]}"
)
if not os.path.exists(dir_path):
try:
os.makedirs(dir_path)
except FileNotFoundError as e:
raise e
env_spec.save_json(dir_path, "environment")
self.model_config.save_json(dir_path, "model_config")
if not self.model_config.uses_default_encoder:
torch.save(self.encoder.state_dict(), os.path.join(dir_path, "encoder.pt"))
with open(os.path.join(dir_path, "encoder_cfg.json"), "w") as enc_cfg_file:
json.dump({"model_path": self.encoder.model_path,
"latent_size": self.encoder.latent_size,
"automodel_preset": asdict(self.encoder.automodel_preset),
"caching": self.encoder.caching}, enc_cfg_file)
if not self.model_config.uses_default_decoder:
torch.save(self.decoder.state_dict(), os.path.join(dir_path, "decoder.pt"))
with open(os.path.join(dir_path, "decoder_cfg.json"), "w") as dec_cfg_file:
cfg = {
"model_path": self.decoder.model_path,
"latent_size": self.decoder.latent_size,
"max_len": self.decoder.max_len,
"conditional": self.decoder.conditional,
"device_map": self.decoder.device_map
}
json.dump(cfg, dec_cfg_file)
@classmethod
def _load_custom_encoder_from_folder(cls, dir_path, tokenizer):
file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)
if "encoder_cfg.json" not in file_list:
raise FileNotFoundError(
f"Missing encoder config file ('encoder_cfg.json') in"
f"{dir_path}... This file is needed to rebuild custom encoders."
" Cannot perform model building."
)
else:
with open(os.path.join(dir_path, "encoder_cfg.json"), "r") as fp:
cfg = json.load(fp)
with open(os.path.join(dir_path, "encoder.pt"), "rb") as fp:
encoder = SentenceEncoder(**(cfg | {"decoder_tokenizer": tokenizer}))
encoder.load_state_dict(torch.load(fp, map_location=torch.device(encoder.device), weights_only=True))
if (not encoder._encoder):
encoder.init_pretrained_model()
return encoder
@classmethod
def _load_custom_decoder_from_folder(cls, dir_path):
file_list = os.listdir(dir_path)
cls._check_python_version_from_folder(dir_path=dir_path)
if "decoder_cfg.json" not in file_list:
raise FileNotFoundError(
f"Missing decoder config file ('decoder_cfg.json') in"
f"{dir_path}... This file is needed to rebuild custom decoders."
" Cannot perform model building."
)
else:
with open(os.path.join(dir_path, "decoder_cfg.json"), "r") as fp:
cfg = json.load(fp)
with open(os.path.join(dir_path, "decoder.pt"), "rb") as fp:
decoder = SentenceDecoder(**cfg)
decoder.load_state_dict(torch.load(fp, map_location=torch.device(decoder.device), weights_only=True))
return decoder
[docs] @classmethod
def load_from_folder(cls, dir_path):
"""Class method to be used to load the model from a specific folder
Args:
dir_path (str): The path where the model should have been be saved.
.. note::
This function requires the folder to contain:
- | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided
**or**
- | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
``decoder.pkl``) if a custom encoder (resp. decoder) was provided
"""
model_config = cls._load_model_config_from_folder(dir_path)
decoder = cls._load_custom_decoder_from_folder(dir_path)
encoder = cls._load_custom_encoder_from_folder(dir_path, decoder.tokenizer)
model = cls(model_config, encoder=encoder, decoder=decoder)
return model
[docs] @classmethod
def load_from_hf_hub(cls, hf_hub_path: str): # pragma: no cover
"""Class method to be used to load a pretrained model from the Hugging Face hub
Args:
hf_hub_path (str): The path where the model should have been be saved on the
hugginface hub.
.. note::
This function requires the folder to contain:
- | a ``model_config.json`` and a ``model.pt`` if no custom architectures were provided
**or**
- | a ``model_config.json``, a ``model.pt`` and a ``encoder.pkl`` (resp.
``decoder.pkl``) if a custom encoder (resp. decoder) was provided
"""
if not hf_hub_is_available():
raise ModuleNotFoundError(
"`huggingface_hub` package must be installed to load models from the HF hub. "
"Run `python -m pip install huggingface_hub` and log in to your account with "
"`huggingface-cli login`."
)
else:
from huggingface_hub import hf_hub_download
logger.info(f"Downloading {cls.__name__} files for rebuilding...")
_ = hf_hub_download(repo_id=hf_hub_path, filename="environment.json")
config_path = hf_hub_download(repo_id=hf_hub_path, filename="model_config.json")
dir_path = os.path.dirname(config_path)
model_config = cls._load_model_config_from_folder(dir_path)
if not model_config.uses_default_decoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="decoder_cfg.json")
_ = hf_hub_download(repo_id=hf_hub_path, filename="decoder.pt")
decoder = cls._load_custom_decoder_from_folder(dir_path)
else:
decoder = None
if not model_config.uses_default_encoder:
_ = hf_hub_download(repo_id=hf_hub_path, filename="encoder_cfg.json")
_ = hf_hub_download(repo_id=hf_hub_path, filename="encoder.pt")
encoder = cls._load_custom_encoder_from_folder(dir_path, decoder.tokenizer)
else:
encoder = None
logger.info(f"Successfully downloaded {cls.__name__} model!")
model = cls(model_config, encoder=encoder, decoder=decoder)
return model