Source code for langvae.decoders.sentence

import gc
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizer, DynamicCache
from pythae.models.nn import BaseDecoder
from pythae.models.base.base_utils import ModelOutput
# from langvae.data_conversion.sparse import densify_w_padding

FLASH_ATTN_SUPPORTED = [
    "meta-llama/Meta-Llama-3-8B",
    "meta-llama/Llama-3.2-3B",
    "mistralai/Mistral-7B-v0.3",
    "Qwen/Qwen2.5-3B"
]


[docs]class SentenceDecoder(BaseDecoder): """ Decoder class for generating sentences from latent representations. This decoder uses a pre-trained causal language model to generate text from latent representations. It outputs token probability distribution tensors (B x S x V), where :math:`B` is the batch size, :math:`S` is the maximum sentence length and :math:`V` is the decoder vocabulary size. Attributes: model_path (str): Path/locator to the pre-trained language model. latent_size (int): Size of the latent space. max_len (int): Maximum length (in tokens) of the generated sentences. device (torch.device): Device on which the model and data are allocated (e.g., 'cpu', 'cuda'). device_map (str): Device map configuration for model parallelism. args (ModelConfig, optional): Additional configuration arguments. """ def __init__(self, model_path: str, latent_size: int, max_len: int, conditional: bool = False, device: torch.device = "cpu", device_map: str = None, args=None): # Args is a ModelConfig instance BaseDecoder.__init__(self) self.model_path = model_path self.latent_size = latent_size self.max_len = max_len self.conditional = conditional self.device_map = device_map if (torch.cuda.is_available() and torch.cuda.device_count() > 1) else None self.device = device self.dec_hidden_layer_dev_map = None self._decoder = [] self._tokenizer = [] self.init_pretrained_model() dec_ids = torch.unsqueeze(torch.tensor([self.tokenizer.pad_token_id] * 2, dtype=torch.int64, device=self.device), dim=-1) pkv = self.decoder(dec_ids, use_cache=True).past_key_values self.pkv_dims = pkv[0][0].shape[1:] self.pkv_dtype = pkv[0][0].dtype self.context_hidden = nn.ModuleList([ nn.LazyLinear( self.pkv_dims[0] * self.pkv_dims[1] * self.pkv_dims[2] * 2 * (self.max_len + 1), # self.pkv_dims[0] * self.pkv_dims[1] * self.pkv_dims[2] == self.decoder.config.hidden_size dtype=self.pkv_dtype, device=f"cuda:{self.dec_hidden_layer_dev_map[i]}" if self.dec_hidden_layer_dev_map else self.device ) for i in range(self.decoder.config.num_hidden_layers) ]) self.dropout = nn.Dropout(p=0.1) # Logging outputs self._dbg_counter = 0 self.debug = False self.output_log_filepath = f"langvae_decoder_{model_path.replace('/', '--')}[{latent_size}_{max_len}].txt" @property def decoder(self) -> nn.Module: return self._decoder[0] @property def tokenizer(self) -> PreTrainedTokenizer: return self._tokenizer[0]
[docs] def to(self, device, include_pretrained: bool = True): super().to(device) self.device = device self.dec_hidden_layer_dev_map = None if (self._decoder and include_pretrained): self._decoder[0].to(device)
[docs] def init_pretrained_model(self): ex_params = dict() if (self.model_path in FLASH_ATTN_SUPPORTED): ex_params = {"attn_implementation": "flash_attention_2", "offload_buffers": True} if (str(self.device).startswith("cuda") and self.device_map): self._decoder = [AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype="auto", device_map=self.device_map, **ex_params)] dec_hidden_layer_prefix = [".".join(layer.split(".")[:-1]) for layer in self._decoder[0].hf_device_map if (layer.split(".")[-1] == str(self.decoder.config.num_hidden_layers - 1)) ][0] self.dec_hidden_layer_dev_map = {i: self._decoder[0].hf_device_map[f"{dec_hidden_layer_prefix}.{i}"] for i in range(self.decoder.config.num_hidden_layers)} else: self._decoder = [AutoModelForCausalLM.from_pretrained(self.model_path, torch_dtype="auto")] self._decoder = [self.decoder.to(self.device)] self._decoder[0].eval() self._decoder[0].requires_grad_(False) gc.collect() if (torch.cuda.is_available()): torch.cuda.empty_cache() self._tokenizer = [AutoTokenizer.from_pretrained(self.model_path, padding_side="left", add_prefix_space=True)] self._tokenizer[0].pad_token = self.tokenizer.eos_token self._tokenizer[0].pad_token_id = self.tokenizer.eos_token_id self._tokenizer[0].bos_token_id = (self._tokenizer[0].bos_token_id or self._decoder[0].config.bos_token_id) self.device = self.decoder.device
[docs] def forward(self, z: Tensor, max_len: int = 0, x: Tensor = None) -> ModelOutput: """ Processes the input latent tensor through the decoder to generate sentences. Args: z (Tensor): Input tensor containing latent representations. max_len (int): Maximum length (tokens) of output sentences. x (Tensor): Input tensor containing original tokens, for teach-forcing training. Returns: ModelOutput: The generated sentences as a ModelOutput object: token probability distribution tensors (B x S x V), where :math:`B` is the batch size, :math:`S is the maximum sentence length and :math:`V` is the decoder vocabulary size. """ # Fix for pythae device allocation bug if (not self.device_map): self._decoder[0] = self._decoder[0].to(self.device) else: self.device = self._decoder[0].device # self.context_embedder = self.context_embedder.to(self.device) if (not max_len): max_len = self.max_len z = z.to(self.pkv_dtype).to(self.device) dev_map = self.dec_hidden_layer_dev_map z_repl = None if (dev_map): for layer_idx in range(len(self.context_hidden)): self.context_hidden[layer_idx].to(f"cuda:{dev_map[layer_idx]}") z_repl = {dev: z.to(f"cuda:{dev}") for dev in set(dev_map.values())} generated = torch.zeros(z.shape[0], max_len + 1, self.decoder.config.vocab_size, device=self.device, dtype=self.pkv_dtype) dec_ids = torch.unsqueeze(torch.tensor([self.tokenizer.bos_token_id] * z.shape[0], dtype=torch.int64, device=self.device), dim=-1) # x_tok_ids = torch.cat([dec_ids, densify_w_padding(x, self.tokenizer.pad_token_id)], dim=1) if (x is not None) else None # decoded = self.decoder(input_ids=dec_ids) # generated[:, 0, :] = F.softmax(decoded.logits[:, -1, :], dim=-1) generated[:, 0,:] += F.one_hot(dec_ids, num_classes=generated.shape[-1]).squeeze() ctx_mem = [None] * self.decoder.config.num_hidden_layers past_dec = [None] * self.decoder.config.num_hidden_layers for layer_idx in range(self.decoder.config.num_hidden_layers): hidden_state = z_repl[dev_map[layer_idx]] if dev_map else z past = tuple([ h.view(-1, self.pkv_dims[0], self.pkv_dims[1] * (self.max_len + 1), self.pkv_dims[2]).to(f"cuda:{dev_map[layer_idx]}" if dev_map else self.device) for h in self.dropout(self.context_hidden[layer_idx](hidden_state)).chunk(2, dim=-1) ]) ctx_mem[layer_idx] = ( past[0], past[1] ) past_dec[layer_idx] = ( past[0][:, :, :1, :], past[1][:, :, :1, :] ) for i in range(max_len): # ctx_embed = context_embeds[:, max(0, i-self.max_look_behind + 1):i+1, :] # past_dec = self.compute_kv_residuals(decoded.past_key_values, z, z_repl, dev_map) # if (x_tok_ids is not None): # gen_ids = x_tok_ids[:, max(0, i):i + 1] # else: gen_ids = generated[:, max(0, i):i+1, :].argmax(dim=-1) # embeds = self.decoder.get_input_embeddings()(gen_ids) # + ctx_embed past_dec = DynamicCache.from_legacy_cache(past_dec) decoded = self.decoder(input_ids=gen_ids, use_cache=True, past_key_values=past_dec) past_dec = [ ( torch.cat([ decoded.past_key_values[layer_idx][0][:, :, :-2, :], ctx_mem[layer_idx][0][:, :, i+1:i+2, :], decoded.past_key_values[layer_idx][0][:, :, -1:, :] ], dim=-2), torch.cat([ decoded.past_key_values[layer_idx][1][:, :, :-2, :], ctx_mem[layer_idx][1][:, :, i+1:i+2, :], decoded.past_key_values[layer_idx][1][:, :, -1:, :] ], dim=-2), ) for layer_idx in range(self.decoder.config.num_hidden_layers) ] generated[:, i+1, :] = F.softmax(decoded.logits[:, -1, :], dim=-1) # Debug print (outputs) if (self.debug): if (self._dbg_counter % 100 == 0): with open(self.output_log_filepath, "w", encoding="utf-8") as output_log_file: print( "\n".join([s.replace(self.tokenizer.pad_token, "|#|") for s in self.tokenizer.batch_decode(torch.argmax(generated, dim=-1))]), file=output_log_file ) print("\n", "-" * 40, "\n", file=output_log_file) self._dbg_counter += 1 output = ModelOutput( reconstruction=generated[:, 1:max_len + 1,:] ) return output
[docs] def generate(self, z: Tensor, max_len: int = 0) -> ModelOutput: return self.forward(z, max_len=max_len)