langvae.trainers package

Submodules

langvae.trainers.cyclical_schedule_kl module

class langvae.trainers.cyclical_schedule_kl.CyclicalScheduleKLThresholdTrainer(model: LangVAE, train_dataset: BaseDataset, eval_dataset: BaseDataset | None = None, training_config: CyclicalScheduleKLThresholdTrainerConfig | None = None, callbacks: List[TrainingCallback] = None)[source]

Bases: BaseTrainer

eval_step(epoch: int)[source]

Perform an evaluation step

Parameters:

epoch (int) – The current epoch number

Returns:

The evaluation loss

Return type:

(torch.Tensor)

get_eval_dataloader(eval_dataset: BaseDataset) DataLoader[source]
get_train_dataloader(train_dataset: BaseDataset) DataLoader[source]
train(log_output_dir: str = None)[source]

This function is the main training function

Parameters:

log_output_dir (str) – The path in which the log will be stored

train_step(epoch: int)[source]

The trainers performs training loop over the train_loader.

Parameters:

epoch (int) – The current epoch number

Returns:

The step training loss

Return type:

(torch.Tensor)

class langvae.trainers.cyclical_schedule_kl.CyclicalScheduleKLThresholdTrainerConfig(output_dir: str | None = None, per_device_train_batch_size: int = 64, per_device_eval_batch_size: int = 64, num_epochs: int = 100, train_dataloader_num_workers: int = 0, eval_dataloader_num_workers: int = 0, optimizer_cls: str = 'Adam', optimizer_params: dict | None = None, scheduler_cls: str | None = None, scheduler_params: dict | None = None, learning_rate: float = 0.0001, steps_saving: int | None = None, steps_predict: int | None = None, keep_best_on_train: bool = False, seed: int = 8, no_cuda: bool = False, world_size: int = -1, local_rank: int = -1, rank: int = -1, dist_backend: str = 'nccl', master_addr: str = 'localhost', master_port: str = '12345', amp: bool = False, start_beta: float = 0.0, max_beta: float = 1.0, n_cycles: int = 1, target_kl: float = 2.0)[source]

Bases: BaseTrainerConfig

max_beta: float = 1.0
n_cycles: int = 1
start_beta: float = 0.0
target_kl: float = 2.0
langvae.trainers.cyclical_schedule_kl.copy_model_ptref(model: LangVAE) LangVAE[source]
langvae.trainers.cyclical_schedule_kl.frange_cycle_zero_linear(n_iter: int, start: float = 0.0, stop: float = 1.0, n_cycle: int = 4, ratio_increase: float = 0.5, ratio_zero: float = 0.3) Tensor[source]

langvae.trainers.training_callbacks module

Module contents