MotionBERT / lib /utils /learning.py
walterzhu's picture
Upload 58 files
bbde80b
raw
history blame
3.99 kB
import os
import numpy as np
import torch
import torch.nn as nn
from functools import partial
from lib.model.DSTformer import DSTformer
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def load_pretrained_weights(model, checkpoint):
"""Load pretrianed weights to model
Incompatible layers (unmatched in name or size) will be ignored
Args:
- model (nn.Module): network model, which must not be nn.DataParallel
- weight_path (str): path to pretrained weights
"""
import collections
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
model_dict = model.state_dict()
new_state_dict = collections.OrderedDict()
matched_layers, discarded_layers = [], []
for k, v in state_dict.items():
# If the pretrained state_dict was saved as nn.DataParallel,
# keys would contain "module.", which should be ignored.
if k.startswith('module.'):
k = k[7:]
if k in model_dict and model_dict[k].size() == v.size():
new_state_dict[k] = v
matched_layers.append(k)
else:
discarded_layers.append(k)
model_dict.update(new_state_dict)
model.load_state_dict(model_dict, strict=True)
print('load_weight', len(matched_layers))
return model
def partial_train_layers(model, partial_list):
"""Train partial layers of a given model."""
for name, p in model.named_parameters():
p.requires_grad = False
for trainable in partial_list:
if trainable in name:
p.requires_grad = True
break
return model
def load_backbone(args):
if not(hasattr(args, "backbone")):
args.backbone = 'DSTformer' # Default
if args.backbone=='DSTformer':
model_backbone = DSTformer(dim_in=3, dim_out=3, dim_feat=args.dim_feat, dim_rep=args.dim_rep,
depth=args.depth, num_heads=args.num_heads, mlp_ratio=args.mlp_ratio, norm_layer=partial(nn.LayerNorm, eps=1e-6),
maxlen=args.maxlen, num_joints=args.num_joints)
elif args.backbone=='TCN':
from lib.model.model_tcn import PoseTCN
model_backbone = PoseTCN()
elif args.backbone=='poseformer':
from lib.model.model_poseformer import PoseTransformer
model_backbone = PoseTransformer(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=32, depth=4,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0, attn_mask=None)
elif args.backbone=='mixste':
from lib.model.model_mixste import MixSTE2
model_backbone = MixSTE2(num_frame=args.maxlen, num_joints=args.num_joints, in_chans=3, embed_dim_ratio=512, depth=8,
num_heads=8, mlp_ratio=2., qkv_bias=True, qk_scale=None,drop_path_rate=0)
elif args.backbone=='stgcn':
from lib.model.model_stgcn import Model as STGCN
model_backbone = STGCN()
else:
raise Exception("Undefined backbone type.")
return model_backbone