Source code for langvae.data_conversion.sparse

import torch
from torch import Tensor


[docs]@torch.compile def densify_w_padding(x: Tensor, pad_token_id: int) -> Tensor: """Converts sparse one-hot tensors to token ids with padding.""" x = x.coalesce() x_dense = torch.zeros(x.shape, dtype=torch.int64, device=x.device) x_dense[:, :, pad_token_id] = 1 nz_idx = x.indices().detach().clone() nz_idx[-1] = pad_token_id x_dense[nz_idx.tolist()] = 0 x_dense[x.indices().tolist()] = x.values().long() x_tok_ids = x_dense.argmax(dim=-1) return x_tok_ids