|
from pytorch_lightning import Trainer |
|
from models import vae_models |
|
from config import config |
|
from pytorch_lightning.callbacks import LearningRateMonitor |
|
from pytorch_lightning.loggers import TensorBoardLogger |
|
import os |
|
os.environ['KMP_DUPLICATE_LIB_OK']='True' |
|
|
|
|
|
def make_model(config): |
|
model_type = config.model_type |
|
model_config = config.model_config |
|
|
|
if model_type not in vae_models.keys(): |
|
raise NotImplementedError("Model Architecture not implemented") |
|
else: |
|
return vae_models[model_type](**model_config.dict()) |
|
|
|
|
|
if __name__ == "__main__": |
|
model = make_model(config) |
|
train_config = config.train_config |
|
logger = TensorBoardLogger(**config.log_config.dict()) |
|
trainer = Trainer(**train_config.dict(), logger=logger, |
|
callbacks=LearningRateMonitor()) |
|
if train_config.auto_lr_find: |
|
lr_finder = trainer.tuner.lr_find(model) |
|
new_lr = lr_finder.suggestion() |
|
print("Learning Rate Chosen:", new_lr) |
|
model.lr = new_lr |
|
trainer.fit(model) |
|
else: |
|
trainer.fit(model) |
|
if not os.path.isdir("./saved_models"): |
|
os.mkdir("./saved_models") |
|
trainer.save_checkpoint( |
|
f"saved_models/{config.model_type}_alpha_{config.model_config.alpha}_dim_{config.model_config.hidden_size}.ckpt") |
|
|