tzco
/

English
File size: 4,633 Bytes
b976bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch, argparse
from nerf.network import NeRFNetwork
from nerf.renderer import NeRFRenderer
from nerf.provider import get_loaders
from nerf.utils import seed_everything, PSNRMeter, Trainer


def fn(i, opt):
    world_size, global_rank, local_rank = opt.gpus * opt.nodes, i + opt.node * opt.gpus, i

    if world_size > 1:
        torch.distributed.init_process_group(backend='nccl', init_method=f'tcp://{opt.master}:{opt.port}', world_size=world_size, rank=global_rank)

    if local_rank == 0:
        print(opt)

    print(f'initiate node{opt.node}, rank{global_rank}, gpu{local_rank}')
    device = torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu')
    torch.cuda.set_device(local_rank)
    seed_everything(opt.seed + global_rank)

    train_ids = open(opt.path, 'r').read().strip().splitlines()
    val_ids = train_ids[:opt.validate_objects]
    test_ids = open(opt.test_list, 'r').read().splitlines()[:8]

    train_loader, val_loader, test_loader = get_loaders(opt, train_ids, val_ids, test_ids)

    network = NeRFNetwork(opt=opt, device=device)
    model = NeRFRenderer(opt=opt, network=network, device=device)
    criterion = torch.nn.MSELoss(reduction='none')

    optimizer = torch.optim.Adam(model.network.get_params(opt.lr0, opt.lr1), betas=(0.9, 0.99), eps=1e-6)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 1)

    trainer = Trainer(name='train',
                      opt=opt,
                      device=device,
                      metrics=[PSNRMeter()],
                      optimizer=optimizer,
                      scheduler=scheduler,
                      criterion=criterion,
                      model=model,
                      ema_decay=opt.ema_decay,
                      eval_interval=opt.eval_interval,
                      workspace=opt.save_dir,
                      checkpoint_path=opt.ckpt,
                      local_rank=global_rank,
                      world_size=world_size,
                      )
    trainer.train(train_loader, val_loader, test_loader, opt.epochs)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('path', type=str)
    parser.add_argument('save_dir', type=str)

    # data
    parser.add_argument('--data_root', type=str, default='path/to/dataset')
    parser.add_argument('--test_list', type=str, default='path/to/test_object_list')
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--validate_objects', type=int, default=8)
    parser.add_argument('--downscale', type=int, default=1)

    # training
    parser.add_argument('--gpus', type=int, default=8)
    parser.add_argument('--nodes', type=int, default=1)
    parser.add_argument('--node', type=int, default=0)
    parser.add_argument('--master', type=str, default='127.0.0.1')
    parser.add_argument('--port', type=int, default=12345)

    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--epochs', type=int, default=1000)
    parser.add_argument('--lr0', type=float, default=1e-3)
    parser.add_argument('--lr1', type=float, default=1e-4)
    parser.add_argument('--ckpt', type=str, default='scratch')
    parser.add_argument('--eval_interval', type=int, default=1)
    parser.add_argument('--fp16', action='store_true')
    parser.add_argument('--ema_decay', type=float, default=0)
    parser.add_argument('--ema_freq', type=int, default=10)
    parser.add_argument('--depth_loss', type=float, default=0)
    parser.add_argument('--lpips_loss', type=float, default=0.01)
    
    # encoder
    parser.add_argument('--image_channel', type=int, default=3)
    parser.add_argument('--extractor_channel', type=int, default=32)
    parser.add_argument('--coarse_volume_resolution', type=int, default=32)
    parser.add_argument('--coarse_volume_channel', type=int, default=4)
    parser.add_argument('--fine_volume_channel', type=int, default=32)
    parser.add_argument('--gaussian_lambda', type=float, default=1e4)
    parser.add_argument('--n_source', type=int, default=32)
    parser.add_argument('--mlp_layer', type=int, default=5)
    parser.add_argument('--mlp_dim', type=int, default=256)
    parser.add_argument('--costreg_ch_mult', type=str, default='2,4,8')
    parser.add_argument('--encoder_clamp_range', type=float, default=100)

    # render
    parser.add_argument('--num_rays', type=int, default=24576)
    parser.add_argument('--num_steps', type=int, default=256)
    parser.add_argument('--bound', type=float, default=1)

    opt = parser.parse_args()
    torch.multiprocessing.spawn(fn, args=(opt,), nprocs=opt.gpus)