langvae.trainers package
Submodules
langvae.trainers.cyclical_schedule_kl module
- class langvae.trainers.cyclical_schedule_kl.CyclicalScheduleKLThresholdTrainer(model: LangVAE, train_dataset: BaseDataset, eval_dataset: BaseDataset | None = None, training_config: CyclicalScheduleKLThresholdTrainerConfig | None = None, callbacks: List[TrainingCallback] = None)[source]
Bases:
BaseTrainer- eval_step(epoch: int)[source]
Perform an evaluation step
- Parameters:
epoch (int) – The current epoch number
- Returns:
The evaluation loss
- Return type:
(torch.Tensor)
- class langvae.trainers.cyclical_schedule_kl.CyclicalScheduleKLThresholdTrainerConfig(output_dir: str | None = None, per_device_train_batch_size: int = 64, per_device_eval_batch_size: int = 64, num_epochs: int = 100, train_dataloader_num_workers: int = 0, eval_dataloader_num_workers: int = 0, optimizer_cls: str = 'Adam', optimizer_params: dict | None = None, scheduler_cls: str | None = None, scheduler_params: dict | None = None, learning_rate: float = 0.0001, steps_saving: int | None = None, steps_predict: int | None = None, keep_best_on_train: bool = False, seed: int = 8, no_cuda: bool = False, world_size: int = -1, local_rank: int = -1, rank: int = -1, dist_backend: str = 'nccl', master_addr: str = 'localhost', master_port: str = '12345', amp: bool = False, start_beta: float = 0.0, max_beta: float = 1.0, n_cycles: int = 1, target_kl: float = 2.0)[source]
Bases:
BaseTrainerConfig