|
|
|
|
|
|
|
import math |
|
import torch |
|
from torch import nn |
|
from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks_default |
|
from fvcore.nn.weight_init import c2_msra_fill, c2_xavier_fill |
|
|
|
from .norm import get_norm |
|
from .stem import VideoModelStem |
|
from .resnet import ResStage |
|
from .head import X3DHead |
|
|
|
|
|
|
|
|
|
def round_width(width, multiplier, min_width=1, divisor=1): |
|
if not multiplier: |
|
return width |
|
width *= multiplier |
|
min_width = min_width or divisor |
|
width_out = max(min_width, int(width + divisor / 2) // divisor * divisor) |
|
if width_out < 0.9 * width: |
|
width_out += divisor |
|
return int(width_out) |
|
|
|
|
|
|
|
|
|
def init_weights( |
|
model, fc_init_std=0.01, zero_init_final_bn=True, zero_init_final_conv=False |
|
): |
|
""" |
|
Performs ResNet style weight initialization. |
|
Args: |
|
fc_init_std (float): the expected standard deviation for fc layer. |
|
zero_init_final_bn (bool): if True, zero initialize the final bn for |
|
every bottleneck. |
|
""" |
|
for m in model.modules(): |
|
if isinstance(m, nn.Conv3d): |
|
|
|
if hasattr(m, "final_conv") and zero_init_final_conv: |
|
m.weight.data.zero_() |
|
else: |
|
""" |
|
Follow the initialization method proposed in: |
|
{He, Kaiming, et al. |
|
"Delving deep into rectifiers: Surpassing human-level |
|
performance on imagenet classification." |
|
arXiv preprint arXiv:1502.01852 (2015)} |
|
""" |
|
c2_msra_fill(m) |
|
|
|
elif isinstance(m, (nn.BatchNorm3d, nn.BatchNorm2d, nn.BatchNorm1d)): |
|
if ( |
|
hasattr(m, "transform_final_bn") |
|
and m.transform_final_bn |
|
and zero_init_final_bn |
|
): |
|
batchnorm_weight = 0.0 |
|
else: |
|
batchnorm_weight = 1.0 |
|
if m.weight is not None: |
|
m.weight.data.fill_(batchnorm_weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
if isinstance(m, nn.Linear): |
|
if hasattr(m, "xavier_init") and m.xavier_init: |
|
c2_xavier_fill(m) |
|
else: |
|
m.weight.data.normal_(mean=0.0, std=fc_init_std) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
|
|
|
|
|
|
|
|
_POOL1 = { |
|
"2d": [[1, 1, 1]], |
|
"c2d": [[2, 1, 1]], |
|
"slow_c2d": [[1, 1, 1]], |
|
"i3d": [[2, 1, 1]], |
|
"slow_i3d": [[1, 1, 1]], |
|
"slow": [[1, 1, 1]], |
|
"slowfast": [[1, 1, 1], [1, 1, 1]], |
|
"x3d": [[1, 1, 1]], |
|
} |
|
|
|
|
|
|
|
_TEMPORAL_KERNEL_BASIS = { |
|
"2d": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
], |
|
"c2d": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
], |
|
"slow_c2d": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
], |
|
"i3d": [ |
|
[[5]], |
|
[[3]], |
|
[[3, 1]], |
|
[[3, 1]], |
|
[[1, 3]], |
|
], |
|
"slow_i3d": [ |
|
[[5]], |
|
[[3]], |
|
[[3, 1]], |
|
[[3, 1]], |
|
[[1, 3]], |
|
], |
|
"slow": [ |
|
[[1]], |
|
[[1]], |
|
[[1]], |
|
[[3]], |
|
[[3]], |
|
], |
|
"slowfast": [ |
|
[[1], [5]], |
|
[[1], [3]], |
|
[[1], [3]], |
|
[[3], [3]], |
|
[[3], [3]], |
|
], |
|
"x3d": [ |
|
[[5]], |
|
[[3]], |
|
[[3]], |
|
[[3]], |
|
[[3]], |
|
], |
|
} |
|
|
|
|
|
|
|
_MODEL_STAGE_DEPTH = {18: (2, 2, 2, 2), 50: (3, 4, 6, 3), 101: (3, 4, 23, 3)} |
|
|
|
|
|
|
|
|
|
class X3D(nn.Module): |
|
""" |
|
X3D model builder. It builds a X3D network backbone, which is a ResNet. |
|
|
|
Christoph Feichtenhofer. |
|
"X3D: Expanding Architectures for Efficient Video Recognition." |
|
https://arxiv.org/abs/2004.04730 |
|
""" |
|
|
|
def __init__(self, cfg): |
|
""" |
|
The `__init__` method of any subclass should also contain these |
|
arguments. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
super(X3D, self).__init__() |
|
self.norm_module = get_norm(cfg) |
|
self.enable_detection = cfg.DETECTION.ENABLE |
|
self.num_pathways = 1 |
|
|
|
exp_stage = 2.0 |
|
self.dim_c1 = cfg.X3D.DIM_C1 |
|
|
|
self.dim_res2 = ( |
|
round_width(self.dim_c1, exp_stage, divisor=8) |
|
if cfg.X3D.SCALE_RES2 |
|
else self.dim_c1 |
|
) |
|
self.dim_res3 = round_width(self.dim_res2, exp_stage, divisor=8) |
|
self.dim_res4 = round_width(self.dim_res3, exp_stage, divisor=8) |
|
self.dim_res5 = round_width(self.dim_res4, exp_stage, divisor=8) |
|
|
|
self.block_basis = [ |
|
|
|
[1, self.dim_res2, 2], |
|
[2, self.dim_res3, 2], |
|
[5, self.dim_res4, 2], |
|
[3, self.dim_res5, 2], |
|
] |
|
self._construct_network(cfg) |
|
init_weights( |
|
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN |
|
) |
|
|
|
def _round_repeats(self, repeats, multiplier): |
|
"""Round number of layers based on depth multiplier.""" |
|
if not multiplier: |
|
return repeats |
|
return int(math.ceil(multiplier * repeats)) |
|
|
|
def _construct_network(self, cfg): |
|
""" |
|
Builds a single pathway X3D model. |
|
|
|
Args: |
|
cfg (CfgNode): model building configs, details are in the |
|
comments of the config file. |
|
""" |
|
assert cfg.MODEL.ARCH in _POOL1.keys() |
|
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys() |
|
|
|
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH] |
|
|
|
num_groups = cfg.RESNET.NUM_GROUPS |
|
width_per_group = cfg.RESNET.WIDTH_PER_GROUP |
|
dim_inner = num_groups * width_per_group |
|
|
|
w_mul = cfg.X3D.WIDTH_FACTOR |
|
d_mul = cfg.X3D.DEPTH_FACTOR |
|
dim_res1 = round_width(self.dim_c1, w_mul) |
|
|
|
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH] |
|
|
|
self.s1 = VideoModelStem( |
|
dim_in=cfg.DATA.INPUT_CHANNEL_NUM, |
|
dim_out=[dim_res1], |
|
kernel=[temp_kernel[0][0] + [3, 3]], |
|
stride=[[1, 2, 2]], |
|
padding=[[temp_kernel[0][0][0] // 2, 1, 1]], |
|
norm_module=self.norm_module, |
|
stem_func_name="x3d_stem", |
|
) |
|
|
|
|
|
dim_in = dim_res1 |
|
for stage, block in enumerate(self.block_basis): |
|
dim_out = round_width(block[1], w_mul) |
|
dim_inner = int(cfg.X3D.BOTTLENECK_FACTOR * dim_out) |
|
|
|
n_rep = self._round_repeats(block[0], d_mul) |
|
|
|
prefix = "s{}".format(stage + 2) |
|
|
|
s = ResStage( |
|
dim_in=[dim_in], |
|
dim_out=[dim_out], |
|
dim_inner=[dim_inner], |
|
temp_kernel_sizes=temp_kernel[1], |
|
stride=[block[2]], |
|
num_blocks=[n_rep], |
|
num_groups=[dim_inner] if cfg.X3D.CHANNELWISE_3x3x3 else [ |
|
num_groups], |
|
num_block_temp_kernel=[n_rep], |
|
nonlocal_inds=cfg.NONLOCAL.LOCATION[0], |
|
nonlocal_group=cfg.NONLOCAL.GROUP[0], |
|
nonlocal_pool=cfg.NONLOCAL.POOL[0], |
|
instantiation=cfg.NONLOCAL.INSTANTIATION, |
|
trans_func_name=cfg.RESNET.TRANS_FUNC, |
|
stride_1x1=cfg.RESNET.STRIDE_1X1, |
|
norm_module=self.norm_module, |
|
dilation=cfg.RESNET.SPATIAL_DILATIONS[stage], |
|
drop_connect_rate=cfg.MODEL.DROPCONNECT_RATE |
|
* (stage + 2) |
|
/ (len(self.block_basis) + 1), |
|
) |
|
dim_in = dim_out |
|
self.add_module(prefix, s) |
|
|
|
if self.enable_detection: |
|
NotImplementedError |
|
else: |
|
spat_sz = int(math.ceil(cfg.DATA.TRAIN_CROP_SIZE / 32.0)) |
|
self.head = X3DHead( |
|
dim_in=dim_out, |
|
dim_inner=dim_inner, |
|
dim_out=cfg.X3D.DIM_C5, |
|
num_classes=cfg.MODEL.NUM_CLASSES, |
|
pool_size=[cfg.DATA.NUM_FRAMES, spat_sz, spat_sz], |
|
dropout_rate=cfg.MODEL.DROPOUT_RATE, |
|
act_func=cfg.MODEL.HEAD_ACT, |
|
bn_lin5_on=cfg.X3D.BN_LIN5, |
|
) |
|
|
|
def forward(self, x, bboxes=None): |
|
for module in self.children(): |
|
x = module(x) |
|
return x |
|
|
|
def build_model(cfg, gpu_id=None): |
|
if torch.cuda.is_available(): |
|
assert ( |
|
cfg.NUM_GPUS <= torch.cuda.device_count() |
|
), "Cannot use more GPU devices than available" |
|
else: |
|
assert ( |
|
cfg.NUM_GPUS == 0 |
|
), "Cuda is not available. Please set `NUM_GPUS: 0 for running on CPUs." |
|
|
|
|
|
model = X3D(cfg) |
|
|
|
if cfg.BN.NORM_TYPE == "sync_batchnorm_apex": |
|
try: |
|
import apex |
|
except ImportError: |
|
raise ImportError("APEX is required for this model, pelase install") |
|
|
|
process_group = apex.parallel.create_syncbn_process_group( |
|
group_size=cfg.BN.NUM_SYNC_DEVICES |
|
) |
|
model = apex.parallel.convert_syncbn_model(model, process_group=process_group) |
|
|
|
if cfg.NUM_GPUS: |
|
if gpu_id is None: |
|
|
|
cur_device = torch.cuda.current_device() |
|
else: |
|
cur_device = gpu_id |
|
|
|
model = model.cuda(device=cur_device) |
|
|
|
if cfg.NUM_GPUS > 1: |
|
|
|
model = torch.nn.parallel.DistributedDataParallel( |
|
module=model, |
|
device_ids=[cur_device], |
|
output_device=cur_device, |
|
find_unused_parameters=( |
|
True |
|
if cfg.MODEL.DETACH_FINAL_FC |
|
or cfg.MODEL.MODEL_NAME == "ContrastiveModel" |
|
else False |
|
), |
|
) |
|
if cfg.MODEL.FP16_ALLREDUCE: |
|
model.register_comm_hook( |
|
state=None, hook=comm_hooks_default.fp16_compress_hook |
|
) |
|
return model |
|
|