Quick start

Installation

To use LangVAE, first install it using pip:

(.venv) $ pip install langvae

Training an LM-VAE

Here’s a basic example of how to train a VAE on text data using LangVAE (or use our example Colab notebook):

from pythae.models.vae import VAEConfig
from saf_datasets import EntailmentBankDataSet
from langvae import LangVAE
from langvae.encoders import SentenceEncoder
from langvae.decoders import SentenceDecoder
from langvae.data_conversion.tokenization import TokenizedDataSet
from langvae.pipelines import LanguageTrainingPipeline
from langvae.trainers import CyclicalScheduleKLThresholdTrainerConfig
from langvae.trainers.training_callbacks import TensorBoardCallback

DEVICE = "cuda"
LATENT_SIZE = 128
MAX_SENT_LEN = 32

# Load pre-trained sentence encoder and decoder models.
decoder = SentenceDecoder("gpt2", LATENT_SIZE, MAX_SENT_LEN, device=DEVICE, device_map="auto")
encoder = SentenceEncoder("bert-base-cased", LATENT_SIZE, decoder.tokenizer, caching=True, device=DEVICE)

# Select explanatory sentences from the EntailmentBank dataset.
dataset = [
    sent for sent in EntailmentBankDataSet()
    if (sent.annotations["type"] == "answer" or
        sent.annotations["type"].startswith("context"))
]

# Set training and evaluation datasets with auto tokenization.
eval_size = int(0.1 * len(dataset))
train_dataset = TokenizedDataSet(sorted(dataset[:-eval_size], key=lambda x: len(x.surface), reverse=True),
                                 decoder.tokenizer, decoder.max_len, caching=True,
                                 cache_persistence=f"eb_train_tok-gpt2_cache.jsonl")
eval_dataset = TokenizedDataSet(sorted(dataset[-eval_size:], key=lambda x: len(x.surface), reverse=True),
                                decoder.tokenizer, decoder.max_len, caching=True,
                                cache_persistence=f"eb_eval_tok-gpt2_cache.jsonl")


# Define VAE model configuration
model_config = VAEConfig(latent_dim=LATENT_SIZE)

# Initialize LangVAE model
model = LangVAE(model_config, encoder, decoder)

exp_label = f"eb-langvae-bert-gpt2-{LATENT_SIZE}"

# Train VAE on explanatory sentences
training_config = CyclicalScheduleKLThresholdTrainerConfig(
    output_dir=exp_label,
    num_epochs=20,
    learning_rate=1e-3,
    per_device_train_batch_size=50,
    per_device_eval_batch_size=50,
    steps_saving=5,
    optimizer_cls="AdamW",
    scheduler_cls="ReduceLROnPlateau",
    scheduler_params={"patience": 5, "factor": 0.5},
    max_beta=1.0,
    n_cycles=16,  # num_epochs * 0.8
    target_kl=2.0,
    keep_best_on_train=True
)

pipeline = LanguageTrainingPipeline(
    training_config=training_config,
    model=model
)

# Monitor the training progress with `tensorboard --logdir=runs &`
tb_callback = TensorBoardCallback(exp_label)

pipeline(
    train_data=train_dataset,
    eval_data=eval_dataset,
    callbacks=[tb_callback]
)