Spaces:
Runtime error
Runtime error
import os | |
import pytorch_lightning as pl | |
import torch | |
from typing import Any | |
from modules.params.diffusion.inference_params import InferenceParams | |
from modules.loader.module_loader import GenericModuleLoader | |
from modules.params.diffusion_trainer.params_streaming_diff_trainer import DiffusionTrainerParams | |
class AbstractTrainer(pl.LightningModule): | |
def __init__(self, | |
inference_params: Any, | |
diff_trainer_params: DiffusionTrainerParams, | |
module_loader: GenericModuleLoader, | |
): | |
super().__init__() | |
self.inference_params = inference_params | |
self.diff_trainer_params = diff_trainer_params | |
self.module_loader = module_loader | |
self.on_start_once_called = False | |
self._setup_methods = [] | |
module_loader( | |
trainer=self, | |
diff_trainer_params=diff_trainer_params) | |
# ------ IMPLEMENTATION HOOKS ------- | |
def post_init(self, batch): | |
''' | |
Is called after LightningDataModule and LightningModule is created, but before any training/validation/prediction. | |
First possible access to the 'trainer' object (e.g. to get 'device'). | |
''' | |
def generate_output(self, batch, batch_idx, inference_params: InferenceParams): | |
''' | |
Is called during validation to generate for each batch an output. | |
Return the meta information about produced result (where result were stored). | |
This is used for the metric evaluation. | |
''' | |
# ------- HELPER FUNCTIONS ------- | |
def _reset_random_generator(self): | |
''' | |
Reset the random generator to the same seed across all workers. The generator is used only for inference. | |
''' | |
if not hasattr(self, "random_generator"): | |
self.random_generator = torch.Generator(device=self.device) | |
# set seed according to 'seed_everything' in config | |
seed = int(os.environ.get("PL_GLOBAL_SEED", 42)) | |
else: | |
seed = self.random_generator.initial_seed() | |
self.random_generator.manual_seed(seed) | |
# ----- PREDICT HOOKS ------ | |
def on_predict_start(self): | |
self.on_start() | |
def predict_step(self, batch, batch_idx): | |
self.on_inference_step(batch=batch, batch_idx=batch_idx) | |
def on_predict_epoch_start(self): | |
self.on_inference_epoch_start() | |
# ----- CUSTOM HOOKS ----- | |
# Global Hooks (Called by Training, Validation and Prediction) | |
# abstract method | |
def _on_start_once(self): | |
''' | |
Will be called only once by on_start. Thus, it will be called by the first call of train,validation or prediction. | |
''' | |
if self.on_start_once_called: | |
return | |
else: | |
self.on_start_once_called = True | |
self.post_init() | |
def on_start(self): | |
''' | |
Called at the beginning of training, validation and prediction. | |
''' | |
self._on_start_once() | |
# Inference Hooks (Called by Validation and Prediction) | |
# ----- Inference Hooks (called by 'validation' and 'predict') ------ | |
def on_inference_epoch_start(self): | |
# reset seed at every inference | |
self._reset_random_generator() | |
def on_inference_step(self, batch, batch_idx): | |
if self.inference_params.reset_seed_per_generation: | |
self._reset_random_generator() | |
self.generate_output( | |
batch=batch, inference_params=self.inference_params, batch_idx=batch_idx) | |