|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
from functools import partial |
|
from lib.model.DSTformer import DSTformer |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0 |
|
self.avg = 0 |
|
self.sum = 0 |
|
self.count = 0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / self.count |
|
|
|
def accuracy(output, target, topk=(1,)): |
|
"""Computes the accuracy over the k top predictions for the specified values of k""" |
|
with torch.no_grad(): |
|
maxk = max(topk) |
|
batch_size = target.size(0) |
|
_, pred = output.topk(maxk, 1, True, True) |
|
pred = pred.t() |
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
|
res = [] |
|
for k in topk: |
|
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) |
|
res.append(correct_k.mul_(100.0 / batch_size)) |
|
return res |
|
|
|
def load_pretrained_weights(model, checkpoint): |
|
"""Load pretrianed weights to model |
|
Incompatible layers (unmatched in name or size) will be ignored |
|
Args: |
|
- model (nn.Module): network model, which must not be nn.DataParallel |
|
- weight_path (str): path to pretrained weights |
|
""" |
|
import collections |
|
if 'state_dict' in checkpoint: |
|
state_dict = checkpoint['state_dict'] |
|
else: |
|
state_dict = checkpoint |
|
model_dict = model.state_dict() |
|
new_state_dict = collections.OrderedDict() |
|
matched_layers, discarded_layers = [], [] |
|
for k, v in state_dict.items(): |
|
|
|
|
|
if k.startswith('module.'): |
|
k = k[7:] |
|
if k in model_dict and model_dict[k].size() == v.size(): |
|
new_state_dict[k] = v |
|
matched_layers.append(k) |
|
else: |
|
discarded_layers.append(k) |
|
model_dict.update(new_state_dict) |
|
model.load_state_dict(model_dict, strict=True) |
|
print('load_weight', len(matched_layers)) |
|
return model |
|
|
|
def partial_train_layers(model, partial_list): |
|
"""Train partial layers of a given model.""" |
|
for name, p in model.named_parameters(): |
|
p.requires_grad = False |
|
for trainable in partial_list: |
|
if trainable in name: |
|
p.requires_grad = True |
|
break |
|
return model |
|
|
|
def load_backbone(args): |
|
if not(hasattr(args, "backbone")): |
|
args.backbone = 'DSTformer' |
|
if args.backbone=='DSTformer': |
|
model_backbone = DSTformer(dim_in=3, dim_out=3, dim_feat=args.dim_feat, dim_rep=args.dim_rep, |
|
depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio, norm_layer=partial(nn.LayerNorm, eps=1e-6), |
|
maxlen=args.maxlen, num_joints=args.num_joints) |
|
elif args.backbone=='TCN': |
|
from lib.model.model_tcn import PoseTCN |
|
model_backbone = PoseTCN() |
|
elif args.backbone=='poseformer': |
|
from lib.model.model_poseformer import PoseTransformer |
|
model_backbone = PoseTransformer(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=32, depth=4, |
|
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0, attn_mask=None) |
|
elif args.backbone=='mixste': |
|
from lib.model.model_mixste import MixSTE2 |
|
model_backbone = MixSTE2(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=512, depth=8, |
|
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0) |
|
elif args.backbone=='stgcn': |
|
from lib.model.model_stgcn import Model as STGCN |
|
model_backbone = STGCN() |
|
else: |
|
raise Exception("Undefined backbone type.") |
|
return model_backbone |