StreamingSVD / diffusion_trainer /abstract_trainer.py
lev1's picture
Initial commit
8fd2f2f
raw
history blame
3.52 kB
import os
import pytorch_lightning as pl
import torch
from typing import Any
from modules.params.diffusion.inference_params import InferenceParams
from modules.loader.module_loader import GenericModuleLoader
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams
class AbstractTrainer(pl.LightningModule):
def __init__(self,
inference_params: Any,
diff_trainer_params: DiffusionTrainerParams,
module_loader: GenericModuleLoader,
):
super().__init__()
self.inference_params = inference_params
self.diff_trainer_params = diff_trainer_params
self.module_loader = module_loader
self.on_start_once_called = False
self._setup_methods = []
module_loader(
trainer=self,
diff_trainer_params=diff_trainer_params)
# ------ IMPLEMENTATION HOOKS -------
def post_init(self, batch):
'''
Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction.
First possible access to the 'trainer' object (e.g. to get 'device').
'''
def generate_output(self, batch, batch_idx, inference_params: InferenceParams):
'''
Is called during validation to generate for each batch an output.
Return the meta information about produced result (where result were stored).
This is used for the metric evaluation.
'''
# ------- HELPER FUNCTIONS -------
def _reset_random_generator(self):
'''
Reset the random generator to the same seed across all workers. The generator is used only for inference.
'''
if not hasattr(self, "random_generator"):
self.random_generator = torch.Generator(device=self.device)
# set seed according to 'seed_everything' in config
seed = int(os.environ.get("PL_GLOBAL_SEED", 42))
else:
seed = self.random_generator.initial_seed()
self.random_generator.manual_seed(seed)
# ----- PREDICT HOOKS ------
def on_predict_start(self):
self.on_start()
def predict_step(self, batch, batch_idx):
self.on_inference_step(batch=batch, batch_idx=batch_idx)
def on_predict_epoch_start(self):
self.on_inference_epoch_start()
# ----- CUSTOM HOOKS -----
# Global Hooks (Called by Training, Validation and Prediction)
# abstract method
def _on_start_once(self):
'''
Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction.
'''
if self.on_start_once_called:
return
else:
self.on_start_once_called = True
self.post_init()
def on_start(self):
'''
Called at the beginning of training, validation and prediction.
'''
self._on_start_once()
# Inference Hooks (Called by Validation and Prediction)
# ----- Inference Hooks (called by 'validation' and 'predict') ------
def on_inference_epoch_start(self):
# reset seed at every inference
self._reset_random_generator()
def on_inference_step(self, batch, batch_idx):
if self.inference_params.reset_seed_per_generation:
self._reset_random_generator()
self.generate_output(
batch=batch, inference_params=self.inference_params, batch_idx=batch_idx)