SupermanxKiaski commited on
Commit
eb144e3
1 Parent(s): fdddc61

Create train_video.py

Browse files
Files changed (1) hide show
  1. train_video.py +104 -0
train_video.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import random
3
+ from argparse import ArgumentParser
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ import yaml
9
+ from tqdm import tqdm
10
+
11
+ from datasets.video_dataset import AtlasDataset
12
+ from models.video_model import VideoModel
13
+ from util.atlas_loss import AtlasLoss
14
+ from util.util import get_optimizer
15
+ from util.video_logger import DataLogger
16
+
17
+
18
+ def train_model(config):
19
+ # set seed
20
+ seed = config["seed"]
21
+ if seed == -1:
22
+ seed = np.random.randint(2 ** 32)
23
+ random.seed(seed)
24
+ np.random.seed(seed)
25
+ torch.manual_seed(seed)
26
+ print(f"running with seed: {seed}.")
27
+
28
+ dataset = AtlasDataset(config)
29
+ model = VideoModel(config)
30
+ criterion = AtlasLoss(config)
31
+ optimizer = get_optimizer(config, model.parameters())
32
+
33
+ logger = DataLogger(config, dataset)
34
+ with tqdm(range(1, config["n_epochs"] + 1)) as tepoch:
35
+ for epoch in tepoch:
36
+ inputs = dataset[0]
37
+ optimizer.zero_grad()
38
+ outputs = model(inputs)
39
+ losses = criterion(outputs, inputs)
40
+
41
+ loss = 0.
42
+ if config["finetune_foreground"]:
43
+ loss += losses["foreground"]["loss"]
44
+ elif config["finetune_background"]:
45
+ loss += losses["background"]["loss"]
46
+
47
+ lr = optimizer.param_groups[0]["lr"]
48
+ log_data = logger.log_data(epoch, lr, losses, model, dataset)
49
+
50
+ loss.backward()
51
+ optimizer.step()
52
+ optimizer.param_groups[0]["lr"] = max(config["min_lr"], config["gamma"] * optimizer.param_groups[0]["lr"])
53
+
54
+ if config["use_wandb"]:
55
+ wandb.log(log_data)
56
+ else:
57
+ if epoch % config["log_images_freq"] == 0:
58
+ logger.save_locally(log_data)
59
+
60
+ tepoch.set_description(f"Epoch {epoch}")
61
+ tepoch.set_postfix(loss=loss.item())
62
+
63
+
64
+ if __name__ == "__main__":
65
+ parser = ArgumentParser()
66
+ parser.add_argument(
67
+ "--config",
68
+ default="./configs/video_config.yaml",
69
+ help="Config path",
70
+ )
71
+ parser.add_argument(
72
+ "--example_config",
73
+ default="car-turn_winter.yaml",
74
+ help="Example config name",
75
+ )
76
+ args = parser.parse_args()
77
+ config_path = args.config
78
+
79
+ with open(config_path, "r") as f:
80
+ config = yaml.safe_load(f)
81
+ with open(f"./configs/video_example_configs/{args.example_config}", "r") as f:
82
+ example_config = yaml.safe_load(f)
83
+ config["example_config"] = args.example_config
84
+ config.update(example_config)
85
+
86
+ run_name = f"-{config['checkpoint_path'].split('/')[-2]}"
87
+ if config["use_wandb"]:
88
+ import wandb
89
+
90
+ wandb.init(project=config["wandb_project"], entity=config["wandb_entity"], config=config, name=run_name)
91
+ wandb.run.name = str(wandb.run.id) + wandb.run.name
92
+ config = dict(wandb.config)
93
+ else:
94
+ now = datetime.datetime.now()
95
+ run_name = f"{now.strftime('%Y-%m-%d_%H-%M-%S')}{run_name}"
96
+ path = Path(f"{config['results_folder']}/{run_name}")
97
+ path.mkdir(parents=True, exist_ok=True)
98
+ with open(path / "config.yaml", "w") as f:
99
+ yaml.dump(config, f)
100
+ config["results_folder"] = str(path)
101
+
102
+ train_model(config)
103
+ if config["use_wandb"]:
104
+ wandb.finish()