|
import os |
|
from dataclasses import dataclass, field |
|
|
|
import pytorch_lightning as pl |
|
import torch.nn.functional as F |
|
|
|
import craftsman |
|
from craftsman.utils.base import ( |
|
Updateable, |
|
update_end_if_possible, |
|
update_if_possible, |
|
) |
|
from craftsman.utils.scheduler import parse_optimizer, parse_scheduler |
|
from craftsman.utils.config import parse_structured |
|
from craftsman.utils.misc import C, cleanup, get_device, load_module_weights |
|
from craftsman.utils.saving import SaverMixin |
|
from craftsman.utils.typing import * |
|
|
|
|
|
class BaseSystem(pl.LightningModule, Updateable, SaverMixin): |
|
@dataclass |
|
class Config: |
|
loggers: dict = field(default_factory=dict) |
|
loss: dict = field(default_factory=dict) |
|
optimizer: dict = field(default_factory=dict) |
|
scheduler: Optional[dict] = None |
|
weights: Optional[str] = None |
|
weights_ignore_modules: Optional[List[str]] = None |
|
cleanup_after_validation_step: bool = False |
|
cleanup_after_test_step: bool = False |
|
|
|
pretrained_model_path: Optional[str] = None |
|
strict_load: bool = True |
|
cfg: Config |
|
|
|
def __init__(self, cfg, resumed=False) -> None: |
|
super().__init__() |
|
self.cfg = parse_structured(self.Config, cfg) |
|
self._save_dir: Optional[str] = None |
|
self._resumed: bool = resumed |
|
self._resumed_eval: bool = False |
|
self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} |
|
if "loggers" in cfg: |
|
self.create_loggers(cfg.loggers) |
|
|
|
self.configure() |
|
if self.cfg.weights is not None: |
|
self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) |
|
self.post_configure() |
|
|
|
def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): |
|
state_dict, epoch, global_step = load_module_weights( |
|
weights, ignore_modules=ignore_modules, map_location="cpu" |
|
) |
|
self.load_state_dict(state_dict, strict=False) |
|
|
|
self.do_update_step(epoch, global_step, on_load_weights=True) |
|
|
|
def set_resume_status(self, current_epoch: int, global_step: int): |
|
|
|
self._resumed_eval = True |
|
self._resumed_eval_status["current_epoch"] = current_epoch |
|
self._resumed_eval_status["global_step"] = global_step |
|
|
|
@property |
|
def resumed(self): |
|
|
|
return self._resumed |
|
|
|
@property |
|
def true_global_step(self): |
|
if self._resumed_eval: |
|
return self._resumed_eval_status["global_step"] |
|
else: |
|
return self.global_step |
|
|
|
@property |
|
def true_current_epoch(self): |
|
if self._resumed_eval: |
|
return self._resumed_eval_status["current_epoch"] |
|
else: |
|
return self.current_epoch |
|
|
|
def configure(self) -> None: |
|
pass |
|
|
|
def post_configure(self) -> None: |
|
""" |
|
executed after weights are loaded |
|
""" |
|
pass |
|
|
|
def C(self, value: Any) -> float: |
|
return C(value, self.true_current_epoch, self.true_global_step) |
|
|
|
def configure_optimizers(self): |
|
optim = parse_optimizer(self.cfg.optimizer, self) |
|
ret = { |
|
"optimizer": optim, |
|
} |
|
if self.cfg.scheduler is not None: |
|
ret.update( |
|
{ |
|
"lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), |
|
} |
|
) |
|
return ret |
|
|
|
def training_step(self, batch, batch_idx): |
|
raise NotImplementedError |
|
|
|
def validation_step(self, batch, batch_idx): |
|
raise NotImplementedError |
|
|
|
def on_train_batch_end(self, outputs, batch, batch_idx): |
|
self.dataset = self.trainer.train_dataloader.dataset |
|
update_end_if_possible( |
|
self.dataset, self.true_current_epoch, self.true_global_step |
|
) |
|
self.do_update_step_end(self.true_current_epoch, self.true_global_step) |
|
|
|
def on_validation_batch_end(self, outputs, batch, batch_idx): |
|
self.dataset = self.trainer.val_dataloaders.dataset |
|
update_end_if_possible( |
|
self.dataset, self.true_current_epoch, self.true_global_step |
|
) |
|
self.do_update_step_end(self.true_current_epoch, self.true_global_step) |
|
if self.cfg.cleanup_after_validation_step: |
|
|
|
cleanup() |
|
|
|
def on_validation_epoch_end(self): |
|
raise NotImplementedError |
|
|
|
def test_step(self, batch, batch_idx): |
|
raise NotImplementedError |
|
|
|
def on_test_batch_end(self, outputs, batch, batch_idx): |
|
self.dataset = self.trainer.test_dataloaders.dataset |
|
update_end_if_possible( |
|
self.dataset, self.true_current_epoch, self.true_global_step |
|
) |
|
self.do_update_step_end(self.true_current_epoch, self.true_global_step) |
|
if self.cfg.cleanup_after_test_step: |
|
|
|
cleanup() |
|
|
|
def on_test_epoch_end(self): |
|
pass |
|
|
|
def predict_step(self, batch, batch_idx): |
|
raise NotImplementedError |
|
|
|
def on_predict_batch_end(self, outputs, batch, batch_idx): |
|
self.dataset = self.trainer.predict_dataloaders.dataset |
|
update_end_if_possible( |
|
self.dataset, self.true_current_epoch, self.true_global_step |
|
) |
|
self.do_update_step_end(self.true_current_epoch, self.true_global_step) |
|
if self.cfg.cleanup_after_test_step: |
|
|
|
cleanup() |
|
|
|
def on_predict_epoch_end(self): |
|
pass |
|
|
|
def preprocess_data(self, batch, stage): |
|
pass |
|
|
|
""" |
|
Implementing on_after_batch_transfer of DataModule does the same. |
|
But on_after_batch_transfer does not support DP. |
|
""" |
|
|
|
def on_train_batch_start(self, batch, batch_idx, unused=0): |
|
self.preprocess_data(batch, "train") |
|
self.dataset = self.trainer.train_dataloader.dataset |
|
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) |
|
self.do_update_step(self.true_current_epoch, self.true_global_step) |
|
|
|
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): |
|
self.preprocess_data(batch, "validation") |
|
self.dataset = self.trainer.val_dataloaders.dataset |
|
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) |
|
self.do_update_step(self.true_current_epoch, self.true_global_step) |
|
|
|
def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): |
|
self.preprocess_data(batch, "test") |
|
self.dataset = self.trainer.test_dataloaders.dataset |
|
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) |
|
self.do_update_step(self.true_current_epoch, self.true_global_step) |
|
|
|
def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): |
|
self.preprocess_data(batch, "predict") |
|
self.dataset = self.trainer.predict_dataloaders.dataset |
|
update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) |
|
self.do_update_step(self.true_current_epoch, self.true_global_step) |
|
|
|
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): |
|
pass |
|
|
|
def on_before_optimizer_step(self, optimizer): |
|
""" |
|
# some gradient-related debugging goes here, example: |
|
from lightning.pytorch.utilities import grad_norm |
|
norms = grad_norm(self.geometry, norm_type=2) |
|
print(norms) |
|
""" |
|
pass |
|
|