Source code for langvae.trainers.cyclical_schedule_kl

import os
import logging
import torch
import torch.distributed as dist
from typing import List, Optional
from copy import deepcopy
from torch import Tensor
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from pydantic.dataclasses import dataclass
from pythae.trainers.base_trainer import BaseTrainer, BaseTrainerConfig
from pythae.trainers.training_callbacks import TrainingCallback
from pythae.data.datasets import BaseDataset
from langvae.arch.vae import LangVAE
from langvae.data_conversion.tokenization import collate_sparse_fn

logger = logging.getLogger(__name__)
console = logging.StreamHandler()
logger.addHandler(console)
logger.setLevel(logging.INFO)


[docs]def frange_cycle_zero_linear(n_iter: int, start: float = 0.0, stop: float = 1.0, n_cycle: int = 4, ratio_increase: float = 0.5, ratio_zero: float = 0.3) -> Tensor: beta_t_list = torch.ones(n_iter) * stop period = n_iter/n_cycle step = (stop-start)/(period*ratio_increase) # linear schedule for c in range(n_cycle): v, i = start, 0 while v <= stop and (int(i+c*period) < n_iter): if i < period*ratio_zero: beta_t_list[int(i+c*period)] = start else: beta_t_list[int(i+c*period)] = v v += step i += 1 return beta_t_list
[docs]def copy_model_ptref(model: LangVAE) -> LangVAE: pt_encoder = model.encoder._encoder pt_enc_tokenizer = model.encoder._tokenizer pt_decoder = model.decoder._decoder pt_dec_tokenizer = model.decoder._tokenizer model.encoder._encoder = None model.encoder._tokenizer = None model.decoder._decoder = None model.decoder._tokenizer = None model_copy = deepcopy(model) model.encoder._encoder = pt_encoder model.encoder._tokenizer = pt_enc_tokenizer model.decoder._decoder = pt_decoder model.decoder._tokenizer = pt_dec_tokenizer model_copy.encoder._encoder = pt_encoder model_copy.encoder._tokenizer = pt_enc_tokenizer model_copy.decoder._decoder = pt_decoder model_copy.decoder._tokenizer = pt_dec_tokenizer return model_copy
[docs]@dataclass class CyclicalScheduleKLThresholdTrainerConfig(BaseTrainerConfig): start_beta: float = 0.0 max_beta: float = 1.0 n_cycles: int = 1 target_kl: float = 2.0
[docs]class CyclicalScheduleKLThresholdTrainer(BaseTrainer): def __init__( self, model: LangVAE, train_dataset: BaseDataset, eval_dataset: Optional[BaseDataset] = None, training_config: Optional[CyclicalScheduleKLThresholdTrainerConfig] = None, callbacks: List[TrainingCallback] = None, ): if training_config is None: training_config = CyclicalScheduleKLThresholdTrainerConfig() super().__init__(model, train_dataset, eval_dataset, training_config, callbacks) self.num_batches = len(train_dataset) // training_config.per_device_train_batch_size self.num_batches += len(train_dataset) % training_config.per_device_train_batch_size n_iter = training_config.num_epochs * self.num_batches self.beta_t_list = frange_cycle_zero_linear(n_iter, start=training_config.start_beta, stop=training_config.max_beta, n_cycle=training_config.n_cycles, ratio_increase=0.25, ratio_zero=0.25) self.model.target_kl = training_config.target_kl
[docs] def train_step(self, epoch: int): """The trainers performs training loop over the train_loader. Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The step training loss """ self.callback_handler.on_train_step_begin( training_config=self.training_config, train_loader=self.train_loader, epoch=epoch, rank=self.rank, ) # set model in train model self.model.train() epoch_loss = 0 for i, inputs in enumerate(self.train_loader): inputs = self._set_inputs_to_device(inputs) self.model.cur_beta = self.beta_t_list[i + (epoch - 1) * self.num_batches] with self.amp_context: model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.train_loader.dataset), uses_ddp=self.distributed, ) torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0) self._optimizers_step(model_output) loss = model_output.loss epoch_loss += loss.item() if epoch_loss != epoch_loss: raise ArithmeticError("NaN detected in train loss") self.callback_handler.on_train_step_end( training_config=self.training_config ) # Allows model updates if needed if self.distributed: self.model.module.update() else: self.model.update() epoch_loss /= len(self.train_loader) return epoch_loss
[docs] def eval_step(self, epoch: int): """Perform an evaluation step Parameters: epoch (int): The current epoch number Returns: (torch.Tensor): The evaluation loss """ self.callback_handler.on_eval_step_begin( training_config=self.training_config, eval_loader=self.eval_loader, epoch=epoch, rank=self.rank, ) self.model.eval() # cur_beta = self.model.cur_beta # self.model.cur_beta = 1.0 epoch_loss = 0 with self.amp_context: for inputs in self.eval_loader: inputs = self._set_inputs_to_device(inputs) try: with torch.no_grad(): model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, ) except RuntimeError: model_output = self.model( inputs, epoch=epoch, dataset_size=len(self.eval_loader.dataset), uses_ddp=self.distributed, ) loss = model_output.loss epoch_loss += loss.item() if epoch_loss != epoch_loss: raise ArithmeticError("NaN detected in eval loss") self.callback_handler.on_eval_step_end( training_config=self.training_config ) epoch_loss /= len(self.eval_loader) # self.model.cur_beta = cur_beta # if (self.model.debug): # LangVAE.loss_writer.add_scalar("Loss/eval", epoch_loss, self.model._dbg_counter // 10) # LangVAE.loss_writer.flush() return epoch_loss
[docs] def get_train_dataloader( self, train_dataset: BaseDataset ) -> torch.utils.data.DataLoader: if self.distributed: train_sampler = DistributedSampler( train_dataset, num_replicas=self.world_size, rank=self.rank ) else: train_sampler = None return DataLoader( dataset=train_dataset, batch_size=self.training_config.per_device_train_batch_size, num_workers=self.training_config.train_dataloader_num_workers, shuffle=False, #(train_sampler is None), sampler=train_sampler, collate_fn=collate_sparse_fn, )
[docs] def get_eval_dataloader( self, eval_dataset: BaseDataset ) -> torch.utils.data.DataLoader: if self.distributed: eval_sampler = DistributedSampler( eval_dataset, num_replicas=self.world_size, rank=self.rank ) else: eval_sampler = None return DataLoader( dataset=eval_dataset, batch_size=self.training_config.per_device_eval_batch_size, num_workers=self.training_config.eval_dataloader_num_workers, shuffle=(eval_sampler is None), sampler=eval_sampler, collate_fn=collate_sparse_fn, )
[docs] def train(self, log_output_dir: str = None): """This function is the main training function Args: log_output_dir (str): The path in which the log will be stored """ self.prepare_training() self.callback_handler.on_train_begin( training_config=self.training_config, model_config=self.model_config ) log_verbose = False msg = ( f"Training params:\n - max_epochs: {self.training_config.num_epochs}\n" " - per_device_train_batch_size: " f"{self.training_config.per_device_train_batch_size}\n" " - per_device_eval_batch_size: " f"{self.training_config.per_device_eval_batch_size}\n" f" - checkpoint saving every: {self.training_config.steps_saving}\n" f"Optimizer: {self.optimizer}\n" f"Scheduler: {self.scheduler}\n" ) if self.is_main_process: logger.info(msg) # set up log file if log_output_dir is not None and self.is_main_process: log_verbose = True file_logger = self._get_file_logger(log_output_dir=log_output_dir) file_logger.info(msg) if self.is_main_process: logger.info("Successfully launched training !\n") # set best losses for early stopping best_train_loss = 1e10 best_eval_loss = 1e10 for epoch in range(1, self.training_config.num_epochs + 1): self.callback_handler.on_epoch_begin( training_config=self.training_config, epoch=epoch, train_loader=self.train_loader, eval_loader=self.eval_loader, ) metrics = {} epoch_train_loss = self.train_step(epoch) metrics["train_epoch_loss"] = epoch_train_loss if self.eval_dataset is not None: epoch_eval_loss = self.eval_step(epoch) metrics["eval_epoch_loss"] = epoch_eval_loss self._schedulers_step(epoch_eval_loss) else: epoch_eval_loss = best_eval_loss self._schedulers_step(epoch_train_loss) if ( epoch_eval_loss < best_eval_loss and not self.training_config.keep_best_on_train ): best_eval_loss = epoch_eval_loss best_model = copy_model_ptref(self.model) self._best_model = best_model elif ( epoch_train_loss < best_train_loss and self.training_config.keep_best_on_train ): best_train_loss = epoch_train_loss best_model = copy_model_ptref(self.model) self._best_model = best_model if ( self.training_config.steps_predict is not None and epoch % self.training_config.steps_predict == 0 and self.is_main_process ): true_data, reconstructions, generations = self.predict(best_model) self.callback_handler.on_prediction_step( self.training_config, true_data=true_data, reconstructions=reconstructions, generations=generations, global_step=epoch, ) self.callback_handler.on_epoch_end(training_config=self.training_config) # save checkpoints if ( self.training_config.steps_saving is not None and epoch % self.training_config.steps_saving == 0 ): if self.is_main_process: self.save_checkpoint( model=best_model, dir_path=self.training_dir, epoch=epoch ) logger.info(f"Saved checkpoint at epoch {epoch}\n") if log_verbose: file_logger.info(f"Saved checkpoint at epoch {epoch}\n") self.callback_handler.on_log( self.training_config, metrics, logger=logger, global_step=epoch, rank=self.rank, ) final_dir = os.path.join(self.training_dir, "final_model") if self.is_main_process: self.save_model(best_model, dir_path=final_dir) logger.info("Training ended!") logger.info(f"Saved final model in {final_dir}") if self.distributed: dist.destroy_process_group() self.callback_handler.on_train_end(self.training_config)