File size: 3,772 Bytes
753fd9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122

from yacs.config import CfgNode as CN
import argparse
import yaml
import os

abs_barc_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..',))

_C = CN()
_C.barc_dir = abs_barc_dir
_C.device = 'cuda'

## path settings
_C.paths = CN()
_C.paths.ROOT_OUT_PATH = abs_barc_dir + '/results/'
_C.paths.ROOT_CHECKPOINT_PATH = abs_barc_dir + '/checkpoint/'
_C.paths.MODELPATH_NORMFLOW = abs_barc_dir + '/checkpoint/barc_normflow_pret/rgbddog_v3_model.pt'

## parameter settings
_C.params = CN()
_C.params.ARCH = 'hg8'    
_C.params.STRUCTURE_POSE_NET = 'normflow'     # 'default'   # 'vae' 
_C.params.NF_VERSION = 3
_C.params.N_JOINTS = 35   
_C.params.N_KEYP = 24      #20    
_C.params.N_SEG = 2
_C.params.N_PARTSEG = 15
_C.params.UPSAMPLE_SEG = True
_C.params.ADD_PARTSEG = True   # partseg: for the CVPR paper this part of the network exists, but is not trained (no part labels in StanExt)
_C.params.N_BETAS = 30    # 10
_C.params.N_BETAS_LIMBS = 7
_C.params.N_BONES = 24
_C.params.N_BREEDS = 121      # 120 breeds plus background
_C.params.IMG_SIZE = 256
_C.params.SILH_NO_TAIL = False
_C.params.KP_THRESHOLD = None    
_C.params.ADD_Z_TO_3D_INPUT = False   
_C.params.N_SEGBPS = 64*2
_C.params.ADD_SEGBPS_TO_3D_INPUT = True
_C.params.FIX_FLENGTH = False   
_C.params.RENDER_ALL = True
_C.params.VLIN = 2    
_C.params.STRUCTURE_Z_TO_B = 'lin'
_C.params.N_Z_FREE = 64
_C.params.PCK_THRESH = 0.15   
_C.params.REF_NET_TYPE = 'add' # refinement network type
_C.params.REF_DETACH_SHAPE = True
_C.params.GRAPHCNN_TYPE = 'inexistent'
_C.params.ISFLAT_TYPE = 'inexistent'
_C.params.SHAPEREF_TYPE = 'inexistent'

## SMAL settings
_C.smal = CN()
_C.smal.SMAL_MODEL_TYPE = 'barc'    
_C.smal.SMAL_KEYP_CONF = 'green'    

## optimization settings
_C.optim = CN()
_C.optim.LR = 5e-4
_C.optim.SCHEDULE = [150, 175, 200]
_C.optim.GAMMA = 0.1
_C.optim.MOMENTUM = 0
_C.optim.WEIGHT_DECAY = 0
_C.optim.EPOCHS = 220
_C.optim.BATCH_SIZE = 12       # keep 12 (needs to be an even number, as we have a custom data sampler)
_C.optim.TRAIN_PARTS = 'all_without_shapedirs'

## dataset settings
_C.data = CN()
_C.data.DATASET = 'stanext24'
_C.data.V12 = True     
_C.data.SHORTEN_VAL_DATASET_TO = None        
_C.data.VAL_OPT = 'val'
_C.data.VAL_METRICS = 'no_loss'

# ---------------------------------------
def update_dependent_vars(cfg):    
    cfg.params.N_CLASSES = cfg.params.N_KEYP + cfg.params.N_SEG
    if cfg.params.VLIN == 0: 
        cfg.params.NUM_STAGE_COMB = 2
        cfg.params.NUM_STAGE_HEADS = 1  
        cfg.params.NUM_STAGE_HEADS_POSE = 1
        cfg.params.TRANS_SEP = False
    elif cfg.params.VLIN == 1:
        cfg.params.NUM_STAGE_COMB = 3              
        cfg.params.NUM_STAGE_HEADS = 1             
        cfg.params.NUM_STAGE_HEADS_POSE = 2        
        cfg.params.TRANS_SEP = False
    elif cfg.params.VLIN == 2:
        cfg.params.NUM_STAGE_COMB = 3              
        cfg.params.NUM_STAGE_HEADS = 1             
        cfg.params.NUM_STAGE_HEADS_POSE = 2        
        cfg.params.TRANS_SEP = True
    else:
        raise NotImplementedError
    if cfg.params.STRUCTURE_Z_TO_B == '1dconv':
        cfg.params.N_Z = cfg.params.N_BETAS + cfg.params.N_BETAS_LIMBS
    else:
        cfg.params.N_Z = cfg.params.N_Z_FREE
    return


update_dependent_vars(_C)
global _cfg_global 
_cfg_global = _C.clone()


def get_cfg_defaults():
    # Get a yacs CfgNode object with default values as defined within this file.
    # Return a clone so that the defaults will not be altered.
    return _C.clone()

def update_cfg_global_with_yaml(cfg_yaml_file):    
    _cfg_global.merge_from_file(cfg_yaml_file)
    update_dependent_vars(_cfg_global)
    return 

def get_cfg_global_updated():
    # return _cfg_global.clone()
    return _cfg_global